"""
Flow execution and graph resolution module for Flowno.
This module contains the Flow class, which is the core execution engine for dataflow graphs.
It manages node scheduling, dependency resolution, cycle breaking, and concurrent execution.
Key components:
- Flow: The main dataflow graph execution engine
- FlowEventLoop: A custom event loop for handling Flow-specific commands
- NodeTaskStatus: State tracking for node execution
"""
import logging
from collections import defaultdict
from collections.abc import AsyncGenerator, Awaitable, Coroutine, Generator
from dataclasses import dataclass
from types import coroutine
from typing import Any, NamedTuple, TypeAlias, cast
from flowno.core.event_loop.commands import Command
from flowno.core.event_loop.event_loop import EventLoop
from flowno.core.event_loop.queues import AsyncSetQueue
from flowno.core.event_loop.types import RawTask, TaskHandlePacket
from flowno.core.flow.instrumentation import get_current_flow_instrument
from flowno.core.node_base import (
DraftInputPortRef,
DraftNode,
FinalizedInputPort,
FinalizedInputPortRef,
FinalizedNode,
MissingDefaultError,
StalledNodeRequestCommand,
SuperNode,
)
from flowno.core.types import Generation, InputPortIndex
from flowno.utilities.helpers import cmp_generation
from flowno.utilities.logging import log_async
from typing_extensions import Never, Unpack, override
logger = logging.getLogger(__name__)
AnyFinalizedNode: TypeAlias = FinalizedNode[Unpack[tuple[Any, ...]], tuple[Any, ...]]
ObjectFinalizedNode: TypeAlias = FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
[docs]
@dataclass
class WaitForStartNextGenerationCommand(Command):
"""Command to wait for a node to start its next generation."""
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
run_level: int = 0
[docs]
@coroutine
def _wait_for_start_next_generation(
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
run_level: int = 0,
) -> Generator[WaitForStartNextGenerationCommand, None, None]:
"""Coroutine that yields a command to wait for a node's next generation."""
return (yield WaitForStartNextGenerationCommand(node, run_level))
[docs]
@dataclass
class TerminateWithExceptionCommand(Command):
"""Command to terminate the flow with an exception."""
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
exception: Exception
[docs]
@coroutine
def _terminate_with_exception(
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
exception: Exception,
) -> Generator[TerminateWithExceptionCommand, None, None]:
"""Coroutine that yields a command to terminate with an exception."""
return (yield TerminateWithExceptionCommand(node, exception))
[docs]
@dataclass
class TerminateReachedLimitCommand(Command):
"""Command to terminate the flow because a node reached its generation limit."""
pass
[docs]
@coroutine
def _terminate_reached_limit() -> Generator[TerminateReachedLimitCommand, None, None]:
"""Coroutine that yields a command to terminate when a generation limit is reached."""
return (yield TerminateReachedLimitCommand())
[docs]
class TerminateLimitReached(Exception):
"""Exception raised when a node reaches its generation limit."""
pass
[docs]
class NodeExecutionError(Exception):
"""Exception raised when a node execution fails."""
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
def __init__(
self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
):
super().__init__(f"Exception in node {node}")
self.node = node
[docs]
@dataclass
class ResumeNodeCommand(Command):
"""Command to resume a node's execution."""
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
[docs]
@coroutine
def _resume_node(
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
) -> Generator[ResumeNodeCommand, None, None]:
"""Resume the concurrent node task. Does not guarantee that the node will resume if already running."""
return (yield ResumeNodeCommand(node))
[docs]
class NodeTaskStatus:
"""
Represents the possible states of a node's task within the flow execution.
States:
- Running: The node is currently executing.
- Ready: The node is ready to execute but not yet running.
- Error: The node encountered an error during execution.
- Stalled: The node is blocked waiting on input data.
"""
[docs]
@dataclass(frozen=True)
class Running:
"""Node is actively executing."""
pass
[docs]
@dataclass(frozen=True)
class Ready:
"""Node is ready to be executed."""
pass
[docs]
@dataclass(frozen=True)
class Error:
"""Node encountered an error during execution."""
pass
[docs]
@dataclass(frozen=True)
class Stalled:
"""Node is stalled waiting for input data."""
stalling_input: FinalizedInputPortRef[object]
Type: TypeAlias = Ready | Running | Error | Stalled
[docs]
class NodeTaskAndStatus(NamedTuple):
"""Container for a node's task and its current status."""
task: RawTask[Command, object, Never]
status: NodeTaskStatus.Type
[docs]
class Flow:
"""
Dataflow graph execution engine.
The Flow class manages the execution of a dataflow graph, handling dependency
resolution, node scheduling, and cycle breaking. It uses a custom event loop
to execute nodes concurrently while respecting data dependencies.
Key features:
- Automatic dependency-based scheduling
- Cycle detection and resolution
- Support for streaming data (run levels)
- Concurrency management
Attributes:
unvisited: List of nodes that have not yet been visited during execution
visited: Set of nodes that have been visited
node_tasks: Dictionary mapping nodes to their tasks and status
running_nodes: Set of nodes currently running
resolution_queue: Queue of nodes waiting to be resolved
"""
# Classvar as instance init
counter: int = 0
# Instance attribute types
unvisited: list[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]]
visited: set[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]]
_stop_at_node_generation: (
dict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], Generation] | Generation
)
node_tasks: dict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], NodeTaskAndStatus]
running_nodes: set[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]]
resolution_queue: AsyncSetQueue[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]]
_defaulted_inputs: defaultdict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], list[InputPortIndex]]
resumable: bool
event_loop: "FlowEventLoop"
def __init__(self, is_finalized: bool = True):
"""
Initialize a new Flow instance.
Args:
is_finalized: Whether the nodes in this flow are already finalized.
"""
self.resumable = False
self.event_loop = FlowEventLoop(self)
self.counter = Flow.counter
Flow.counter += 1
self.unvisited = []
self.visited = set()
self._stop_at_node_generation = None
self.node_tasks = {}
self.running_nodes = set()
self.resolution_queue = AsyncSetQueue()
self._defaulted_inputs = defaultdict(list)
[docs]
def set_node_status(
self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], status: NodeTaskStatus.Type
) -> None:
"""
Update the status of a node and notify instrumentation.
Args:
node: The node whose status is being updated
status: The new status to set
"""
old_status = self.node_tasks[node].status
get_current_flow_instrument().on_node_status_change(self, node, old_status, status)
self.node_tasks[node] = self.node_tasks[node]._replace(status=status)
if isinstance(status, NodeTaskStatus.Running):
self.running_nodes.add(node)
elif node in self.running_nodes:
self.running_nodes.remove(node)
[docs]
async def _terminate_if_reached_limit(self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]):
"""
Check if a node has reached its generation limit and terminate if so.
Args:
node: The node to check
Raises:
TerminateLimitReached: If the node reached its generation limit
"""
if isinstance(self._stop_at_node_generation, dict):
stop_generation = self._stop_at_node_generation.get(node, ())
else:
stop_generation = self._stop_at_node_generation
if cmp_generation(node.generation, stop_generation) >= 0:
get_current_flow_instrument().on_node_generation_limit(self, node, stop_generation)
await _terminate_reached_limit()
[docs]
async def _handle_coroutine_node(
self,
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
returned: Awaitable[tuple[object, ...]],
):
"""
Handle a node that returns a coroutine (single output).
This awaits the result of the node's coroutine and stores the
result in the node's data.
Args:
node: The node to handle
returned: The coroutine returned by the node's call
"""
# this is already part of run_level 0 lifecyce instrumentation context
# in evaluate_node
result = await returned
# TODO: wait for barrier0
node.push_data(result, 0)
# TODO: set count for barrier 0
get_current_flow_instrument().on_node_emitted_data(self, node, result, 0)
[docs]
async def _handle_async_generator_node(
self,
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
returned: AsyncGenerator[tuple[object, ...], None],
):
"""
Handle a node that returns an async generator (streaming output).
This processes each yielded item from the generator, storing them
as run level 1 data, and accumulates them for the final run level 0
result when the generator completes.
Args:
node: The node to handle
returned: The async generator returned by the node's call
"""
acc: tuple[object, ...] | None = None
try:
while True:
# already part of run_level 0 lifecycle
with get_current_flow_instrument().node_lifecycle(self, node, run_level=1):
result = await anext(returned)
if acc is None:
acc = result
else:
try:
acc = tuple(node._draft_node.accumulate_streamed_data(acc, result))
except NotImplementedError:
acc = None
# wait for the last output data to have been read before overwriting.
with get_current_flow_instrument().on_barrier_node_write(self, node, result, 1):
await node._barrier1.wait()
node.push_data(result, 1)
# remember how many times output data must be read
node._barrier1.set_count(len(node.get_output_nodes_by_run_level(1)))
get_current_flow_instrument().on_node_emitted_data(self, node, result, 1)
await self._terminate_if_reached_limit(node)
await self._enqueue_output_nodes(node)
await _wait_for_start_next_generation(node, 1)
except StopAsyncIteration:
# normal implicit streaming node return
# acc is a run level 0 value
# TODO: wait for barrier0
if acc is None:
node.push_data(None, 0)
else:
node.push_data(acc, 0)
# TODO: set barrier0
get_current_flow_instrument().on_node_emitted_data(self, node, acc, 0)
except Exception as e:
# python reraises any exception raised in the async generator as RuntimeError
# `Exception.__cause__` is the original exception
if isinstance(e.__cause__, StopAsyncIteration):
# completion with explicit `raise StopAsyncIteration("final value")`
if not isinstance(e.__cause__.args[0], tuple):
raise ValueError(
(
"The final value of a node async generator must ",
f"be a tuple. Got: {e.__cause__.args[0]}. If ",
"you use the @node.tuple decorator you are ",
"responsible for wrapping the final value in ",
"a tuple.",
)
)
data: tuple[object, ...] = e.__cause__.args[0]
# TODO: wait for barrier0
node.push_data(data, 0)
# TODO: set barrier0
get_current_flow_instrument().on_node_emitted_data(self, node, data, 0)
else:
raise
[docs]
@log_async
async def evaluate_node(self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]) -> Never:
"""
The persistent task that evaluates a node.
This is the main execution function for a node. It:
1. Waits for the node to be ready to run
2. Gathers inputs and handles defaulted values
3. Calls the node with its inputs
4. Processes the result (either coroutine or async generator)
5. Propagates outputs to dependent nodes
6. Repeats
Args:
node: The node to evaluate
Returns:
Never returns; runs as a persistent coroutine
Raises:
NotImplementedError: If the node does not return a coroutine or async generator
"""
while True:
await _wait_for_start_next_generation(node, 0)
with get_current_flow_instrument().node_lifecycle(self, node, run_level=0):
positional_arg_values, defaulted_inputs = node.gather_inputs()
await node.count_down_upstream_latches(defaulted_inputs)
try:
self.set_defaulted_inputs(node, defaulted_inputs)
returned = node.call(*positional_arg_values)
# make sure the user used async def.
if not isinstance(returned, (Coroutine, AsyncGenerator)):
raise NotImplementedError(
"Node must be a coroutine (async def) or an AsyncGenerator (async def with yield)"
)
if isinstance(returned, Coroutine):
await self._handle_coroutine_node(node, returned)
else:
await self._handle_async_generator_node(node, returned)
except Exception as e:
get_current_flow_instrument().on_node_error(self, node, e)
# if self.node_unhandled_exception_terminates:
await _terminate_with_exception(node, e)
finally:
self.clear_defaulted_inputs(node)
await self._terminate_if_reached_limit(node)
await self._enqueue_output_nodes(node)
[docs]
def add_node(self, node: FinalizedNode[Unpack[tuple[Any, ...]], tuple[Any, ...]]):
"""
Add a node to the flow.
Args:
node: The node to add
"""
if node in self.unvisited:
return
get_current_flow_instrument().on_node_registered(self, node)
self.unvisited.append(node)
self._register_node(node)
[docs]
def _register_node(self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]):
"""
Register a node's task with the flow.
This creates the persistent task for the node and adds it to the node_tasks dictionary.
Args:
node: The node to register
"""
task: RawTask[Command, object, Never] = self.evaluate_node(node)
# prime the coroutine. I choose to structure the evaluate_node while loop this way so
# it needs to be primed once to get rid of the unawaited coroutine warning
_ = task.send(None)
self.node_tasks[node] = NodeTaskAndStatus(task, NodeTaskStatus.Ready())
[docs]
def _mark_node_as_visited(self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]):
"""
Mark a node as visited during the resolution process.
Args:
node: The node to mark as visited
"""
get_current_flow_instrument().on_node_visited(self, node)
if node in self.unvisited:
# this proves that the node is connected to the graph
self.unvisited.remove(node)
self.visited.add(node)
elif node not in self.visited:
# current node has not been registered by .add_node()
self.visited.add(node)
self._register_node(node)
[docs]
def add_nodes(self, nodes: list[FinalizedNode[Unpack[tuple[Any, ...]], tuple[Any, ...]]]):
"""
Add multiple nodes to the flow.
Args:
nodes: The nodes to add
"""
for node in nodes:
self.add_node(node)
[docs]
async def _enqueue_output_nodes(self, out_node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]):
"""
Enqueue all nodes that depend on the given node.
Args:
out_node: The node whose dependents should be enqueued
"""
output_nodes = out_node.get_output_nodes()
if not self.resolution_queue.closed:
for out_node in output_nodes:
get_current_flow_instrument().on_resolution_queue_put(self, out_node)
await self.resolution_queue.putAll(output_nodes)
[docs]
async def _enqueue_node(self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]):
"""
Enqueue a single node for resolution.
Args:
node: The node to enqueue
"""
get_current_flow_instrument().on_resolution_queue_put(self, node)
await self.resolution_queue.put(node)
[docs]
def run_until_complete(
self,
stop_at_node_generation: (
dict[
FinalizedNode[Unpack[tuple[Any, ...]], tuple[Any, ...]]
| DraftNode[Unpack[tuple[Any, ...]], tuple[Any, ...]],
Generation,
]
| Generation
) = (),
terminate_on_node_error: bool = False,
_debug_max_wait_time: float | None = None,
):
"""
Execute the flow until completion or until a termination condition is met.
This is the main entry point for running a flow. It starts the resolution
process and runs until all nodes have completed or a termination condition
(like reaching a generation limit or an error) is met.
Args:
stop_at_node_generation: Generation limit for nodes, either as a global
limit or as a dict mapping nodes to their individual limits
terminate_on_node_error: Whether to terminate the flow if a node raises an exception
_debug_max_wait_time: Maximum time in seconds to wait for I/O operations
(useful for debugging)
Raises:
Exception: Any exception raised by nodes and not caught
TerminateLimitReached: When a node reaches its generation limit
"""
self.event_loop.run_until_complete(
self._node_resolve_loop(
cast(
dict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], Generation] | Generation,
stop_at_node_generation,
),
terminate_on_node_error,
),
join=True,
_debug_max_wait_time=_debug_max_wait_time,
)
[docs]
@log_async
async def _node_resolve_loop(
self,
stop_at_node_generation: (
dict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], Generation] | Generation
),
terminate_on_node_error: bool,
):
"""
Main resolution loop for the flow.
This function implements the core algorithm for resolving node dependencies
and executing nodes in the correct order. It:
1. Picks an initial node
2. For each node in the resolution queue:
a. Finds the set of nodes that must be executed first
b. Marks those nodes as visited
c. Resumes their execution
3. Continues until the resolution queue is empty
Args:
stop_at_node_generation: Generation limit for nodes
terminate_on_node_error: Whether to terminate on node errors
"""
get_current_flow_instrument().on_flow_start(self)
self._stop_at_node_generation = stop_at_node_generation
if not self.unvisited:
logger.warning("No nodes to run.")
initial_node = self.unvisited.pop(0)
get_current_flow_instrument().on_resolution_queue_put(self, initial_node)
await self.resolution_queue.put(initial_node)
# blocks until a node is available or the queue is closed
async for current_node in self.resolution_queue:
get_current_flow_instrument().on_resolution_queue_get(self, current_node)
solution_nodes = self._find_node_solution(current_node)
get_current_flow_instrument().on_solving_nodes(self, current_node, solution_nodes)
for leaf_node in solution_nodes:
self._mark_node_as_visited(leaf_node)
await _resume_node(leaf_node)
# self.event_loop.clean_up()
get_current_flow_instrument().on_flow_end(self)
[docs]
def _find_node_solution(
self, node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
) -> list[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]]:
"""
Find the nodes that are ultimately preventing the given node from running.
This method is key to Flowno's cycle resolution algorithm. It:
1. Builds a condensed graph of strongly connected components (SCCs)
2. Finds the leaf SCCs in this condensed graph
3. For each leaf SCC, picks a node to force evaluate based on default values
Args:
node: The node whose dependencies need to be resolved
Returns:
A list of nodes that should be forced to evaluate to unblock the given node
Raises:
MissingDefaultError: If a cycle is detected with no default values to break it
"""
supernode_root = self._condensed_tree(node)
condensed_mermaid = supernode_root.generate_mermaid_charts_for_condensed_graph()
nodes_to_force_evaluate: list[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]] = []
for supernode in self._find_leaf_supernodes(supernode_root):
nodes_to_force_evaluate.append(self._pick_node_to_force_evaluate(supernode))
return nodes_to_force_evaluate
[docs]
def _condensed_tree(self, head: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]) -> SuperNode:
"""
Build a condensed graph of strongly connected components (SCCs) from stale connections.
This method implements Tarjan's algorithm to find strongly connected components
(cycles) in the dependency graph, but only following connections that are "stale"
(where the input's generation is <= the node's generation).
Args:
head: The starting point for building the condensed graph
Returns:
A SuperNode representing the root of the condensed graph
"""
visited: set[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]] = set()
current_scc_stack: list[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]] = []
on_stack: set[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]] = set()
id_counter = 0
ids: dict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], int] = {}
low_links: dict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], int] = {}
all_sccs: list[SuperNode] = []
scc_for_node: dict[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], SuperNode] = {}
def get_subgraph_edges(
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
) -> Generator[FinalizedInputPort[object], None, None]:
"""
Return the inputs (edges) from `node` to its upstream dependencies that
belong in the stale subgraph.
1) Gather all inputs that are stale according to
get_inputs_with_le_generation_clipped_to_minimum_run_level().
2) If the node is stalled, we only yield the single stalled input
(if and only if it is also stale and not defaulted).
3) Otherwise, we yield all stale, non-defaulted inputs.
"""
# 1) Collect all stale inputs
stale_inputs = node.get_inputs_with_le_generation_clipped_to_minimum_run_level()
# 2) Check node's status
match self.node_tasks[node].status:
case NodeTaskStatus.Stalled(stalling_input):
# logger.debug(f"{node} is stalled on input port {stalling_input.port_index}")
assert stalling_input.node == node
# Grab exactly that one input port:
single_port = node._input_ports[stalling_input.port_index]
# Only yield it if:
# - it's in the stale set
# - it's not defaulted
if single_port in stale_inputs and not self.is_input_defaulted(node, single_port.port_index):
yield single_port
case _:
# 3) Normal case: yield all stale, non-defaulted inputs
for port in stale_inputs:
if self.is_input_defaulted(node, port.port_index):
continue
yield port
def tarjan_dfs(v: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]):
"""
Tarjan's algorithm for finding strongly connected components.
This is a depth-first search that identifies strongly connected
components (cycles) in the graph.
Args:
v: The current node being processed
"""
nonlocal id_counter
ids[v] = low_links[v] = id_counter
id_counter += 1
current_scc_stack.append(v)
on_stack.add(v)
visited.add(v)
for v_input_ports in get_subgraph_edges(v):
if v_input_ports.connected_output is None:
continue
dependency: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]] = (
v_input_ports.connected_output.node
)
if dependency not in visited:
tarjan_dfs(dependency)
low_links[v] = min(low_links[v], low_links[dependency])
elif dependency in on_stack:
low_links[v] = min(low_links[v], ids[dependency])
if low_links[v] == ids[v]:
scc_nodes: set[FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]] = set()
while True:
w = current_scc_stack.pop()
on_stack.remove(w)
scc_nodes.add(w)
if w == v:
break
members_dict = {
node: [
port.port_index
for port in get_subgraph_edges(node)
if port.connected_output and port.connected_output.node in scc_nodes
]
for node in scc_nodes
}
super_node = SuperNode(head=v, members=members_dict, dependencies=[])
for member in scc_nodes:
scc_for_node[member] = super_node
all_sccs.append(super_node)
tarjan_dfs(head)
# build the condensed graph
for super_node in all_sccs:
for member in super_node.members:
for port in get_subgraph_edges(member):
if not port.connected_output:
continue
dependency: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]] = (
port.connected_output.node
)
if scc_for_node[dependency] != super_node:
super_node.dependencies.append(scc_for_node[dependency])
scc_for_node[dependency].dependent = super_node
return scc_for_node[head]
[docs]
def _find_leaf_supernodes(self, root: SuperNode) -> list[SuperNode]:
"""
Identify all leaf supernodes in the condensed DAG.
Leaf supernodes are those with no dependencies.
Returns:
list[SuperNode]: A list of all leaf supernodes in the graph.
"""
final_leaves: list[SuperNode] = []
def dfs(current: SuperNode):
if not current.dependencies:
final_leaves.append(current)
return
for dep in current.dependencies:
dfs(dep)
dfs(root)
return final_leaves
[docs]
def _pick_node_to_force_evaluate(
self, leaf_supernode: SuperNode
) -> "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]":
"""Pick a node to force evaluate according to the cycle breaking heuristic.
Args:
leaf_supernode (SuperNode): The leaf Super-Node of the Condensed subgraph.
Returns:
FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]: The node to force evaluate.
Undefined Behavior:
If the argument is not a leaf in the condensed graph, the behavior is undefined.
"""
for node, input_ports in leaf_supernode.members.items():
if all(node.has_default_for_input(input_port) for input_port in input_ports):
return node
raise MissingDefaultError(leaf_supernode)
@override
def __repr__(self):
return f"<Flow#{self.counter}>"
[docs]
class FlowEventLoop(EventLoop):
def __init__(self, flow: Flow):
super().__init__()
self.flow = flow
[docs]
@override
def _handle_command(
self,
current_task_packet: TaskHandlePacket[Command, Any, Any, Exception],
command: Command,
) -> bool:
if super()._handle_command(current_task_packet, command):
return True
if isinstance(command, WaitForStartNextGenerationCommand):
node = command.node
self.flow.set_node_status(node, NodeTaskStatus.Ready())
if not self.flow.running_nodes and not self.flow.resolution_queue:
# close the resolution queue, allowing the main loop to exit
# we can't await the .close() method because we are outside a coroutine
self.handle_queue_close(queue=self.flow.resolution_queue)
elif isinstance(command, TerminateWithExceptionCommand):
node = command.node
self.flow.set_node_status(node, NodeTaskStatus.Error())
if not self.flow.running_nodes and not self.flow.resolution_queue:
# close the resolution queue, allowing the main loop to exit
# we can't await the .close() method because we are outside a coroutine
self.handle_queue_close(queue=self.flow.resolution_queue)
raise command.exception
elif isinstance(command, TerminateReachedLimitCommand):
raise TerminateLimitReached()
elif isinstance(command, ResumeNodeCommand):
node = command.node
current_task = current_task_packet[0]
if node not in self.flow.running_nodes:
self.flow.set_node_status(node, NodeTaskStatus.Running())
self.tasks.append((self.flow.node_tasks[node][0], None, None))
self.tasks.append((current_task, None, None))
else:
self.tasks.append((current_task, None, None))
elif isinstance(command, StalledNodeRequestCommand):
stalled_input = command.stalled_input
stalling_node = command.stalling_node
self.flow.set_node_status(stalled_input.node, NodeTaskStatus.Stalled(stalled_input))
get_current_flow_instrument().on_node_stalled(self.flow, stalling_node, stalled_input)
self.tasks.insert(0, (self.flow._enqueue_node(stalling_node), None, None))
else:
return False
return True