Source code for flowno.utilities.coroutine_wrapper

"""
Coroutine wrapper for Flowno's asynchronous execution logging.

This module provides a wrapper for Python coroutines that adds detailed logging
for debugging purposes. The wrapper logs coroutine execution events including
creation, resumption, yielding, and completion.

For examples and more detailed information, see the :py:mod:`flowno.utilities.logging` module
and its `log_async` decorator which uses this wrapper.
"""

import logging
from collections.abc import Coroutine, Generator
from types import TracebackType
from typing import overload, TypeVar
from typing_extensions import override

Yield = TypeVar("Yield")
T = TypeVar("T")

logger = logging.getLogger(__name__)


[docs] class CoroutineWrapper(Coroutine[Yield, object, T]): """ Wrapper for coroutines to add detailed logging. This wrapper intercepts coroutine operations and logs execution events, making it easier to debug complex asynchronous workflows in Flowno. Used internally by the `log_async` decorator. Logs when the coroutine is: - Created - Started/Resumed (via send or throw) - Yielding commands to the event loop - Returning awaited values Ensures that exceptions are propagated correctly. """ def __init__(self, coro: Coroutine[Yield, object, T], func_name: str, arg_str: str): self._coro = coro # The underlying coroutine self._func_name = func_name # Name of the coroutine function self._arg_str = arg_str # String representation of arguments
[docs] @override def send(self, value: object) -> Yield: try: logger.debug(f"Resuming coroutine: {self._func_name}({self._arg_str}) with send({value!r})") result = self._coro.send(value) logger.debug(f"Coroutine {self._func_name}({self._arg_str}) yielded {result!r}") return result except StopIteration as e: # Coroutine has finished execution # the return type of an async function gets wrapped in a StopIteration final_result: T = e.value # pyright: ignore[reportAny] logger.debug(f"Finished coroutine: {self._func_name}({self._arg_str}) with result {final_result!r}") raise except BaseException as e: # Log any exception raised logger.debug(f"Coroutine {self._func_name}({self._arg_str}) raised exception {e!r}") raise
@overload def throw( self, typ: type[BaseException], val: BaseException | object = None, tb: TracebackType | None = None, / ) -> Yield: ... @overload def throw(self, typ: BaseException, val: None = None, tb: TracebackType | None = None, /) -> Yield: ...
[docs] @override def throw( self, typ: type[BaseException] | BaseException, val: BaseException | object = None, tb: TracebackType | None = None, ) -> Yield: try: typ_name = getattr(typ, "__name__", str(typ)) logger.debug(f"Throwing into coroutine: {self._func_name}({self._arg_str}) exception {typ_name}({val})") if val is None: assert isinstance(typ, BaseException), f"Expected BaseException, got {typ!r}" result = self._coro.throw(typ) else: assert isinstance(typ, type), f"Expected type, got {typ!r}" result = self._coro.throw(typ, val, tb) logger.debug(f"Coroutine {self._func_name}({self._arg_str}) yielded {result!r} after throw") return result except StopIteration as e: final_result: T = e.value # pyright: ignore[reportAny] logger.debug( f"Finished coroutine after throw: {self._func_name}({self._arg_str}) with result {final_result!r}" ) raise except BaseException as e: logger.debug(f"Coroutine {self._func_name}({self._arg_str}) raised exception after throw: {e!r}") raise
[docs] @override def close(self): self._coro.close() logger.debug(f"Coroutine {self._func_name}({self._arg_str}) closed")
@override def __await__(self) -> Generator[Yield, object, T]: # Implement __await__ to return an iterator that drives the coroutine return self._wrap_awaitable(self._coro.__await__()) def _wrap_awaitable(self, awaitable: Generator[Yield, object, T]) -> Generator[Yield, object, T]: try: # Start the coroutine via __await__ logger.debug(f"Starting coroutine: {self._func_name}({self._arg_str}) via __await__") value = yield from awaitable # Coroutine completed logger.debug(f"Coroutine {self._func_name}({self._arg_str}) completed via __await__ with result {value!r}") return value except StopIteration as e: final_result: T = e.value # pyright: ignore[reportAny] logger.debug( f"Finished coroutine via __await__: {self._func_name}({self._arg_str}) with result {final_result!r}" ) raise except BaseException as e: logger.debug(f"Coroutine {self._func_name}({self._arg_str}) raised exception via __await__: {e!r}") raise
__all__ = ["CoroutineWrapper"]