finalize api for plugin events in backend

just need frontend impl now
This commit is contained in:
AAGaming
2023-12-31 20:29:19 -05:00
parent db96121304
commit c5ea95a787
8 changed files with 93 additions and 50 deletions
+15 -6
View File
@@ -1,6 +1,5 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Queue, sleep
from json.decoder import JSONDecodeError
from logging import getLogger
from os import listdir, path
from pathlib import Path
@@ -9,6 +8,7 @@ from typing import Any, Tuple, Dict
from aiohttp import web
from os.path import exists
from attr import dataclass
from watchdog.events import RegexMatchingEventHandler, DirCreatedEvent, DirModifiedEvent, FileCreatedEvent, FileModifiedEvent # type: ignore
from watchdog.observers import Observer # type: ignore
@@ -23,10 +23,6 @@ from .wsrouter import WSRouter
Plugins = dict[str, PluginWrapper]
ReloadQueue = Queue[Tuple[str, str, bool | None] | Tuple[str, str]]
#TODO: Remove placeholder method
async def log_plugin_emitted_message(message: Any):
getLogger().debug(f"EMITTED MESSAGE: " + str(message))
class FileChangeHandler(RegexMatchingEventHandler):
def __init__(self, queue: ReloadQueue, plugin_path: str) -> None:
super().__init__(regexes=[r'^.*?dist\/index\.js$', r'^.*?main\.py$']) # type: ignore
@@ -70,10 +66,17 @@ class FileChangeHandler(RegexMatchingEventHandler):
self.logger.debug(f"file modified: {src_path}")
self.maybe_reload(src_path)
@dataclass
class PluginEvent:
plugin_name: str
event: str
data: str
class Loader:
def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None:
self.loop = loop
self.logger = getLogger("Loader")
self.ws = ws
self.plugin_path = plugin_path
self.logger.info(f"plugin_path: {self.plugin_path}")
self.plugins: Plugins = {}
@@ -149,8 +152,14 @@ class Loader:
self.plugins.pop(plugin.name, None)
if plugin.passive:
self.logger.info(f"Plugin {plugin.name} is passive")
async def plugin_emitted_event(event: str, data: Any):
self.logger.debug(f"PLUGIN EMITTED EVENT: {str(event)} {data}")
event_data = PluginEvent(plugin_name=plugin.name, event=event, data=data)
await self.ws.emit("plugin_event", event_data)
self.plugins[plugin.name].set_emitted_event_callback(plugin_emitted_event)
self.plugins[plugin.name] = plugin.start()
self.plugins[plugin.name].set_emitted_message_callback(log_plugin_emitted_message)
self.logger.info(f"Loaded {plugin.name}")
if not batch:
self.loop.create_task(self.dispatch_plugin(plugin.name, plugin.version))
+4 -3
View File
@@ -19,7 +19,7 @@ import subprocess
import logging
import time
from typing import Dict, Any
from typing import TypeVar, Type
"""
Constants
@@ -213,9 +213,10 @@ logger.setLevel(logging.INFO)
"""
Event handling
"""
DataType = TypeVar("DataType")
# TODO better docstring im lazy
async def emit_message(message: Dict[Any, Any]) -> None:
async def emit(event: str, data: DataType | None = None, data_type: Type[DataType] | None = None) -> None:
"""
Send a message to the frontend.
Send an event to the frontend.
"""
pass
@@ -16,7 +16,7 @@ __version__ = '0.1.0'
import logging
from typing import Dict, Any
from typing import TypeVar, Type
"""
Constants
@@ -177,8 +177,9 @@ logger: logging.Logger
"""
Event handling
"""
DataType = TypeVar("DataType")
# TODO better docstring im lazy
async def emit_message(message: Dict[Any, Any]) -> None:
async def emit(event: str, data: DataType | None = None, data_type: Type[DataType] | None = None) -> None:
"""
Send a message to the frontend.
Send an event to the frontend.
"""
@@ -1,8 +1,15 @@
from typing import Any, TypedDict
from enum import IntEnum
from uuid import uuid4
from asyncio import Event
class SocketMessageType(IntEnum):
CALL = 0
RESPONSE = 1
EVENT = 2
class SocketResponseDict(TypedDict):
type: SocketMessageType
id: str
success: bool
res: Any
+14 -11
View File
@@ -6,11 +6,13 @@ from multiprocessing import Process
from .sandboxed_plugin import SandboxedPlugin
from .method_call_request import MethodCallRequest
from .messages import MethodCallRequest, SocketMessageType
from ..localplatform.localsocket import LocalSocket
from typing import Any, Callable, Coroutine, Dict, List
EmittedEventCallbackType = Callable[[str, Any], Coroutine[Any, Any, Any]]
class PluginWrapper:
def __init__(self, file: str, plugin_directory: str, plugin_path: str) -> None:
self.file = file
@@ -27,18 +29,19 @@ class PluginWrapper:
self.name = json["name"]
self.author = json["author"]
self.flags = json["flags"]
self.api_version = json["api_version"] if "api_version" in json else 0
self.passive = not path.isfile(self.file)
self.log = getLogger("plugin")
self.sandboxed_plugin = SandboxedPlugin(self.name, self.passive, self.flags, self.file, self.plugin_directory, self.plugin_path, self.version, self.author)
#TODO: Maybe make LocalSocket not require on_new_message to make this cleaner
self.sandboxed_plugin = SandboxedPlugin(self.name, self.passive, self.flags, self.file, self.plugin_directory, self.plugin_path, self.version, self.author, self.api_version)
# TODO: Maybe make LocalSocket not require on_new_message to make this cleaner
self._socket = LocalSocket(self.sandboxed_plugin.on_new_message)
self._listener_task: Task[Any]
self._method_call_requests: Dict[str, MethodCallRequest] = {}
self.emitted_message_callback: Callable[[Dict[Any, Any]], Coroutine[Any, Any, Any]]
self.emitted_event_callback: EmittedEventCallbackType
self.legacy_method_warning = False
@@ -51,15 +54,15 @@ class PluginWrapper:
line = await self._socket.read_single_line()
if line != None:
res = loads(line)
if res["id"] == "0":
create_task(self.emitted_message_callback(res["payload"]))
else:
if res["type"] == SocketMessageType.EVENT.value:
create_task(self.emitted_event_callback(res["event"], res["data"]))
elif res["type"] == SocketMessageType.RESPONSE.value:
self._method_call_requests.pop(res["id"]).set_result(res)
except:
pass
def set_emitted_message_callback(self, callback: Callable[[Dict[Any, Any]], Coroutine[Any, Any, Any]]):
self.emitted_message_callback = callback
def set_emitted_event_callback(self, callback: EmittedEventCallbackType):
self.emitted_event_callback = callback
async def execute_legacy_method(self, method_name: str, kwargs: Dict[Any, Any]):
if not self.legacy_method_warning:
@@ -70,7 +73,7 @@ class PluginWrapper:
request = MethodCallRequest()
await self._socket.get_socket_connection()
await self._socket.write_single_line(dumps({ "method": method_name, "args": kwargs, "id": request.id, "legacy": True }, ensure_ascii=False))
await self._socket.write_single_line(dumps({ "type": SocketMessageType.CALL, "method": method_name, "args": kwargs, "id": request.id, "legacy": True }, ensure_ascii=False))
self._method_call_requests[request.id] = request
return await request.wait_for_result()
@@ -81,7 +84,7 @@ class PluginWrapper:
request = MethodCallRequest()
await self._socket.get_socket_connection()
await self._socket.write_single_line(dumps({ "method": method_name, "args": args, "id": request.id }, ensure_ascii=False))
await self._socket.write_single_line(dumps({ "type": SocketMessageType.CALL, "method": method_name, "args": args, "id": request.id }, ensure_ascii=False))
self._method_call_requests[request.id] = request
return await request.wait_for_result()
+36 -13
View File
@@ -8,13 +8,17 @@ from traceback import format_exc
from asyncio import (get_event_loop, new_event_loop,
set_event_loop, sleep)
from .method_call_request import SocketResponseDict
from backend.decky_loader.plugin.messages import SocketMessageType
from .messages import SocketResponseDict, SocketMessageType
from ..localplatform.localsocket import LocalSocket
from ..localplatform.localplatform import setgid, setuid, get_username, get_home_path
from ..customtypes import UserType
from .. import helpers
from typing import Any, Dict, List
from typing import List, TypeVar, Type
DataType = TypeVar("DataType")
class SandboxedPlugin:
def __init__(self,
@@ -25,7 +29,8 @@ class SandboxedPlugin:
plugin_directory: str,
plugin_path: str,
version: str|None,
author: str) -> None:
author: str,
api_version: int) -> None:
self.name = name
self.passive = passive
self.flags = flags
@@ -34,6 +39,7 @@ class SandboxedPlugin:
self.plugin_directory = plugin_directory
self.version = version
self.author = author
self.api_version = api_version
self.log = getLogger("plugin")
@@ -79,10 +85,11 @@ class SandboxedPlugin:
sysmodules[key.replace("decky_loader.", "")] = sysmodules[key]
from .imports import decky
async def emit_message(message: Dict[Any, Any]):
async def emit_message(event: str, data: DataType | None = None, data_type: Type[DataType] | None = None) -> None:
await self._socket.write_single_line_server(dumps({
"id": "0",
"payload": message
"type": SocketMessageType.EVENT,
"event": event,
"data": data
}))
# copy the docstring over so we don't have to duplicate it
emit_message.__doc__ = decky.emit_message.__doc__
@@ -97,12 +104,21 @@ class SandboxedPlugin:
assert spec.loader is not None
spec.loader.exec_module(module)
# TODO fix self weirdness once plugin.json versioning is done. need this before WS release!
self.Plugin = module.Plugin
if self.api_version > 0:
self.Plugin = module.Plugin()
else:
self.Plugin = module.Plugin
if hasattr(self.Plugin, "_migration"):
get_event_loop().run_until_complete(self.Plugin._migration(self.Plugin))
if self.api_version > 0:
get_event_loop().run_until_complete(self.Plugin._migration())
else:
get_event_loop().run_until_complete(self.Plugin._migration(self.Plugin))
if hasattr(self.Plugin, "_main"):
get_event_loop().create_task(self.Plugin._main(self.Plugin))
if self.api_version > 0:
get_event_loop().create_task(self.Plugin._main())
else:
get_event_loop().create_task(self.Plugin._main(self.Plugin))
get_event_loop().create_task(socket.setup_server())
get_event_loop().run_forever()
except:
@@ -113,7 +129,10 @@ class SandboxedPlugin:
try:
self.log.info("Attempting to unload with plugin " + self.name + "'s \"_unload\" function.\n")
if hasattr(self.Plugin, "_unload"):
await self.Plugin._unload(self.Plugin)
if self.api_version > 0:
await self.Plugin._unload()
else:
await self.Plugin._unload(self.Plugin)
self.log.info("Unloaded " + self.name + "\n")
else:
self.log.info("Could not find \"_unload\" in " + self.name + "'s main.py" + "\n")
@@ -121,7 +140,7 @@ class SandboxedPlugin:
self.log.error("Failed to unload " + self.name + "!\n" + format_exc())
exit(0)
async def on_new_message(self, message : str) -> str|None:
async def on_new_message(self, message : str) -> str | None:
data = loads(message)
if "stop" in data:
@@ -133,14 +152,18 @@ class SandboxedPlugin:
await self._unload()
raise Exception("Closing message listener")
d: SocketResponseDict = {"res": None, "success": True, "id": data["id"]}
d: SocketResponseDict = {"type": SocketMessageType.RESPONSE, "res": None, "success": True, "id": data["id"]}
try:
if data["legacy"]:
if self.api_version > 0:
raise Exception("Legacy methods may not be used on api_version > 0")
# Legacy kwargs
d["res"] = await getattr(self.Plugin, data["method"])(self.Plugin, **data["args"])
else:
if self.api_version < 1 :
raise Exception("api_version 1 or newer is required to call methods with index-based arguments")
# New args
d["res"] = await getattr(self.Plugin, data["method"])(self.Plugin, *data["args"])
d["res"] = await getattr(self.Plugin, data["method"])(*data["args"])
except Exception as e:
d["res"] = str(e)
d["success"] = False
+3 -4
View File
@@ -1,14 +1,13 @@
from __future__ import annotations
from os import stat_result
import uuid
from json.decoder import JSONDecodeError
from os.path import splitext
import re
from traceback import format_exc
from stat import FILE_ATTRIBUTE_HIDDEN # type: ignore
from asyncio import StreamReader, StreamWriter, start_server, gather, open_connection
from aiohttp import ClientSession, web
from aiohttp import ClientSession
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Any, List, TypedDict
from logging import getLogger
@@ -30,7 +29,7 @@ class FilePickerObj(TypedDict):
class Utilities:
def __init__(self, context: PluginManager) -> None:
self.context = context
self.util_methods: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
self.legacy_util_methods: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
"ping": self.ping,
"http_request": self.http_request,
"install_plugin": self.install_plugin,
@@ -84,7 +83,7 @@ class Utilities:
self.logger.debug(f"Calling utility {method_name} with legacy kwargs");
res: Dict[Any, Any] = {}
try:
r = await self.util_methods[method_name](**kwargs)
r = await self.legacy_util_methods[method_name](**kwargs)
res["result"] = r
res["success"] = True
except Exception as e:
+10 -10
View File
@@ -1,3 +1,4 @@
from _typeshed import DataclassInstance
from logging import getLogger
from asyncio import AbstractEventLoop, create_task
@@ -8,7 +9,8 @@ from aiohttp.web import Application, WebSocketResponse, Request, Response, get
from enum import IntEnum
from typing import Callable, Coroutine, Dict, Any, cast, TypeVar, Type
from dataclasses import dataclass
from dataclasses import asdict, is_dataclass
from traceback import format_exc
@@ -24,15 +26,9 @@ class MessageType(IntEnum):
# WSMessage with slightly better typings
class WSMessageExtra(WSMessage):
# TODO message typings here too
data: Any
type: WSMsgType
@dataclass
class Message:
data: Any
type: MessageType
# @dataclass
# class CallMessage
# see wsrouter.ts for typings
@@ -133,7 +129,11 @@ class WSRouter:
return ws
# DataType defaults to None so that if a plugin opts in to strict pyright checking and attempts to pass data witbout specifying the type (or any), the type check fails
async def emit(self, event: str, data: DataType | None = None, data_type: Type[DataType]|None = None):
async def emit(self, event: str, data: DataType | None = None, data_type: Type[DataType] | None = None):
self.logger.debug('Firing frontend event %s with args %s', data)
sent_data: Dict[Any, Any] | None = cast(Dict[Any, Any], data)
if is_dataclass(data):
data_as_dataclass = cast(DataclassInstance, data)
sent_data = asdict(data_as_dataclass)
await self.write({ "type": MessageType.EVENT.value, "event": event, "data": data })
await self.write({ "type": MessageType.EVENT.value, "event": event, "data": sent_data })