-
Notifications
You must be signed in to change notification settings - Fork 548
feat(modules): async modules #1920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright 2026 Dimensional Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import asyncio | ||
| from concurrent.futures import Future | ||
| from typing import Protocol | ||
|
|
||
| from dimos.core.coordination.blueprints import autoconnect | ||
| from dimos.core.coordination.module_coordinator import ModuleCoordinator | ||
| from dimos.core.core import arpc, rpc | ||
| from dimos.core.module import Module | ||
| from dimos.core.stream import In, Out | ||
| from dimos.spec.utils import Spec | ||
|
|
||
|
|
||
| class Dubler(Module): | ||
| a: In[int] | ||
| double_a: Out[int] | ||
|
|
||
| async def handle_a(self, x: int) -> None: | ||
| self.double_a.publish(x * 2) | ||
|
|
||
| @arpc | ||
| async def find_duble(self, x: int) -> int: | ||
| await asyncio.sleep(0.5) | ||
| return x * 2 | ||
|
|
||
|
|
||
| class DublerSpec(Spec, Protocol): | ||
| async def find_duble(self, x: int) -> int: ... | ||
|
|
||
|
|
||
| class StartModule(Module): | ||
| _dubler: DublerSpec | ||
| _timer_future: Future | None = None | ||
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| super().start() | ||
| self._timer_future = self.spawn(self._timer_loop()) | ||
|
|
||
| async def _timer_loop(self) -> None: | ||
| i = 1 | ||
| import time | ||
|
|
||
| while True: | ||
| await asyncio.sleep(1.0) | ||
| print("Finding duble of", i, "time=", time.time()) | ||
| ret = await self._dubler.find_duble(i) | ||
| print("Found duble of", ret, "time=", time.time()) | ||
| i += 1 | ||
| if i == 3: | ||
| raise Exception("asdf") | ||
|
|
||
| @rpc | ||
| def stop(self) -> None: | ||
| if self._timer_future is not None: | ||
| self._timer_future.cancel() | ||
| self._timer_future = None | ||
| super().stop() | ||
|
|
||
|
|
||
| blueprint = autoconnect(StartModule.blueprint(), Dubler.blueprint()) | ||
|
|
||
| if __name__ == "__main__": | ||
| ModuleCoordinator.build(blueprint).loop() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,13 +15,17 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import functools | ||
| import inspect | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| TypeVar, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Callable | ||
| from collections.abc import Callable, Coroutine | ||
|
|
||
| T = TypeVar("T") | ||
|
|
||
|
|
@@ -34,3 +38,36 @@ | |
| def rpc(fn: Callable[P, R]) -> Callable[P, R]: | ||
| fn.__rpc__ = True # type: ignore[attr-defined] | ||
| return fn | ||
|
|
||
|
|
||
| def arpc(fn: Callable[..., Coroutine[Any, Any, Any]]) -> Callable[..., Any]: | ||
| """Mark an async method as an RPC body that runs on the module's self._loop. | ||
|
|
||
| Dual-mode dispatch: | ||
| * Caller is on self._loop (another @arpc, a handle_*, or a process_observable | ||
| callback): returns the coroutine so the caller can ``await`` it normally. | ||
| * Caller is on any other thread (RPC dispatcher, sync test, sync @rpc on the | ||
| same module): schedules the coroutine onto self._loop and blocks until done. | ||
|
|
||
| Discovery is shared with @rpc — sets ``__rpc__ = True`` so the method appears | ||
| in ``Module.rpcs`` and is served by the existing RPC machinery without changes. | ||
| """ | ||
| if not inspect.iscoroutinefunction(fn): | ||
| raise TypeError("@arpc requires an `async def` method") | ||
|
|
||
| @functools.wraps(fn) | ||
| def wrapper(self, *args: Any, **kwargs: Any) -> Any: | ||
| loop = getattr(self, "_loop", None) | ||
| try: | ||
| running = asyncio.get_running_loop() | ||
| except RuntimeError: | ||
| running = None | ||
| if running is loop: | ||
| return fn(self, *args, **kwargs) | ||
| future = asyncio.run_coroutine_threadsafe(fn(self, *args, **kwargs), loop) | ||
| return future.result() | ||
|
Comment on lines
+59
to
+68
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When A minimal guard prevents this: if loop is None:
raise RuntimeError(
f"@arpc method called before {type(self).__name__}._loop is initialised"
) |
||
|
|
||
| wrapper.__rpc__ = True # type: ignore[attr-defined] | ||
| wrapper.__arpc__ = True # type: ignore[attr-defined] | ||
| wrapper.aio = fn # type: ignore[attr-defined] | ||
| return wrapper | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,8 +45,14 @@ | |
| from dimos.protocol.tf.tf import LCMTF, TFSpec | ||
| from dimos.utils import colors | ||
| from dimos.utils.generic import classproperty | ||
| from dimos.utils.logging_config import setup_logger | ||
|
|
||
| logger = setup_logger() | ||
|
|
||
| if TYPE_CHECKING: | ||
| from reactivex import Observable | ||
| from reactivex.abc import DisposableBase | ||
|
|
||
| from dimos.core.coordination.blueprints import Blueprint | ||
| from dimos.core.introspection.module.info import ModuleInfo | ||
| from dimos.core.rpc_client import RPCClient | ||
|
|
@@ -64,13 +70,46 @@ class SkillInfo: | |
| args_schema: str | ||
|
|
||
|
|
||
| def _log_task_exception(task: asyncio.Task[Any]) -> None: | ||
| if task.cancelled(): | ||
| return | ||
| try: | ||
| exc = task.exception() | ||
| except asyncio.InvalidStateError: | ||
| return | ||
| if exc is None or isinstance(exc, asyncio.CancelledError): | ||
| return | ||
| # Calling task.exception() above marks the exception as retrieved, so | ||
| # asyncio's GC-time logger won't fire — we must log here. | ||
| name = task.get_name() | ||
| logger.error( | ||
| f"Unhandled exception in async task {name!r}: {type(exc).__name__}: {exc}", | ||
| exc_info=exc, | ||
| ) | ||
|
|
||
|
|
||
| def _logging_task_factory( | ||
| loop: asyncio.AbstractEventLoop, coro: Any, **kwargs: Any | ||
| ) -> asyncio.Task[Any]: | ||
| """Default task factory for module-owned loops: every task gets a done | ||
| callback that logs unhandled exceptions. Without this, exceptions in | ||
| coroutines scheduled via bare ``asyncio.run_coroutine_threadsafe`` (where | ||
| the returned ``concurrent.futures.Future`` is never read) disappear | ||
| silently. This is independent of and complementary to ``Module.spawn``. | ||
| """ | ||
| task = asyncio.Task(coro, loop=loop, **kwargs) | ||
| task.add_done_callback(_log_task_exception) | ||
| return task | ||
|
|
||
|
|
||
| def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: | ||
| try: | ||
| running_loop = asyncio.get_running_loop() | ||
| return running_loop, None | ||
| except RuntimeError: | ||
| loop = asyncio.new_event_loop() | ||
| asyncio.set_event_loop(loop) | ||
| loop.set_task_factory(_logging_task_factory) | ||
|
|
||
| thr = threading.Thread(target=loop.run_forever, daemon=True) | ||
| thr.start() | ||
|
|
@@ -150,7 +189,90 @@ def build(self) -> None: | |
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| pass | ||
| self._auto_bind_handlers() | ||
|
|
||
| def _auto_bind_handlers(self) -> None: | ||
| """For each declared In[T] x, if `async def handle_x` exists, subscribe | ||
| it via process_observable so it runs on self._loop. | ||
|
|
||
| Called from ModuleBase.start(). Subclasses opt in by calling | ||
| super().start() — modules that don't are unaffected. | ||
| """ | ||
| for input_name, in_stream in self.inputs.items(): | ||
| handler = getattr(self, f"handle_{input_name}", None) | ||
| if handler is None: | ||
| continue | ||
| # @arpc wraps an async fn in a sync dispatcher; unwrap it so we | ||
| # subscribe the raw coroutine function instead of the wrapper | ||
| # (which would block on run_coroutine_threadsafe from the rx thread). | ||
| if getattr(handler, "__arpc__", False): | ||
| handler = handler.aio.__get__(self, type(self)) | ||
| if not inspect.iscoroutinefunction(handler): | ||
| raise TypeError( | ||
| f"{type(self).__name__}.handle_{input_name} must be `async def` " | ||
| "(use a manual self.<input>.subscribe(...) for sync handlers)" | ||
| ) | ||
| # observable() is backpressured/latest — slow handlers coalesce | ||
| # bursts instead of growing an unbounded queue on the loop. | ||
| self.process_observable(in_stream.observable(), handler) | ||
|
|
||
| def process_observable( | ||
| self, | ||
| observable: "Observable[Any]", | ||
| async_cb: Callable[[Any], Any], | ||
| ) -> "DisposableBase": | ||
| """Subscribe `async_cb` (an async function) to `observable`, dispatching | ||
| each emitted value onto self._loop. The subscription is registered for | ||
| cleanup on stop().""" | ||
| if not inspect.iscoroutinefunction(async_cb): | ||
| raise TypeError("process_observable requires an `async def` callback") | ||
| sub = observable.subscribe(self._make_async_dispatch(async_cb)) | ||
| return self.register_disposable(sub) | ||
|
|
||
| def _make_async_dispatch(self, async_handler: Callable[[Any], Any]) -> Callable[[Any], None]: | ||
| """Build a sync callback that schedules `async_handler(msg)` onto self._loop.""" | ||
|
|
||
| def on_msg(msg: Any) -> None: | ||
| loop = self._loop | ||
| if loop is None or not loop.is_running(): | ||
| return | ||
| future = asyncio.run_coroutine_threadsafe(async_handler(msg), loop) | ||
| future.add_done_callback(self._log_async_handler_error) | ||
|
|
||
| return on_msg | ||
|
|
||
| def spawn(self, coro: Any) -> Any: | ||
| """ | ||
| Schedule a coroutine on self._loop from any thread. | ||
|
|
||
| Use this instead of bare `asyncio.run_coroutine_threadsafe(coro, | ||
| self._loop)` when scheduling a long-running async task sync context like | ||
| start(). | ||
|
|
||
| Unhandled exceptions are routed to the module logger instead of being | ||
| silently stored in the returned Future, which is the common pitfall when | ||
| nothing ever reads `.result()`. | ||
| """ | ||
|
|
||
| loop = self._loop | ||
| if loop is None or not loop.is_running(): | ||
| raise RuntimeError(f"{type(self).__name__}._loop is not running") | ||
| future = asyncio.run_coroutine_threadsafe(coro, loop) | ||
| future.add_done_callback(self._log_async_handler_error) | ||
| return future | ||
|
|
||
| def _log_async_handler_error(self, fut: Any) -> None: | ||
| try: | ||
| fut.result() | ||
| except (asyncio.CancelledError, RuntimeError): | ||
| pass # loop stopped or task cancelled during shutdown | ||
| except BaseException as e: | ||
| # Include exception type+message in the event string so it is | ||
| # visible on consoles whose formatters strip exc_info/traceback. | ||
| logger.exception( | ||
| f"Unhandled error in async task on {type(self).__name__}._loop: " | ||
| f"{type(e).__name__}: {e}" | ||
| ) | ||
|
Comment on lines
+264
to
+275
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The handler silently discards any import concurrent.futures
def _log_async_handler_error(self, fut: Any) -> None:
try:
fut.result()
except (asyncio.CancelledError, concurrent.futures.CancelledError):
pass # task cancelled during shutdown
except RuntimeError as e:
if "event loop" in str(e).lower() or "loop is closed" in str(e).lower():
pass # loop shut down before task completed
else:
logger.exception(
f"Unhandled error in async task on {type(self).__name__}._loop: "
f"{type(e).__name__}: {e}"
)
except BaseException as e:
logger.exception(
f"Unhandled error in async task on {type(self).__name__}._loop: "
f"{type(e).__name__}: {e}"
)As written, any |
||
|
|
||
| @rpc | ||
| def stop(self) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dimos/asdf.pyis clearly a temporary scratch file: the filename is a keyboard-mash, and the_timer_loopmethod deliberately raisesraise Exception("asdf")after two iterations. This file is in the maindimos/package directory (importable by any consumer), not in a test directory, and the PR description does not mention it as intended shipped code. It should be removed or moved to a dedicated example/test location before merging.