Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions core/framework/graph/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ class GraphSpec(BaseModel):
default_model: str = "claude-haiku-4-5-20251001"
max_tokens: int = 1024

# Cleanup LLM for JSON extraction fallback (fast/cheap model preferred)
# If not set, uses CEREBRAS_API_KEY -> cerebras/llama-3.3-70b or
# ANTHROPIC_API_KEY -> claude-3-5-haiku as fallback
cleanup_llm_model: str | None = None

# Execution limits
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
max_retries_per_node: int = 3
Expand Down Expand Up @@ -449,6 +454,44 @@ def get_incoming_edges(self, node_id: str) -> list[EdgeSpec]:
"""Get all edges entering a node."""
return [e for e in self.edges if e.target == node_id]

def detect_fan_out_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that fan-out to multiple targets.

A fan-out occurs when a node has multiple outgoing edges with the same
condition (typically ON_SUCCESS) that should execute in parallel.

Returns:
Dict mapping source_node_id -> list of parallel target_node_ids
"""
fan_outs: dict[str, list[str]] = {}
for node in self.nodes:
outgoing = self.get_outgoing_edges(node.id)
# Fan-out: multiple edges with ON_SUCCESS condition
success_edges = [
e for e in outgoing if e.condition == EdgeCondition.ON_SUCCESS
]
if len(success_edges) > 1:
fan_outs[node.id] = [e.target for e in success_edges]
return fan_outs

def detect_fan_in_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that receive from multiple sources (fan-in / convergence).

A fan-in occurs when a node has multiple incoming edges, meaning
it should wait for all predecessor branches to complete.

Returns:
Dict mapping target_node_id -> list of source_node_ids
"""
fan_ins: dict[str, list[str]] = {}
for node in self.nodes:
incoming = self.get_incoming_edges(node.id)
if len(incoming) > 1:
fan_ins[node.id] = [e.source for e in incoming]
return fan_ins

def get_entry_point(self, session_state: dict | None = None) -> str:
"""
Get the appropriate entry point based on session state.
Expand Down
Loading
Loading