Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions src/loafer/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias

from .exceptions import DeleteMessage, TerminateTaskGroup
from asyncio_extensions import STOP, iterate_queue

from .exceptions import DeleteMessage
from .routes import Route
from .types import ExcInfo, Message

Expand All @@ -13,7 +15,7 @@

logger: logging.Logger = logging.getLogger(__name__)

ProcessingQueue: TypeAlias = asyncio.Queue[tuple[Route, Message]]
ProcessingQueue: TypeAlias = asyncio.Queue[tuple[Message, Route]]


class LoaferDispatcher:
Expand Down Expand Up @@ -77,41 +79,33 @@ async def _fetch_messages(
break

async def _consume_messages(self, processing_queue: ProcessingQueue, tg: asyncio.TaskGroup) -> None:
while True:
message, route = await processing_queue.get()

async for message, route in iterate_queue(processing_queue):
task = tg.create_task(self._process_message(message, route))
try:
async with asyncio.timeout(self.worker_timeout):
await task
except TimeoutError:
logger.exception("message processing timed out, route=%s\n%r\n", route, message)
task.cancel()
finally:
processing_queue.task_done()

async def dispatch_providers(self, *, forever: bool = True) -> None:
processing_queue: ProcessingQueue = ProcessingQueue(self.queue_size)

try:
async with asyncio.TaskGroup() as tg:
provider_tasks: list[asyncio.Task[None]] = [
tg.create_task(self._fetch_messages(processing_queue, route, forever=forever))
for route in self.routes
]
async with asyncio.TaskGroup() as tg:
provider_tasks: list[asyncio.Task[None]] = [
tg.create_task(self._fetch_messages(processing_queue, route, forever=forever)) for route in self.routes
]

for _ in range(self.workers):
tg.create_task(self._consume_messages(processing_queue, tg))

async def join() -> None:
await asyncio.wait(provider_tasks)
await processing_queue.join()
for _ in range(self.workers):
tg.create_task(self._consume_messages(processing_queue, tg))

raise TerminateTaskGroup # noqa: TRY301
await asyncio.wait(provider_tasks)

tg.create_task(join())
except* TerminateTaskGroup:
pass
if sys.version_info >= (3, 13):
processing_queue.shutdown()
else:
for _ in range(self.workers):
await processing_queue.put(STOP) # type: ignore[arg-type]

def stop(self) -> None:
for route in self.routes:
Expand Down
4 changes: 0 additions & 4 deletions src/loafer/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,3 @@ class ProviderError(Exception):

class DeleteMessage(BaseException): # technically not an Exception
pass


class TerminateTaskGroup(BaseException):
"""Exception raised to terminate a task group."""