Experimental support for async method calls

This commit is contained in:
marios8543
2023-10-17 23:52:18 +03:00
parent 949c5e73c4
commit 321242b0d9
3 changed files with 123 additions and 66 deletions
+29
View File
@@ -0,0 +1,29 @@
from typing import Any, TypedDict
from uuid import uuid4
from asyncio import Event
class SocketResponseDict(TypedDict):
id: str
success: bool
res: Any
class MethodCallResponse:
def __init__(self, success: bool, result: Any) -> None:
self.success = success
self.result = result
class MethodCallRequest:
def __init__(self) -> None:
self.id = str(uuid4())
self.event = Event()
self.response: MethodCallResponse
def set_result(self, dc: SocketResponseDict):
self.response = MethodCallResponse(dc["success"], dc["res"])
self.event.set()
async def wait_for_result(self):
await self.event.wait()
if not self.response.success:
raise Exception(self.response.result)
return self.response
+55
View File
@@ -0,0 +1,55 @@
from json import dumps, load, loads
from logging import getLogger
from os import path
from .sandboxed_plugin import SandboxedPlugin
from .method_call_request import MethodCallRequest
from ..localplatform.localsocket import LocalSocket
from typing import Any, Dict
class PluginWrapper:
def __init__(self, file: str, plugin_directory: str, plugin_path: str) -> None:
self.file = file
self.plugin_path = plugin_path
self.plugin_directory = plugin_directory
self.version = None
json = load(open(path.join(plugin_path, plugin_directory, "plugin.json"), "r", encoding="utf-8"))
if path.isfile(path.join(plugin_path, plugin_directory, "package.json")):
package_json = load(open(path.join(plugin_path, plugin_directory, "package.json"), "r", encoding="utf-8"))
self.version = package_json["version"]
self.name = json["name"]
self.author = json["author"]
self.flags = json["flags"]
self.passive = not path.isfile(self.file)
self.log = getLogger("plugin")
self.method_call_requests: Dict[str, MethodCallRequest] = {}
self.sandboxed_plugin = SandboxedPlugin(self.name, self.passive, self.flags, self.file, self.plugin_directory, self.plugin_path, self.version, self.author)
#TODO: Maybe somehow make LocalSocket not require on_new_message to make this more clear
self.socket = LocalSocket(self.sandboxed_plugin.on_new_message)
self.sandboxed_plugin.start(self.socket)
def __str__(self) -> str:
return self.name
async def response_listener(self):
while True:
line = await self.socket.read_single_line()
if line != None:
res = loads(line)
self.method_call_requests.pop(res["id"]).set_result(res)
async def execute_method(self, method_name: str, kwargs: Dict[Any, Any]):
if self.passive:
raise RuntimeError("This plugin is passive (aka does not implement main.py)")
request = MethodCallRequest()
await self.socket.get_socket_connection()
await self.socket.write_single_line(dumps({ "method": method_name, "args": kwargs, "id": request.id }, ensure_ascii=False))
self.method_call_requests[request.id] = request
return await request.wait_for_result()
@@ -1,46 +1,44 @@
import multiprocessing
from asyncio import (Lock, get_event_loop, new_event_loop,
set_event_loop, sleep)
from importlib.util import module_from_spec, spec_from_file_location
from json import dumps, load, loads
from logging import getLogger
from traceback import format_exc
from os import path, environ
from signal import SIGINT, signal
from sys import exit, path as syspath, modules as sysmodules
from typing import Any, Dict
from .localsocket import LocalSocket
from .localplatform import setgid, setuid, get_username, get_home_path
from .customtypes import UserType
from . import helpers
from importlib.util import module_from_spec, spec_from_file_location
from json import dumps, loads
from logging import getLogger
import multiprocessing
from sys import exit, path as syspath
from traceback import format_exc
from asyncio import (get_event_loop, new_event_loop,
set_event_loop, sleep)
class PluginWrapper:
def __init__(self, file: str, plugin_directory: str, plugin_path: str) -> None:
from .method_call_request import SocketResponseDict
from ..localplatform.localsocket import LocalSocket
from ..localplatform.localplatform import setgid, setuid, get_username, get_home_path
from ..customtypes import UserType
from .. import helpers
from typing import List
class SandboxedPlugin:
def __init__(self,
name: str,
passive: bool,
flags: List[str],
file: str,
plugin_directory: str,
plugin_path: str,
version: str|None,
author: str) -> None:
self.name = name
self.passive = passive
self.flags = flags
self.file = file
self.plugin_path = plugin_path
self.plugin_directory = plugin_directory
self.method_call_lock = Lock()
self.socket: LocalSocket = LocalSocket(self._on_new_message)
self.version = None
json = load(open(path.join(plugin_path, plugin_directory, "plugin.json"), "r", encoding="utf-8"))
if path.isfile(path.join(plugin_path, plugin_directory, "package.json")):
package_json = load(open(path.join(plugin_path, plugin_directory, "package.json"), "r", encoding="utf-8"))
self.version = package_json["version"]
self.name = json["name"]
self.author = json["author"]
self.flags = json["flags"]
self.version = version
self.author = author
self.log = getLogger("plugin")
self.passive = not path.isfile(self.file)
def __str__(self) -> str:
return self.name
def _init(self):
def _init(self, socket: LocalSocket):
try:
signal(SIGINT, lambda s, f: exit(0))
@@ -90,7 +88,7 @@ class PluginWrapper:
get_event_loop().run_until_complete(self.Plugin._migration(self.Plugin))
if hasattr(self.Plugin, "_main"):
get_event_loop().create_task(self.Plugin._main(self.Plugin))
get_event_loop().create_task(self.socket.setup_server())
get_event_loop().create_task(socket.setup_server())
get_event_loop().run_forever()
except:
self.log.error("Failed to start " + self.name + "!\n" + format_exc())
@@ -108,7 +106,7 @@ class PluginWrapper:
self.log.error("Failed to unload " + self.name + "!\n" + format_exc())
exit(0)
async def _on_new_message(self, message : str) -> str|None:
async def on_new_message(self, message : str) -> str|None:
data = loads(message)
if "stop" in data:
@@ -121,7 +119,7 @@ class PluginWrapper:
raise Exception("Closing message listener")
# TODO there is definitely a better way to type this
d: Dict[str, Any] = {"res": None, "success": True}
d: SocketResponseDict = {"res": None, "success": True, "id": data["id"]}
try:
d["res"] = await getattr(self.Plugin, data["method"])(self.Plugin, **data["args"])
except Exception as e:
@@ -129,35 +127,10 @@ class PluginWrapper:
d["success"] = False
finally:
return dumps(d, ensure_ascii=False)
def start(self):
def start(self, socket: LocalSocket):
if self.passive:
return self
multiprocessing.Process(target=self._init).start()
return self
def stop(self):
if self.passive:
return
async def _(self: PluginWrapper):
await self.socket.write_single_line(dumps({ "stop": True }, ensure_ascii=False))
await self.socket.close_socket_connection()
get_event_loop().create_task(_(self))
async def execute_method(self, method_name: str, kwargs: Dict[Any, Any]):
if self.passive:
raise RuntimeError("This plugin is passive (aka does not implement main.py)")
async with self.method_call_lock:
# reader, writer =
await self.socket.get_socket_connection()
await self.socket.write_single_line(dumps({ "method": method_name, "args": kwargs }, ensure_ascii=False))
line = await self.socket.read_single_line()
if line != None:
res = loads(line)
if not res["success"]:
raise Exception(res["res"])
return res["res"]
multiprocessing.Process(target=self._init, args=[socket]).start()
return self