"""
Core node types and base classes for the Flowno dataflow system.
This module defines the fundamental building blocks of the dataflow graph:
- DraftNode: Base class for all nodes before finalization
- FinalizedNode: Runtime node with active connections
- Stream: For handling streaming data between nodes
- Various port and connection types for graph construction
These classes are primarily internal implementation details.
"""
import inspect
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import AsyncGenerator, AsyncIterator, Coroutine, Generator
from dataclasses import dataclass, field
from functools import cmp_to_key
from types import CodeType, coroutine
from typing import (
Any,
ClassVar,
Generic,
NamedTuple,
TypeAlias,
TypeVar,
cast,
final,
overload,
)
from flowno.core.event_loop.commands import Command
from flowno.core.event_loop.synchronization import CountdownLatch
from flowno.core.flow.instrumentation import get_current_flow_instrument
from flowno.core.types import (
DataGeneration,
Generation,
InputPortIndex,
OutputPortIndex,
RunLevel,
)
from flowno.utilities.helpers import (
clip_generation,
cmp_generation,
inc_generation,
parent_generation,
stitched_generation,
)
from typing_extensions import TypeVarTuple, Unpack, deprecated, override
logger = logging.getLogger(__name__)
_T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts")
#: The return type of a single output or multiple output node.
ReturnTupleT_co = TypeVar("ReturnTupleT_co", covariant=True, bound=tuple[object, ...])
_InputType = TypeVar("_InputType")
_ReturnT = TypeVar("_ReturnT")
_Tout = TypeVar("_Tout", covariant=True)
ObjectFinalizedNode: TypeAlias = "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]"
[docs]
@dataclass(eq=False)
class SuperNode:
"""
The SuperNode is a set of nodes that are strongly connected to each
other. It is used by the cycle detection algorithm.
"""
head: "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]"
members: dict[
"FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]",
list[InputPortIndex],
] # used to track the internal connections of the super-node
dependent: "SuperNode | None" = None # used to build the condensed graph/forest
dependencies: list["SuperNode"] = field(default_factory=list) # used to build the condensed graph/forest
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, SuperNode):
return NotImplemented
return self.head == other.head
@override
def __hash__(self) -> int:
return id(self.head)
[docs]
@deprecated("Not used")
def generate_mermaid_charts_for_condensed_graph(self) -> str:
"""
Generates two ouputs:
1) A mermaid diagram of the condensed graph.
2) A dictionary mapping each supernode to a mermaid diagram of its internal members.
"""
# 1. Gather all supernodes in the DAG starting from root
all_supernodes = self.gather_supernodes()
# 2. Assign an ID to each supernode for the diagram
# Example: S0, S1, S2, ...
supernode_ids = {snode: f"S{idx}" for idx, snode in enumerate(all_supernodes)}
# 3. Build one big mermaid graph for the condensed DAG
condensed_mermaid = SuperNode.build_condensed_mermaid(all_supernodes, supernode_ids)
return condensed_mermaid
[docs]
def gather_supernodes(self) -> list["SuperNode"]:
"""
Simple DFS to collect every SuperNode reachable from `root`.
"""
visited: set["SuperNode"] = set()
stack: list["SuperNode"] = [self]
while stack:
current = stack.pop()
if current not in visited:
visited.add(current)
# Each supernode can have multiple dependencies
for dep in current.dependencies:
if dep not in visited:
stack.append(dep)
return list(visited)
[docs]
@staticmethod
@deprecated("Not used")
def build_condensed_mermaid(all_snodes: list["SuperNode"], supernode_ids: dict["SuperNode", str]) -> str:
"""
Produce a single Mermaid diagram that shows all SuperNodes as nodes
and draws edges for each supernode.dependencies entry.
"""
lines = [
"```mermaid",
"---",
"title: Condensed Graph of Considered Stale Connections",
"---",
"flowchart LR",
]
for snode in all_snodes:
for dep in snode.dependencies:
lines.append(f" {supernode_ids[dep]} --> {supernode_ids[snode]}")
for snode in all_snodes:
lines.append(f" subgraph {supernode_ids[snode]}")
for node in snode.members.keys():
lines.append(f" {_node_id(node)}[{_node_label(node)}]")
for input_port_index in node._input_ports:
inport = node._input_ports[input_port_index]
upstream_connection = inport.connected_output
if upstream_connection is None:
continue
upstream_node = upstream_connection.node
if upstream_node in snode.members:
lines.append(f" {_node_id(node)} --> {_node_id(upstream_node)}")
lines.append(" end")
lines.append("```")
return "\n".join(lines)
[docs]
@deprecated("Not used")
def _node_id(
node: "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]",
) -> str:
"""
Return a stable identifier for a node to use in Mermaid.
"""
return f"N_{node._draft_node.__class__.__name__}{node._instance_id}"
[docs]
@deprecated("Not used")
def _node_label(
node: "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]",
) -> str:
"""
Return a short label to display for each node.
"""
return f'"{node}"'
[docs]
class OriginalCall(NamedTuple):
call_signature: inspect.Signature
call_code: CodeType
func_name: str
class_name: str | None = None
@property
def signature(self) -> inspect.Signature:
return self.call_signature
@property
def filename(self) -> str:
return self.call_code.co_filename
@property
def lineno(self) -> int:
return self.call_code.co_firstlineno
[docs]
class DraftNode(ABC, Generic[Unpack[_Ts], ReturnTupleT_co]):
"""Abstract Base class for all connectable draft nodes in the flow graph.
DraftNode subclass instance represents a node that can be connected to
other nodes to form a computational graph. It only handles input/output
connections. The node must be wrapped in a FinalizedNode to be used in a
running flow.
.. warning::
Do not use this class directly, use the :py:deco:`node <flowno.node>`
decorator instead
Examples:
>>> from flowno import node, FlowHDL
>>> @node
... async def add(a: int, b: int) -> int:
... return a + b
>>> with FlowHDL() as f:
... f.result = add(1, f.result)
"""
#: Used by subclasses to set the minimum run level required for each input port
_minimum_run_level: ClassVar[list[RunLevel]]
# TODO: See about removing and just reading from _original_call
#: Used by dynamically generated subclasses in the @node factory
_default_values: ClassVar[dict[int, object]]
_original_call: ClassVar[OriginalCall]
#: Used for debugging and logging
_instance_counter: ClassVar[int] = 0
_instance_id: int
def __init__(self, *args: Unpack[tuple[object, ...]]) -> None:
self._instance_id = self.increment_instance_counter()
logger.debug(f"Initializing {self.__class__.__name__}#{self._instance_id} with args {args}")
self._input_ports: defaultdict[InputPortIndex, DraftInputPort[object]] = defaultdict(DraftInputPort[object])
self._connected_output_nodes: defaultdict[
OutputPortIndex,
list[DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]]],
] = defaultdict(list)
minimum_run_level = self._minimum_run_level or [0] * len(args)
for input_port_index in (InputPortIndex(index) for index in range(len(args))):
self._input_ports[input_port_index].minimum_run_level = minimum_run_level[input_port_index]
# loop over each argument and set up the corresponding input port
for input_connection, arg in (
(
DraftInputPortRef[object](
cast(DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]], self),
InputPortIndex(index),
),
arg,
)
for index, arg in enumerate(args)
):
logger.debug(f"Initializing {input_connection} with arg {arg} (type={type(arg).__name__})")
# arg can be None, a DraftNode, a NodeOutputPlaceholder, a
# NodePlaceholder, an OutputConnection, or a constant
if arg is None:
# no connection on this input port. Passing an explicit None
# indicates that there is a port here, but we want it to be
# disconnected.
self._input_ports[input_connection.port_index].connected_output = None
if isinstance(arg, DraftNode):
# if a node is passed as an argument, connect the first output
# port to this input port
arg.output(0).connect(input_connection)
elif isinstance(arg, NodePlaceholder):
# if a node placeholder is passed as an argument, then we can't connect it yet
# we'll wait until the FlowHDL context is closed to finalize the connections
self._input_ports[input_connection.port_index].connected_output = arg.output(0)
elif isinstance(arg, DraftOutputPortRef):
# f.mynode = MyNode(f.other_node.output(0))
# an output connection is a fully specified connection to
# another node's output port. we have no idea what the type of
# the node connection is assigned to the arg. subclasses or
# @node class factory overloads can to typechecking.
cast(DraftOutputPortRef[object], arg).connect(input_connection)
elif isinstance(arg, OutputPortRefPlaceholder):
# a node output placeholder is a partially specified connection to another node's output port
self._input_ports[input_connection.port_index].connected_output = arg
else:
# otherwise, it's a constant. Wrap it in a Constant node and connect it.
constant_node = Constant[object](arg)
constant_node.output(0).connect(input_connection)
# DraftNode can be used as a subclass, where `.call()` is overridden,
# or as a decorator which builds a dynamically generated subclass.
# * A subclass with a custom `call` method will use defaults from the
# call method signature
# * The node decorator assignes default values to the
# cls._default_values class attribute
# Copy the default values from the source (_default_values or .call
# parameters) to the input port for uniform access
if self.__class__._default_values:
logger.debug(f"{self.__class__.__name__} has _default_values: {self._default_values}")
for input_port_index, default_value in ((InputPortIndex(i), v) for i, v in self._default_values.items()):
logger.debug(
f"Setting default, {default_value}, for {DraftInputPortRef(cast(DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]], self), input_port_index)}"
)
self._input_ports[input_port_index].default_value = default_value
else:
logger.debug(f"{self.__class__.__name__} has no _default_values, extracting from call method signature")
# If the _default_values class attribute is not set, use the default values from the call method signature
call_sig = inspect.signature(self.call)
call_parameters = call_sig.parameters
for index, param in enumerate(call_parameters.values()):
logger.debug(f"Checking {self} parameter {param} for default value")
default_value: object = param.default # pyright: ignore[reportAny]
if default_value is not inspect.Parameter.empty:
logger.debug(f"Setting default value for {self} input {index} to {default_value}")
self._input_ports[InputPortIndex(index)].default_value = default_value
for port_index, port in self._input_ports.items():
input_port_ref = DraftInputPortRef[object](
cast(DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]], self),
port_index,
)
logger.debug(
(
f"Initialized {input_port_ref} "
f"with connected_output={port.connected_output}, "
f"minimum_run_level={port.minimum_run_level}, "
f"default_value={port.default_value}"
)
)
def get_data(self, run_level: RunLevel = 0) -> ReturnTupleT_co:
raise NotImplementedError
def increment_instance_counter(self) -> int:
current_count = self.__class__._instance_counter
self.__class__._instance_counter += 1
return current_count
def output(self, output_port: int) -> "DraftOutputPortRef[object]":
# TODO: Validate that we are not referencing a port that does not exist. Not sure how.
return DraftOutputPortRef(
cast(DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]], self),
OutputPortIndex(output_port),
)
def input(self, port_index: int) -> "DraftInputPortRef[object]":
return DraftInputPortRef(
cast(DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]], self),
InputPortIndex(port_index),
)
@abstractmethod
def call(
self, *args: Unpack[_Ts]
) -> (
Coroutine[Any, Any, ReturnTupleT_co]
| AsyncGenerator[ReturnTupleT_co, None] # pyright: ignore[reportExplicitAny]
):
raise NotImplementedError
def accumulate_streamed_data(self, accumulator: ReturnTupleT_co, partial: ReturnTupleT_co) -> ReturnTupleT_co:
# TODO: add ability to 'override' this in stateful node defiitions
# TODO: think about allowing different types for differnt run levels
new_acc = []
for output_port, output_data in enumerate(partial):
acc_slice = accumulator[output_port]
if isinstance(acc_slice, int) and isinstance(output_data, int):
new_acc.append(acc_slice + output_data) # pyright: ignore[reportUnknownMemberType]
elif isinstance(acc_slice, float) and isinstance(output_data, float):
new_acc.append(acc_slice + output_data) # pyright: ignore[reportUnknownMemberType]
elif isinstance(acc_slice, str) and isinstance(output_data, str):
new_acc.append(acc_slice + output_data) # pyright: ignore[reportUnknownMemberType]
else:
raise NotImplementedError(
"You must implement accumulate_streamed_data for non-simple data types, or manually raise StopAsyncIteration(final_value) when the stream is complete."
)
new_acc = cast(ReturnTupleT_co, tuple(new_acc)) # pyright: ignore[reportUnknownArgumentType]
return new_acc
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}#{self._instance_id}"
[docs]
def get_output_nodes(
self,
) -> list["DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]]"]:
"""Get all nodes connected to this node's outputs."""
return [node for nodes in self._connected_output_nodes.values() for node in nodes]
[docs]
def _blank_finalized(self) -> "FinalizedNode[Unpack[_Ts], ReturnTupleT_co]":
"""Convert this draft node into a finalized node without connections."""
return FinalizedNode(
self._original_call,
self._instance_id,
dict(),
dict(),
self,
)
[docs]
class FinalizedNode(Generic[Unpack[_Ts], ReturnTupleT_co]):
"""The finalized node with valid connections to other nodes. Do not
explicitly create or subclass `FinalizedNode`.
"""
_original_call: OriginalCall
_instance_id: int
_input_ports: dict[InputPortIndex, "FinalizedInputPort[object]"]
_connected_output_nodes: dict[
OutputPortIndex,
list["FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]"],
]
_barrier0: CountdownLatch
_barrier1: CountdownLatch
_draft_node: DraftNode[Unpack[_Ts], ReturnTupleT_co]
# TODO: We only ever read the highest clipped generation data for each run
# level. Replace with a forgetful datastructure
_data: dict[DataGeneration, ReturnTupleT_co]
def __init__(
self,
original_call: OriginalCall,
instance_id: int,
input_ports: dict[InputPortIndex, "FinalizedInputPort[object]"],
connected_output_nodes: dict[
OutputPortIndex,
list["FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]"],
],
draft: DraftNode[Unpack[_Ts], ReturnTupleT_co],
):
self._original_call = original_call
self._instance_id = instance_id
self._input_ports = input_ports
self._connected_output_nodes = connected_output_nodes
self._draft_node = draft
self._barrier0 = CountdownLatch(0)
self._barrier1 = CountdownLatch(0)
self._data = dict()
def __getattr__(self, name: str) -> Any:
"""
Delegate attribute lookups to the underlying draft node if this
class does not have the attribute itself. This allows code like
`f.toggle.count` to reach the user-defined class attribute.
"""
# If FinalizedNode itself has the attribute, return it:
if name in self.__dict__:
return self.__dict__[name]
# Otherwise, delegate to the draft node instance:
return getattr(self._draft_node, name)
def __setattr__(self, name: str, value: Any) -> None:
"""
Delegate attribute writes to the underlying draft node if it's
not part of FinalizedNode's own internal attributes.
"""
# If it's one of our known fields or starts with underscore, set it here:
if name.startswith("_") or name in self.__dict__ or hasattr(type(self), name):
super().__setattr__(name, value)
else:
setattr(self._draft_node, name, value)
def output(self, output_port: int) -> "FinalizedOutputPortRef[object]":
# TODO: Validate that we are not referencing a port that does not exist. Not sure how.
return FinalizedOutputPortRef(
cast(FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], self),
OutputPortIndex(output_port),
)
def input(self, port_index: int) -> "FinalizedInputPortRef[object]":
return FinalizedInputPortRef(
cast(FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]], self),
InputPortIndex(port_index),
)
[docs]
def get_output_nodes(
self,
) -> list["FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]"]:
"""Get all nodes connected to this node's outputs."""
return [node for nodes in self._connected_output_nodes.values() for node in nodes]
[docs]
def get_output_nodes_by_run_level(self, run_level: RunLevel) -> list["FinalizedInputPort[object]"]:
"""
Returns a list of input ports (FinalizedInputPort objects) from consumer nodes that are connected
to this node, filtered by the specified run level.
Each input port is included once per connection. That is, if the same consumer is connected
multiple times (or via different output ports) with a minimum_run_level equal to the specified run_level,
it will appear once for each such connection.
Args:
run_level (RunLevel): The desired run level (for example, 0 for final data, 1 for partial/streaming updates).
Returns:
list[FinalizedInputPort[object]]: A list of input ports on consumer nodes that are connected to self and
have the specified minimum_run_level.
"""
ret: list[FinalizedInputPort[object]] = []
# Iterate over all output ports and their connected consumer nodes.
for output_port, consumer_nodes in self._connected_output_nodes.items():
# For each consumer node connected on this output port...
for consumer in consumer_nodes:
# Iterate over each input port on the consumer.
for input_port_index, consumer_input in consumer._input_ports.items():
# Skip this input port if it is not connected.
if consumer_input.connected_output is None:
continue
# Check that this input port is connected to self.
if consumer_input.connected_output.node is self:
# Check if this input port expects data at the specified run level.
if consumer_input.minimum_run_level == run_level:
ret.append(consumer_input)
return ret
def has_default_for_input(self, input_port_index: InputPortIndex) -> bool:
input_port = self._input_ports[input_port_index]
if hasattr(input_port, "default_value") and input_port.default_value is not inspect.Parameter.empty:
logger.debug(f"{self} input port {input_port_index} has a default value: {input_port.default_value}")
return True
logger.debug(f"{self} input port {input_port_index} does NOT have a default value")
return False
@property
def generation(self) -> Generation:
return max(self._data.keys(), key=cmp_to_key(cmp_generation), default=None)
[docs]
def get_data(self, run_level: RunLevel = 0) -> ReturnTupleT_co | None:
"""Get the most recent data produced by the node at the specified run level.
Args:
run_level: The run level to retrieve data for (default: 0)
Returns:
The tuple of data produced by this node, or None if no data is available
"""
if self.generation is None:
return None
if len(self.generation) - 1 > run_level:
requested_generation = list(self.generation[: run_level + 1])
requested_generation[run_level] -= 1 # truncating is data in the future
if tuple(requested_generation) not in self._data:
return None
return self._data[tuple(requested_generation)]
if len(self.generation) - 1 < run_level:
return None
return self._data[self.generation]
[docs]
async def count_down_upstream_latches(self, defaulted_inputs: list[InputPortIndex]) -> None:
"""Count down upstream node barriers for non-defaulted connected inputs."""
for input_port_index, input_port in self._input_ports.items():
if input_port.connected_output is None or input_port_index in defaulted_inputs:
continue
upstream_node = input_port.connected_output.node
await upstream_node._barrier0.count_down()
@final
def push_data(self, data: ReturnTupleT_co | None, run_level: RunLevel = 0) -> None:
logger.debug(f"push_data({self}, {data}, run_level={run_level})")
new_generation = inc_generation(self.generation, run_level)
logger.info(
f"{self} advanced to generation {new_generation} with data={data}",
extra={"tag": "flow"},
)
logger.debug(f"{self}.data[{new_generation}] = {repr(data)}")
self._data[tuple(new_generation)] = data
def __repr__(self) -> str:
name: str
if self._original_call.class_name:
name = self._original_call.class_name
else:
name = self._original_call.func_name
return f"{name}#{self._instance_id}"
[docs]
def call(
self, *args: Unpack[_Ts]
) -> (
Coroutine[Any, Any, ReturnTupleT_co]
| AsyncGenerator[ReturnTupleT_co, None] # pyright: ignore[reportExplicitAny]
):
"""Delegate call to the draft node implementation"""
return self._draft_node.call(*args)
[docs]
class Stream(Generic[_InputType], AsyncIterator[_InputType]):
"""A stream of values from one node to another.
Streams connect nodes that produce multiple values over time (run_level > 0)
to consuming nodes. They act as async iterators that yield values as they
become available.
Type Parameters:
_InputType: The type of data being streamed
"""
input: "FinalizedInputPortRef[_InputType]"
output: "FinalizedOutputPortRef[_InputType]"
def __init__(
self,
input: "FinalizedInputPortRef[_InputType]",
output: "FinalizedOutputPortRef[_InputType]",
) -> None:
super().__init__()
self.input = input
self.output = output
self._last_consumed_generation: Generation = None
self.run_level: RunLevel = 1
self._last_consumed_parent_generation: Generation = None
@override
def __aiter__(self) -> AsyncIterator[_InputType]:
get_current_flow_instrument().on_stream_start(self)
return self
@override
def __repr__(self) -> str:
return f"Stream({self.output}->{self.input}, last_consumed={self._last_consumed_generation}, last_consumed_parent={self._last_consumed_parent_generation}, run_level={self.run_level})"
@override
async def __anext__(self) -> _InputType:
logger.debug(f"calling __anext__({self})")
def get_clipped_stitched_gen():
stitch_0 = self.input.node._input_ports[self.input.port_index].stitch_level_0
return clip_generation(
stitched_generation(self.output.node.generation, stitch_0), run_level=self.run_level
)
while (
cmp_generation(
get_clipped_stitched_gen(),
self._last_consumed_generation,
)
<= 0
):
logger.debug(
(
f"{self.output.node}'s generation, "
f"when clipped {clip_generation(self.output.node.generation, run_level=self.run_level)}, "
f"and stitched {get_clipped_stitched_gen()}, "
f"is less than or equal to the last consumed generation {self._last_consumed_generation}"
f", requesting new data."
)
)
await _node_stalled(self.input, self.output.node)
logger.debug(
f"{self.output.node}'s generation, when clipped and stitched, is greater than the last consumed generation, continuing."
)
logger.debug(f"__anext__({self}): continuing")
if (
self._last_consumed_generation is not None
and parent_generation(self.output.node.generation) != self._last_consumed_parent_generation
):
logger.debug(
(
f"Parent generation changed from {self._last_consumed_parent_generation} "
f"to {parent_generation(self.output.node.generation)} "
f"indicating the stream is complete or restarted."
)
)
logger.info(f"Stream {self} is complete or restarted.", extra={"tag": "flow"})
get_current_flow_instrument().on_stream_end(self)
raise StopAsyncIteration
self._last_consumed_generation = get_clipped_stitched_gen()
self._last_consumed_parent_generation = parent_generation(self.output.node.generation)
data_tuple = self.output.node.get_data(run_level=self.run_level)
assert data_tuple is not None
data = cast(_InputType, data_tuple[self.output.port_index])
logger.debug(f"__anext__({self}) returning {repr(data)}")
logger.info(
f"Stream {self} consumed data {repr(data)}",
extra={"tag": "flow"},
)
with get_current_flow_instrument().on_barrier_node_read(self.output.node, 1):
await self.output.node._barrier1.count_down()
# Instrumentation: Stream processed next item
get_current_flow_instrument().on_stream_next(self, data)
return data
[docs]
@dataclass
class StalledNodeRequestCommand(Command):
stalled_input: "FinalizedInputPortRef[object]"
stalling_node: "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]"
@coroutine
def _node_stalled(
stalled_input: "FinalizedInputPortRef[Any]",
stalling_node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
) -> Generator["StalledNodeRequestCommand", None, None]:
logger.info(
f"{stalled_input} is stalled. Requesting new data from {stalling_node}.",
extra={"tag": "flow"},
)
return (yield StalledNodeRequestCommand(stalled_input, stalling_node))
[docs]
@dataclass
class NodePlaceholder:
"""Placeholder for a node defined on a FlowHDL context instance.
Examples:
>>> with FlowHDL() as flow:
... flow.node1 = DummyNode()
... assert isinstance(flow.node1, Node)
... assert isinstance(flow.node2, NodePlaceholder)
... print(flow.node2.name) # prints "node2"
... assert isinstance(flow.node3[OutputPortIndex(0)], NodeOutputPlaceholder)
"""
name: str
def output(self, output_port: OutputPortIndex | int) -> "OutputPortRefPlaceholder[object]":
return OutputPortRefPlaceholder[object](self, OutputPortIndex(output_port))
[docs]
@dataclass
class OutputPortRefPlaceholder(Generic[_ReturnT]):
"""Placeholder for a specific output port of a NodePlaceholder.
"""
node: NodePlaceholder
port_index: OutputPortIndex
@override
def __repr__(self) -> str:
return f"{self.node}.output({self.port_index})"
[docs]
@dataclass
class DraftOutputPortRef(Generic[_Tout]):
"""Represents the an output port on a node (the data producer).
Type Parameters:
_T : A phantom type. The `port_index` output of `node.call(...)` should be `_T`.
"""
node: DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]]
port_index: OutputPortIndex
@override
def __repr__(self) -> str:
return f"{self.node}.output({self.port_index})"
[docs]
def connect(
self,
target: "DraftInputPortRef[_Tout]",
):
"""
Wires this output to the given target's input.
Internally, record that data flows
source_node.output(port_index) -> target_node.input(port_index)
"""
logger.info(f"Connecting {self} -> {target}")
self.node._connected_output_nodes[self.port_index].append(target.node)
target.node._input_ports[target.port_index].connected_output = cast(DraftOutputPortRef[object], self)
def _finalize(
self,
final_by_draft: dict[
DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]],
FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]],
],
):
return FinalizedOutputPortRef[_Tout](node=final_by_draft[self.node], port_index=self.port_index)
[docs]
@dataclass
class FinalizedOutputPortRef(Generic[_T]):
node: FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]
port_index: OutputPortIndex
@override
def __repr__(self) -> str:
return f"{self.node}.output({self.port_index})"
[docs]
class Constant(DraftNode[(), tuple[_T]]):
"""
A node that produces a constant value of type _T.
Parameters
----------
value : _T
The constant value produced by the node.
"""
value: _T
# Since we have zero inputs, the “minimum_run_level” list is empty
_minimum_run_level: ClassVar[list[RunLevel]] = []
# No default values are needed for a constant’s inputs (it has none)
_default_values: ClassVar[dict[int, object]] = {}
# Provide a class-level OriginalCall so that FinalizedNode
# can figure out signature info, filename, lineno, etc.
_original_call: ClassVar[OriginalCall] = OriginalCall(
inspect.signature(lambda: None), # a dummy signature; or use call’s
(lambda: None).__code__, # a dummy code object; or use call’s
func_name="call",
class_name="Constant",
)
def __init__(self, value: _T) -> None:
# Our node has no inputs, so just call super().__init__() with no args
super().__init__()
self.value = value
# If you want to make _original_call match exactly your real call() method,
# you can override it here with the real code object and signature:
self.__class__._original_call = OriginalCall(
inspect.signature(self.call),
self.call.__code__,
func_name="call",
class_name=self.__class__.__name__,
)
[docs]
@override
async def call(self) -> tuple[_T]:
"""
Produce the single-value tuple containing our constant.
"""
return (self.value,)
# >>>>>>>> ChatGPT generated crap
[docs]
def build_signature_display(
oc: "OriginalCall",
required_input_ports: list["InputPortIndex"],
) -> str:
"""
Return a multi-line string that shows either:
1) ClassName.call(...) in an abbreviated form, or
2) func_name(...) in a traditional inline form
With underlines marking which parameters are missing defaults.
"""
# Pull out the parameters
parameters = list(oc.signature.parameters.items())
# Decide if 'self' is the first parameter
skip_self = len(parameters) > 0 and parameters[0][0] == "self"
# Distinguish the "class's call method" case from a normal function
if oc.class_name and oc.func_name == "call":
# e.g. MyNode.call
# (x: int, y: str=...)
line_1 = f" class {oc.class_name}:"
line_2, underline_2 = build_abbreviated_signature(parameters, required_input_ports, skip_self)
# Indent 'async def call' by four spaces
line_2 = f" async def {oc.func_name}(self, {line_2}): ..."
return f"{line_1}\n{line_2}\n{' '*(len(oc.func_name)+23)}{underline_2}"
else:
# Normal function style: func_name(x: int, y: str=...)
signature_line, param_positions = build_inline_signature(
oc.func_name, parameters, required_input_ports, skip_self
)
underline_line = build_underline_line(signature_line, param_positions)
return f" async def {signature_line}\n{' ' * 12}{underline_line}"
[docs]
def build_abbreviated_signature(
parameters: list[tuple[str, inspect.Parameter]],
required_input_ports: list["InputPortIndex"],
skip_self: bool,
) -> tuple[str, str]:
"""
Construct just the parenthesized parameter list, skipping self if needed.
Returns the line plus a line of dashes/underlines for missing-default params.
"""
# If skipping self, we offset all required_input_ports by one
# to align them with actual param indices.
# Or you can treat real_index carefully in a loop.
param_strings = []
param_positions = []
signature_text = ""
current_length = len(signature_text)
for i, (pname, pval) in enumerate(parameters):
# If skip_self and this is the 0th parameter, omit it.
if skip_self and i == 0:
continue
# Figure out the real input-port index
real_index = i if not skip_self else (i - 1)
underline_needed = real_index in required_input_ports
# Build text for this parameter
param_text = build_param_text(pname, pval)
start_pos = current_length
signature_text += param_text
end_pos = start_pos + len(param_text)
param_positions.append((start_pos, end_pos, underline_needed))
current_length += len(param_text)
# Add comma if not last visible param
if i < len(parameters) - 1:
signature_text += ", "
current_length += 2
signature_text += ""
underline_line = build_underline_line(signature_text, param_positions)
return signature_text, underline_line
[docs]
def build_inline_signature(
func_name: str,
parameters: list[tuple[str, inspect.Parameter]],
required_input_ports: list["InputPortIndex"],
skip_self: bool,
) -> tuple[str, list[tuple[int, int, bool]]]:
"""
Build something like: func_name(x: int=1, y: str)
Returns the signature string and param_positions for underlining.
"""
signature_line = f"{func_name}("
param_positions = []
current_length = len(signature_line)
for i, (pname, pval) in enumerate(parameters):
if skip_self and i == 0:
# Just show 'self' plainly or skip it entirely. Usually you'd skip it in the signature.
# If you prefer to omit it, comment out or adapt logic here.
continue
real_index = i if not skip_self else (i - 1)
underline_needed = real_index in required_input_ports
param_text = build_param_text(pname, pval)
start_pos = current_length
signature_line += param_text
end_pos = start_pos + len(param_text)
param_positions.append((start_pos, end_pos, underline_needed))
current_length = len(signature_line)
if i < len(parameters) - 1:
signature_line += ", "
current_length += 2
signature_line += ")"
return signature_line, param_positions
[docs]
def build_param_text(param_name: str, param: inspect.Parameter) -> str:
"""
Return a string representation of one parameter, e.g. `x: int=5`.
"""
annotation = param.annotation if param.annotation != inspect._empty else "Any"
if annotation == str:
annotation = "str"
elif annotation == int:
annotation = "int"
elif annotation == float:
annotation = "float"
if param.default == inspect.Parameter.empty:
return f"{param_name}: {annotation}"
else:
return f"{param_name}: {annotation} = {param.default}"
[docs]
def build_underline_line(signature_line: str, param_positions: list[tuple[int, int, bool]]) -> str:
"""
Create a line of spaces, with dashes marking the parameters that need defaults.
"""
underline_chars = [" "] * len(signature_line)
for start, end, needed in param_positions:
if needed:
for idx in range(start, end):
underline_chars[idx] = "-"
return "".join(underline_chars)
# <<<<<<<<
[docs]
class MissingDefaultError(Exception):
@overload
def __init__(
self,
node: "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]",
input_index: InputPortIndex,
) -> None: ...
@overload
def __init__(self, node: "SuperNode") -> None: ...
def __init__(
self,
node: "FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]] | SuperNode",
input_index: InputPortIndex | None = None,
):
if not isinstance(node, FinalizedNode):
assert input_index is None
missing_info: dict[
"FinalizedNode[Unpack[tuple[object, ...]], tuple[object, ...]]",
list[InputPortIndex],
] = {}
for node, internal_input_ports in node.members.items():
for internal_input_port in internal_input_ports:
if node.has_default_for_input(internal_input_port):
continue
else:
if node not in missing_info:
missing_info[node] = []
missing_info[node].append(internal_input_port)
full_message = (
"Detected a cycle without default values. You must add defaults to the indicated arguments for at least ONE of the following nodes:\n"
+ "\nOR\n".join(
format_missing_defaults(node, input_ports) for node, input_ports in missing_info.items()
)
)
super().__init__(full_message)
else:
assert input_index is not None
super().__init__(format_missing_defaults(node, [input_index]))
# Type aliases for common node types
ObjectDraftNode = DraftNode[Unpack[tuple[object, ...]], tuple[object, ...]]
AnyNode = DraftNode[Unpack[tuple[Any, ...]], tuple[Any, ...]]
__all__ = [
"DraftNode",
"ObjectDraftNode",
"AnyNode",
"NodePlaceholder",
"DraftInputPortRef",
"DraftOutputPortRef",
"Stream",
"OutputPortRefPlaceholder",
]