mirror of
https://github.com/SteamDeckHomebrew/decky-loader.git
synced 2026-06-13 12:15:09 +03:00
more progress on websockets
This commit is contained in:
@@ -88,7 +88,7 @@ class Loader:
|
||||
self.observer.start()
|
||||
self.loop.create_task(self.enable_reload_wait())
|
||||
|
||||
server_instance.add_routes([
|
||||
server_instance.web_app.add_routes([
|
||||
web.get("/frontend/{path:.*}", self.handle_frontend_assets),
|
||||
web.get("/locales/{path:.*}", self.handle_frontend_locales),
|
||||
web.get("/plugins", self.get_plugins),
|
||||
|
||||
@@ -166,7 +166,7 @@ class Utilities:
|
||||
style.textContent = `{style}`;
|
||||
}})()
|
||||
""", False)
|
||||
|
||||
assert result is not None # TODO remove this once it has proper typings
|
||||
if "exceptionDetails" in result["result"]:
|
||||
raise result["result"]["exceptionDetails"]
|
||||
|
||||
@@ -233,7 +233,7 @@ class Utilities:
|
||||
folders.append({"file": file, "filest": filest, "is_dir": True})
|
||||
elif include_files:
|
||||
# Handle requested extensions if present
|
||||
if len(include_ext) == 0 or 'all_files' in include_ext \
|
||||
if include_ext == None or len(include_ext) == 0 or 'all_files' in include_ext \
|
||||
or splitext(file.name)[1].lstrip('.') in include_ext:
|
||||
if (is_hidden and include_hidden) or not is_hidden:
|
||||
files.append({"file": file, "filest": filest, "is_dir": False})
|
||||
|
||||
@@ -1,37 +1,51 @@
|
||||
from logging import getLogger
|
||||
|
||||
from asyncio import AbstractEventLoop, Future
|
||||
from asyncio import AbstractEventLoop, Future, create_task
|
||||
|
||||
from aiohttp import WSMsgType
|
||||
from aiohttp import WSMsgType, WSMessage
|
||||
from aiohttp.web import Application, WebSocketResponse, Request, Response, get
|
||||
|
||||
from enum import Enum
|
||||
from enum import IntEnum
|
||||
|
||||
from typing import Dict
|
||||
from typing import Callable, Dict, Any, cast, TypeVar, Type
|
||||
from dataclasses import dataclass
|
||||
|
||||
from traceback import format_exc
|
||||
|
||||
from helpers import get_csrf_token
|
||||
|
||||
class MessageType(Enum):
|
||||
# Call-reply
|
||||
class MessageType(IntEnum):
|
||||
ERROR = -1
|
||||
# Call-reply, Frontend -> Backend
|
||||
CALL = 0
|
||||
REPLY = 1
|
||||
ERROR = 2
|
||||
# # Pub/sub
|
||||
# SUBSCRIBE = 3
|
||||
# UNSUBSCRIBE = 4
|
||||
# PUBLISH = 5
|
||||
# Pub/Sub, Backend -> Frontend
|
||||
EVENT = 3
|
||||
|
||||
# WSMessage with slightly better typings
|
||||
class WSMessageExtra(WSMessage):
|
||||
data: Any
|
||||
type: WSMsgType
|
||||
@dataclass
|
||||
class Message:
|
||||
data: Any
|
||||
type: MessageType
|
||||
|
||||
# @dataclass
|
||||
# class CallMessage
|
||||
|
||||
# see wsrouter.ts for typings
|
||||
|
||||
DataType = TypeVar("DataType")
|
||||
|
||||
Route = Callable[..., Future[Any]]
|
||||
|
||||
class WSRouter:
|
||||
def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
|
||||
self.loop = loop
|
||||
self.ws = None
|
||||
self.req_id = 0
|
||||
self.routes = {}
|
||||
self.running_calls: Dict[int, Future] = {}
|
||||
self.ws: WebSocketResponse | None
|
||||
self.instance_id = 0
|
||||
self.routes: Dict[str, Route] = {}
|
||||
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
|
||||
self.logger = getLogger("WSRouter")
|
||||
|
||||
@@ -39,22 +53,38 @@ class WSRouter:
|
||||
get("/ws", self.handle)
|
||||
])
|
||||
|
||||
async def write(self, dta: Dict[str, any]):
|
||||
await self.ws.send_json(dta)
|
||||
async def write(self, data: Dict[str, Any]):
|
||||
if self.ws != None:
|
||||
await self.ws.send_json(data)
|
||||
else:
|
||||
self.logger.warn("Dropping message as there is no connected socket: %s", data)
|
||||
|
||||
def add_route(self, name: str, route):
|
||||
def add_route(self, name: str, route: Route):
|
||||
self.routes[name] = route
|
||||
|
||||
def remove_route(self, name: str):
|
||||
del self.routes[name]
|
||||
|
||||
async def _call_route(self, route: str, args: ..., call_id: int):
|
||||
instance_id = self.instance_id
|
||||
res = await self.routes[route](*args)
|
||||
if instance_id != self.instance_id:
|
||||
try:
|
||||
self.logger.warn("Ignoring %s reply from stale instance %d with args %s and response %s", route, instance_id, args, res)
|
||||
except:
|
||||
self.logger.warn("Ignoring %s reply from stale instance %d (failed to log event data)", route, instance_id)
|
||||
finally:
|
||||
return
|
||||
await self.write({"type": MessageType.REPLY.value, "id": call_id, "result": res})
|
||||
|
||||
async def handle(self, request: Request):
|
||||
# Auth is a query param as JS WebSocket doesn't support headers
|
||||
if request.rel_url.query["auth"] != get_csrf_token():
|
||||
return Response(text='Forbidden', status='403')
|
||||
return Response(text='Forbidden', status=403)
|
||||
self.logger.debug('Websocket connection starting')
|
||||
ws = WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
self.instance_id += 1
|
||||
self.logger.debug('Websocket connection ready')
|
||||
|
||||
if self.ws != None:
|
||||
@@ -68,6 +98,8 @@ class WSRouter:
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
msg = cast(WSMessageExtra, msg)
|
||||
|
||||
self.logger.debug(msg)
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
self.logger.debug(msg.data)
|
||||
@@ -81,25 +113,13 @@ class WSRouter:
|
||||
# do stuff with the message
|
||||
if data["route"] in self.routes:
|
||||
try:
|
||||
res = await self.routes[data["route"]](*data["args"])
|
||||
await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
|
||||
self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
|
||||
create_task(self._call_route(data["route"], data["args"], data["id"]))
|
||||
except:
|
||||
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
|
||||
create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()}))
|
||||
else:
|
||||
# Dunno why but fstring doesnt work here
|
||||
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."})
|
||||
case MessageType.REPLY.value:
|
||||
if self.running_calls[data["id"]]:
|
||||
self.running_calls[data["id"]].set_result(data["result"])
|
||||
del self.running_calls[data["id"]]
|
||||
self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
|
||||
case MessageType.ERROR.value:
|
||||
if self.running_calls[data["id"]]:
|
||||
self.running_calls[data["id"]].set_exception(data["error"])
|
||||
del self.running_calls[data["id"]]
|
||||
self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')
|
||||
|
||||
create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."}))
|
||||
case _:
|
||||
self.logger.error("Unknown message type", data)
|
||||
finally:
|
||||
@@ -112,17 +132,7 @@ class WSRouter:
|
||||
self.logger.debug('Websocket connection closed')
|
||||
return ws
|
||||
|
||||
async def call(self, route: str, *args):
|
||||
future = Future()
|
||||
async def emit(self, event: str, data: DataType | None = None, data_type: Type[DataType] = Any):
|
||||
self.logger.debug('Firing frontend event %s with args %s', data)
|
||||
|
||||
self.req_id += 1
|
||||
|
||||
id = self.req_id
|
||||
|
||||
self.running_calls[id] = future
|
||||
|
||||
self.logger.debug(f'Calling JS method {route} with args {str(args)}')
|
||||
|
||||
self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })
|
||||
|
||||
return await future
|
||||
await self.write({ "type": MessageType.EVENT.value, "event": event, "data": data })
|
||||
@@ -7,14 +7,12 @@ declare global {
|
||||
}
|
||||
|
||||
enum MessageType {
|
||||
// Call-reply
|
||||
CALL,
|
||||
REPLY,
|
||||
ERROR,
|
||||
// Pub/sub
|
||||
// SUBSCRIBE,
|
||||
// UNSUBSCRIBE,
|
||||
// PUBLISH
|
||||
ERROR = -1,
|
||||
// Call-reply, Frontend -> Backend
|
||||
CALL = 0,
|
||||
REPLY = 1,
|
||||
// Pub/Sub, Backend -> Frontend
|
||||
EVENT = 3,
|
||||
}
|
||||
|
||||
interface CallMessage {
|
||||
|
||||
Reference in New Issue
Block a user