type hints on main,plugin,updater,utilites.localsocket

This commit is contained in:
marios8543
2023-09-18 00:31:54 +03:00
parent e2d708a6af
commit a7c358844c
5 changed files with 92 additions and 63 deletions
+22 -15
View File
@@ -1,10 +1,13 @@
import asyncio, time, random
import asyncio, time
from typing import Awaitable, Callable
import random
from localplatform import ON_WINDOWS
BUFFER_LIMIT = 2 ** 20 # 1 MiB
class UnixSocket:
def __init__(self, on_new_message):
def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
@@ -46,28 +49,32 @@ class UnixSocket:
self.reader = None
async def read_single_line(self) -> str|None:
reader, writer = await self.get_socket_connection()
reader, _ = await self.get_socket_connection()
if self.reader == None:
return None
try:
assert reader
except AssertionError:
return
return await self._read_single_line(reader)
async def write_single_line(self, message : str):
reader, writer = await self.get_socket_connection()
_, writer = await self.get_socket_connection()
if self.writer == None:
return;
try:
assert writer
except AssertionError:
return
await self._write_single_line(writer, message)
async def _read_single_line(self, reader) -> str:
async def _read_single_line(self, reader: asyncio.StreamReader) -> str:
line = bytearray()
while True:
try:
line.extend(await reader.readuntil())
except asyncio.LimitOverrunError:
line.extend(await reader.read(reader._limit))
line.extend(await reader.read(reader._limit)) # type: ignore
continue
except asyncio.IncompleteReadError as err:
line.extend(err.partial)
@@ -77,27 +84,27 @@ class UnixSocket:
return line.decode("utf-8")
async def _write_single_line(self, writer, message : str):
async def _write_single_line(self, writer: asyncio.StreamWriter, message : str):
if not message.endswith("\n"):
message += "\n"
writer.write(message.encode("utf-8"))
await writer.drain()
async def _listen_for_method_call(self, reader, writer):
async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
while True:
line = await self._read_single_line(reader)
try:
res = await self.on_new_message(line)
except Exception as e:
except Exception:
return
if res != None:
await self._write_single_line(writer, res)
class PortSocket (UnixSocket):
def __init__(self, on_new_message):
def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
@@ -125,7 +132,7 @@ class PortSocket (UnixSocket):
return True
if ON_WINDOWS:
class LocalSocket (PortSocket):
class LocalSocket (PortSocket): # type: ignore
pass
else:
class LocalSocket (UnixSocket):
+8 -7
View File
@@ -1,5 +1,6 @@
# Change PyInstaller files permissions
import sys
from typing import Dict
from localplatform import (chmod, chown, service_stop, service_start,
ON_WINDOWS, get_log_level, get_live_reload,
get_server_port, get_server_host, get_chown_plugin_path,
@@ -16,7 +17,7 @@ import multiprocessing
import aiohttp_cors # type: ignore
# Partial imports
from aiohttp import client_exceptions
from aiohttp.web import Application, Response, get, run_app, static # type: ignore
from aiohttp.web import Application, Response, Request, get, run_app, static # type: ignore
from aiohttp_jinja2 import setup as jinja_setup
# local modules
@@ -70,7 +71,7 @@ class PluginManager:
jinja_setup(self.web_app)
async def startup(_):
async def startup(_: Application):
if self.settings.getSetting("cef_forward", False):
self.loop.create_task(service_start(REMOTE_DEBUGGER_UNIT))
else:
@@ -84,16 +85,16 @@ class PluginManager:
self.web_app.add_routes([get("/auth/token", self.get_auth_token)])
for route in list(self.web_app.router.routes()):
self.cors.add(route)
self.cors.add(route) # type: ignore
self.web_app.add_routes([static("/static", path.join(path.dirname(__file__), 'static'))])
self.web_app.add_routes([static("/legacy", path.join(path.dirname(__file__), 'legacy'))])
def exception_handler(self, loop, context):
def exception_handler(self, loop: AbstractEventLoop, context: Dict[str, str]):
if context["message"] == "Unclosed connection":
return
loop.default_exception_handler(context)
async def get_auth_token(self, request):
async def get_auth_token(self, request: Request):
return Response(text=get_csrf_token())
async def load_plugins(self):
@@ -144,7 +145,7 @@ class PluginManager:
# This is because of https://github.com/aio-libs/aiohttp/blob/3ee7091b40a1bc58a8d7846e7878a77640e96996/aiohttp/client_ws.py#L321
logger.info("CEF has disconnected...")
# At this point the loop starts again and we connect to the freshly started Steam client once it is ready.
except Exception as e:
except Exception:
logger.error("Exception while reading page events " + format_exc())
await tab.close_websocket()
pass
@@ -154,7 +155,7 @@ class PluginManager:
# logger.info("Plugin loader isn't present in Steam anymore, reinjecting...")
# await self.inject_javascript(tab)
async def inject_javascript(self, tab: Tab, first=False, request=None):
async def inject_javascript(self, tab: Tab, first: bool=False, request: Request|None=None):
logger.info("Loading Decky frontend!")
try:
if first:
+1 -1
View File
@@ -20,7 +20,7 @@ class PluginWrapper:
self.plugin_path = plugin_path
self.plugin_directory = plugin_directory
self.method_call_lock = Lock()
self.socket = LocalSocket(self._on_new_message)
self.socket: LocalSocket = LocalSocket(self._on_new_message)
self.version = None
+18 -9
View File
@@ -1,23 +1,31 @@
import os
import shutil
import uuid
from asyncio import sleep
from ensurepip import version
from json.decoder import JSONDecodeError
from logging import getLogger
from os import getcwd, path, remove
from typing import List, TypedDict
from backend.main import PluginManager
from localplatform import chmod, service_restart, ON_LINUX, get_keep_systemd_service, get_selinux
from aiohttp import ClientSession, web
import helpers
from injector import get_gamepadui_tab, inject_to_tab
from injector import get_gamepadui_tab
from settings import SettingsManager
logger = getLogger("Updater")
class RemoteVerAsset(TypedDict):
name: str
browser_download_url: str
class RemoteVer(TypedDict):
tag_name: str
prerelease: bool
assets: List[RemoteVerAsset]
class Updater:
def __init__(self, context) -> None:
def __init__(self, context: PluginManager) -> None:
self.context = context
self.settings = self.context.settings
# Exposes updater methods to frontend
@@ -28,8 +36,8 @@ class Updater:
"do_restart": self.do_restart,
"check_for_updates": self.check_for_updates
}
self.remoteVer = None
self.allRemoteVers = None
self.remoteVer: RemoteVer | None = None
self.allRemoteVers: List[RemoteVer] = []
self.localVer = helpers.get_loader_version()
try:
@@ -44,7 +52,7 @@ class Updater:
])
context.loop.create_task(self.version_reloader())
async def _handle_server_method_call(self, request):
async def _handle_server_method_call(self, request: web.Request):
method_name = request.match_info["method_name"]
try:
args = await request.json()
@@ -52,7 +60,7 @@ class Updater:
args = {}
res = {}
try:
r = await self.updater_methods[method_name](**args)
r = await self.updater_methods[method_name](**args) # type: ignore
res["result"] = r
res["success"] = True
except Exception as e:
@@ -105,7 +113,7 @@ class Updater:
selectedBranch = self.get_branch(self.context.settings)
async with ClientSession() as web:
async with web.request("GET", "https://api.github.com/repos/SteamDeckHomebrew/decky-loader/releases", ssl=helpers.get_ssl_context()) as res:
remoteVersions = await res.json()
remoteVersions: List[RemoteVer] = await res.json()
if selectedBranch == 0:
logger.debug("release type: release")
remoteVersions = list(filter(lambda ver: ver["tag_name"].startswith("v") and not ver["prerelease"] and not ver["tag_name"].find("-pre") > 0 and ver["tag_name"], remoteVersions))
@@ -142,6 +150,7 @@ class Updater:
async def do_update(self):
logger.debug("Starting update.")
assert self.remoteVer
version = self.remoteVer["tag_name"]
download_url = None
download_filename = "PluginLoader" if ON_LINUX else "PluginLoader.exe"
+43 -31
View File
@@ -1,3 +1,4 @@
from os import stat_result
import uuid
from json.decoder import JSONDecodeError
from os.path import splitext
@@ -5,12 +6,12 @@ import re
from traceback import format_exc
from stat import FILE_ATTRIBUTE_HIDDEN # type: ignore
from asyncio import start_server, gather, open_connection
from asyncio import StreamReader, StreamWriter, start_server, gather, open_connection
from aiohttp import ClientSession, web
from typing import Dict
from typing import Callable, Coroutine, Dict, Any, List, TypedDict
from logging import getLogger
from backend.browser import PluginInstallType
from backend.browser import PluginInstallRequest, PluginInstallType
from backend.main import PluginManager
from injector import inject_to_tab, get_gamepadui_tab, close_old_tabs, get_tab
from pathlib import Path
@@ -18,10 +19,15 @@ from localplatform import ON_WINDOWS
import helpers
from localplatform import service_stop, service_start, get_home_path, get_username
class FilePickerObj(TypedDict):
file: Path
filest: stat_result
is_dir: bool
class Utilities:
def __init__(self, context: PluginManager) -> None:
self.context = context
self.util_methods: Dict[] = {
self.util_methods: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
"ping": self.ping,
"http_request": self.http_request,
"install_plugin": self.install_plugin,
@@ -54,7 +60,7 @@ class Utilities:
web.post("/methods/{method_name}", self._handle_server_method_call)
])
async def _handle_server_method_call(self, request):
async def _handle_server_method_call(self, request: web.Request):
method_name = request.match_info["method_name"]
try:
args = await request.json()
@@ -70,7 +76,7 @@ class Utilities:
res["success"] = False
return web.json_response(res)
async def install_plugin(self, artifact="", name="No name", version="dev", hash=False, install_type=PluginInstallType.INSTALL):
async def install_plugin(self, artifact: str="", name: str="No name", version: str="dev", hash: str="", install_type: PluginInstallType=PluginInstallType.INSTALL):
return await self.context.plugin_browser.request_plugin_install(
artifact=artifact,
name=name,
@@ -79,21 +85,21 @@ class Utilities:
install_type=install_type
)
async def install_plugins(self, requests):
async def install_plugins(self, requests: List[PluginInstallRequest]):
return await self.context.plugin_browser.request_multiple_plugin_installs(
requests=requests
)
async def confirm_plugin_install(self, request_id):
async def confirm_plugin_install(self, request_id: str):
return await self.context.plugin_browser.confirm_plugin_install(request_id)
def cancel_plugin_install(self, request_id):
async def cancel_plugin_install(self, request_id: str):
return self.context.plugin_browser.cancel_plugin_install(request_id)
async def uninstall_plugin(self, name):
async def uninstall_plugin(self, name: str):
return await self.context.plugin_browser.uninstall_plugin(name)
async def http_request(self, method="", url="", **kwargs):
async def http_request(self, method: str="", url: str="", **kwargs: Any):
async with ClientSession() as web:
res = await web.request(method, url, ssl=helpers.get_ssl_context(), **kwargs)
text = await res.text()
@@ -103,12 +109,13 @@ class Utilities:
"body": text
}
async def ping(self, **kwargs):
async def ping(self, **kwargs: Any):
return "pong"
async def execute_in_tab(self, tab, run_async, code):
async def execute_in_tab(self, tab: str, run_async: bool, code: str):
try:
result = await inject_to_tab(tab, code, run_async)
assert result
if "exceptionDetails" in result["result"]:
return {
"success": False,
@@ -125,7 +132,7 @@ class Utilities:
"result": e
}
async def inject_css_into_tab(self, tab, style):
async def inject_css_into_tab(self, tab: str, style: str):
try:
css_id = str(uuid.uuid4())
@@ -139,7 +146,7 @@ class Utilities:
}})()
""", False)
if "exceptionDetails" in result["result"]:
if result and "exceptionDetails" in result["result"]:
return {
"success": False,
"result": result["result"]
@@ -155,7 +162,7 @@ class Utilities:
"result": e
}
async def remove_css_from_tab(self, tab, css_id):
async def remove_css_from_tab(self, tab: str, css_id: str):
try:
result = await inject_to_tab(tab,
f"""
@@ -167,7 +174,7 @@ class Utilities:
}})()
""", False)
if "exceptionDetails" in result["result"]:
if result and "exceptionDetails" in result["result"]:
return {
"success": False,
"result": result
@@ -182,10 +189,10 @@ class Utilities:
"result": e
}
async def get_setting(self, key, default):
async def get_setting(self, key: str, default: Any):
return self.context.settings.getSetting(key, default)
async def set_setting(self, key, value):
async def set_setting(self, key: str, value: Any):
return self.context.settings.setSetting(key, value)
async def allow_remote_debugging(self):
@@ -210,17 +217,18 @@ class Utilities:
if path == None:
path = get_home_path()
path = Path(path).resolve()
path_obj = Path(path).resolve()
files, folders = [], []
files: List[FilePickerObj] = []
folders: List[FilePickerObj] = []
#Resolving all files/folders in the requested directory
for file in path.iterdir():
for file in path_obj.iterdir():
if file.exists():
filest = file.stat()
is_hidden = file.name.startswith('.')
if ON_WINDOWS and not is_hidden:
is_hidden = bool(filest.st_file_attributes & FILE_ATTRIBUTE_HIDDEN)
is_hidden = bool(filest.st_file_attributes & FILE_ATTRIBUTE_HIDDEN) # type: ignore
if include_folders and file.is_dir():
if (is_hidden and include_hidden) or not is_hidden:
folders.append({"file": file, "filest": filest, "is_dir": True})
@@ -234,9 +242,9 @@ class Utilities:
if filter_for is not None:
try:
if re.compile(filter_for):
files = filter(lambda file: re.search(filter_for, file.name) != None, files)
files = list(filter(lambda file: re.search(filter_for, file["file"].name) != None, files))
except re.error:
files = filter(lambda file: file.name.find(filter_for) != -1, files)
files = list(filter(lambda file: file["file"].name.find(filter_for) != -1, files))
# Ordering logic
ord_arg = order_by.split("_")
@@ -256,6 +264,9 @@ class Utilities:
files.sort(key=lambda x: x['filest'].st_size, reverse = not rev)
# Folders has no file size, order by name instead
folders.sort(key=lambda x: x['file'].name.casefold())
case _:
files.sort(key=lambda x: x['file'].name.casefold(), reverse = rev)
folders.sort(key=lambda x: x['file'].name.casefold(), reverse = rev)
#Constructing the final file list, folders first
all = [{
@@ -275,14 +286,14 @@ class Utilities:
# Based on https://stackoverflow.com/a/46422554/13174603
def start_rdt_proxy(self, ip, port):
async def pipe(reader, writer):
def start_rdt_proxy(self, ip: str, port: int):
async def pipe(reader: StreamReader, writer: StreamWriter):
try:
while not reader.at_eof():
writer.write(await reader.read(2048))
finally:
writer.close()
async def handle_client(local_reader, local_writer):
async def handle_client(local_reader: StreamReader, local_writer: StreamWriter):
try:
remote_reader, remote_writer = await open_connection(
ip, port)
@@ -298,7 +309,8 @@ class Utilities:
def stop_rdt_proxy(self):
if self.rdt_proxy_server:
self.rdt_proxy_server.close()
self.rdt_proxy_task.cancel()
if self.rdt_proxy_task:
self.rdt_proxy_task.cancel()
async def _enable_rdt(self):
# TODO un-hardcode port
@@ -348,11 +360,11 @@ class Utilities:
await tab.evaluate_js("location.reload();", False, True, False)
self.logger.info("React DevTools disabled")
async def get_user_info(self) -> dict:
async def get_user_info(self) -> Dict[str, str]:
return {
"username": get_username(),
"path": get_home_path()
}
async def get_tab_id(self, name):
async def get_tab_id(self, name: str):
return (await get_tab(name)).id