diff --git a/core/examples/manual_agent.py b/core/examples/manual_agent.py index da01e2335e..c4dfd047b9 100644 --- a/core/examples/manual_agent.py +++ b/core/examples/manual_agent.py @@ -13,11 +13,12 @@ """ import asyncio -import logging -from framework.graph import Goal, NodeSpec, EdgeSpec, GraphSpec, EdgeCondition + +from framework.graph import EdgeCondition, EdgeSpec, Goal, GraphSpec, NodeSpec from framework.graph.executor import GraphExecutor from framework.runtime.core import Runtime + # 1. Define Node Logic (Pure Python Functions) def greet(name: str) -> str: """Generate a simple greeting.""" @@ -38,7 +39,7 @@ async def main(): description="Generate a friendly uppercase greeting", success_criteria=[ { - "id": "greeting_generated", + "id": "greeting_generated", "description": "Greeting produced", "metric": "custom", "target": "any" @@ -63,7 +64,7 @@ async def main(): name="Uppercaser", description="Converts greeting to uppercase", node_type="function", - function="uppercase", + function="uppercase", input_keys=["greeting"], output_keys=["final_greeting"] ) @@ -100,8 +101,8 @@ async def main(): executor.register_function("uppercaser", uppercase) # 8. Execute Agent - print(f"ā–¶ Executing agent with input: name='Alice'...") - + print("ā–¶ Executing agent with input: name='Alice'...") + result = await executor.execute( graph=graph, goal=goal, diff --git a/core/framework/__init__.py b/core/framework/__init__.py index 4c0088e8a5..4bc274eeaa 100644 --- a/core/framework/__init__.py +++ b/core/framework/__init__.py @@ -22,22 +22,22 @@ See `framework.testing` for details. """ -from framework.schemas.decision import Decision, Option, Outcome, DecisionEvaluation -from framework.schemas.run import Run, RunSummary, Problem -from framework.runtime.core import Runtime from framework.builder.query import BuilderQuery -from framework.llm import LLMProvider, AnthropicProvider -from framework.runner import AgentRunner, AgentOrchestrator +from framework.llm import AnthropicProvider, LLMProvider +from framework.runner import AgentOrchestrator, AgentRunner +from framework.runtime.core import Runtime +from framework.schemas.decision import Decision, DecisionEvaluation, Option, Outcome +from framework.schemas.run import Problem, Run, RunSummary # Testing framework from framework.testing import ( + ApprovalStatus, + DebugTool, + ErrorCategory, Test, TestResult, - TestSuiteResult, TestStorage, - ApprovalStatus, - ErrorCategory, - DebugTool, + TestSuiteResult, ) __all__ = [ diff --git a/core/framework/builder/__init__.py b/core/framework/builder/__init__.py index 7a3c4a3e09..5e17b1c526 100644 --- a/core/framework/builder/__init__.py +++ b/core/framework/builder/__init__.py @@ -2,12 +2,12 @@ from framework.builder.query import BuilderQuery from framework.builder.workflow import ( - GraphBuilder, - BuildSession, BuildPhase, - ValidationResult, + BuildSession, + GraphBuilder, TestCase, TestResult, + ValidationResult, ) __all__ = [ diff --git a/core/framework/builder/query.py b/core/framework/builder/query.py index aeffc98538..40930d5af2 100644 --- a/core/framework/builder/query.py +++ b/core/framework/builder/query.py @@ -8,12 +8,12 @@ 4. What should we change? (suggestions) """ -from typing import Any from collections import defaultdict from pathlib import Path +from typing import Any from framework.schemas.decision import Decision -from framework.schemas.run import Run, RunSummary, RunStatus +from framework.schemas.run import Run, RunStatus, RunSummary from framework.storage.backend import FileStorage @@ -476,7 +476,7 @@ def _find_differences(self, run1: Run, run2: Run) -> list[str]: ) # Find first divergence point - for i, (d1, d2) in enumerate(zip(run1.decisions, run2.decisions)): + for i, (d1, d2) in enumerate(zip(run1.decisions, run2.decisions, strict=False)): if d1.chosen_option_id != d2.chosen_option_id: differences.append( f"Diverged at decision {i}: chose '{d1.chosen_option_id}' vs '{d2.chosen_option_id}'" diff --git a/core/framework/builder/workflow.py b/core/framework/builder/workflow.py index baf1e5b5ac..29702ae2cc 100644 --- a/core/framework/builder/workflow.py +++ b/core/framework/builder/workflow.py @@ -13,16 +13,17 @@ You cannot skip steps or bypass validation. """ +from collections.abc import Callable +from datetime import datetime from enum import Enum from pathlib import Path -from datetime import datetime -from typing import Any, Callable +from typing import Any from pydantic import BaseModel, Field +from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec from framework.graph.goal import Goal from framework.graph.node import NodeSpec -from framework.graph.edge import EdgeSpec, EdgeCondition, GraphSpec class BuildPhase(str, Enum): @@ -630,69 +631,69 @@ def _generate_code(self, graph: GraphSpec) -> str: """Generate Python code for the graph.""" lines = [ '"""', - f'Generated agent: {self.session.name}', - f'Generated at: {datetime.now().isoformat()}', + f"Generated agent: {self.session.name}", + f"Generated at: {datetime.now().isoformat()}", '"""', - '', - 'from framework.graph import (', - ' Goal, SuccessCriterion, Constraint,', - ' NodeSpec, EdgeSpec, EdgeCondition,', - ')', - 'from framework.graph.edge import GraphSpec', - 'from framework.graph.goal import GoalStatus', - '', - '', - '# Goal', + "", + "from framework.graph import (", + " Goal, SuccessCriterion, Constraint,", + " NodeSpec, EdgeSpec, EdgeCondition,", + ")", + "from framework.graph.edge import GraphSpec", + "from framework.graph.goal import GoalStatus", + "", + "", + "# Goal", ] if self.session.goal: goal_json = self.session.goal.model_dump_json(indent=4) - lines.append('GOAL = Goal.model_validate_json(\'\'\'') + lines.append("GOAL = Goal.model_validate_json('''") lines.append(goal_json) lines.append("''')") else: - lines.append('GOAL = None') + lines.append("GOAL = None") lines.extend([ - '', - '', - '# Nodes', - 'NODES = [', + "", + "", + "# Nodes", + "NODES = [", ]) for node in self.session.nodes: node_json = node.model_dump_json(indent=4) - lines.append(' NodeSpec.model_validate_json(\'\'\'') + lines.append(" NodeSpec.model_validate_json('''") lines.append(node_json) lines.append(" '''),") lines.extend([ - ']', - '', - '', - '# Edges', - 'EDGES = [', + "]", + "", + "", + "# Edges", + "EDGES = [", ]) for edge in self.session.edges: edge_json = edge.model_dump_json(indent=4) - lines.append(' EdgeSpec.model_validate_json(\'\'\'') + lines.append(" EdgeSpec.model_validate_json('''") lines.append(edge_json) lines.append(" '''),") lines.extend([ - ']', - '', - '', - '# Graph', + "]", + "", + "", + "# Graph", ]) graph_json = graph.model_dump_json(indent=4) - lines.append('GRAPH = GraphSpec.model_validate_json(\'\'\'') + lines.append("GRAPH = GraphSpec.model_validate_json('''") lines.append(graph_json) lines.append("''')") - return '\n'.join(lines) + return "\n".join(lines) # ========================================================================= # SESSION MANAGEMENT diff --git a/core/framework/graph/__init__.py b/core/framework/graph/__init__.py index f01f870681..620a93b383 100644 --- a/core/framework/graph/__init__.py +++ b/core/framework/graph/__init__.py @@ -1,32 +1,32 @@ """Graph structures: Goals, Nodes, Edges, and Flexible Execution.""" -from framework.graph.goal import Goal, SuccessCriterion, Constraint, GoalStatus -from framework.graph.node import NodeSpec, NodeContext, NodeResult, NodeProtocol -from framework.graph.edge import EdgeSpec, EdgeCondition, GraphSpec +from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec +from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec from framework.graph.executor import GraphExecutor +from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor +from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion +from framework.graph.judge import HybridJudge, create_default_judge +from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec # Flexible execution (Worker-Judge pattern) from framework.graph.plan import ( - Plan, - PlanStep, ActionSpec, ActionType, - StepStatus, - Judgment, - JudgmentAction, - EvaluationRule, - PlanExecutionResult, - ExecutionStatus, - load_export, # HITL (Human-in-the-loop) ApprovalDecision, ApprovalRequest, ApprovalResult, + EvaluationRule, + ExecutionStatus, + Judgment, + JudgmentAction, + Plan, + PlanExecutionResult, + PlanStep, + StepStatus, + load_export, ) -from framework.graph.judge import HybridJudge, create_default_judge -from framework.graph.worker_node import WorkerNode, StepExecutionResult -from framework.graph.flexible_executor import FlexibleGraphExecutor, ExecutorConfig -from framework.graph.code_sandbox import CodeSandbox, safe_exec, safe_eval +from framework.graph.worker_node import StepExecutionResult, WorkerNode __all__ = [ # Goal diff --git a/core/framework/graph/code_sandbox.py b/core/framework/graph/code_sandbox.py index 28a4c231b8..ea8eb8aa47 100644 --- a/core/framework/graph/code_sandbox.py +++ b/core/framework/graph/code_sandbox.py @@ -13,11 +13,11 @@ """ import ast -import sys import signal -from typing import Any -from dataclasses import dataclass, field +import sys from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any # Safe builtins whitelist SAFE_BUILTINS = { @@ -216,7 +216,7 @@ def handler(signum, frame): raise TimeoutError(f"Code execution timed out after {seconds} seconds") # Only works on Unix-like systems - if hasattr(signal, 'SIGALRM'): + if hasattr(signal, "SIGALRM"): old_handler = signal.signal(signal.SIGALRM, handler) signal.alarm(seconds) try: diff --git a/core/framework/graph/edge.py b/core/framework/graph/edge.py index b63607dbdd..61c1d5b8a1 100644 --- a/core/framework/graph/edge.py +++ b/core/framework/graph/edge.py @@ -22,8 +22,8 @@ given the current goal, context, and execution state. """ -from typing import Any from enum import Enum +from typing import Any from pydantic import BaseModel, Field @@ -238,7 +238,7 @@ def _llm_decide( # Parse response import re - json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL) + json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL) if json_match: data = json.loads(json_match.group()) proceed = data.get("proceed", False) diff --git a/core/framework/graph/executor.py b/core/framework/graph/executor.py index dd61e7901c..f84fa6d1a6 100644 --- a/core/framework/graph/executor.py +++ b/core/framework/graph/executor.py @@ -10,25 +10,26 @@ """ import logging -from typing import Any, Callable +from collections.abc import Callable from dataclasses import dataclass, field +from typing import Any -from framework.runtime.core import Runtime +from framework.graph.edge import GraphSpec from framework.graph.goal import Goal from framework.graph.node import ( - NodeSpec, + FunctionNode, + LLMNode, NodeContext, - NodeResult, NodeProtocol, - SharedMemory, - LLMNode, + NodeResult, + NodeSpec, RouterNode, - FunctionNode, + SharedMemory, ) -from framework.graph.edge import GraphSpec +from framework.graph.output_cleaner import CleansingConfig, OutputCleaner from framework.graph.validator import OutputValidator -from framework.graph.output_cleaner import OutputCleaner, CleansingConfig from framework.llm.provider import LLMProvider, Tool +from framework.runtime.core import Runtime @dataclass diff --git a/core/framework/graph/flexible_executor.py b/core/framework/graph/flexible_executor.py index 238b127c50..f34a973410 100644 --- a/core/framework/graph/flexible_executor.py +++ b/core/framework/graph/flexible_executor.py @@ -15,28 +15,29 @@ This keeps planning external while execution/evaluation is internal. """ -from typing import Any, Callable +from collections.abc import Callable from dataclasses import dataclass from datetime import datetime +from typing import Any -from framework.runtime.core import Runtime +from framework.graph.code_sandbox import CodeSandbox from framework.graph.goal import Goal +from framework.graph.judge import HybridJudge, create_default_judge from framework.graph.plan import ( - Plan, - PlanStep, - PlanExecutionResult, + ApprovalDecision, + ApprovalRequest, + ApprovalResult, ExecutionStatus, - StepStatus, Judgment, JudgmentAction, - ApprovalRequest, - ApprovalResult, - ApprovalDecision, + Plan, + PlanExecutionResult, + PlanStep, + StepStatus, ) -from framework.graph.judge import HybridJudge, create_default_judge -from framework.graph.worker_node import WorkerNode, StepExecutionResult -from framework.graph.code_sandbox import CodeSandbox +from framework.graph.worker_node import StepExecutionResult, WorkerNode from framework.llm.provider import LLMProvider, Tool +from framework.runtime.core import Runtime # Type alias for approval callback ApprovalCallback = Callable[[ApprovalRequest], ApprovalResult] diff --git a/core/framework/graph/goal.py b/core/framework/graph/goal.py index bddf7ff72e..839a1108f9 100644 --- a/core/framework/graph/goal.py +++ b/core/framework/graph/goal.py @@ -12,8 +12,8 @@ """ from datetime import datetime -from typing import Any from enum import Enum +from typing import Any from pydantic import BaseModel, Field diff --git a/core/framework/graph/hitl.py b/core/framework/graph/hitl.py index 0f88f8f68c..d81d2da526 100644 --- a/core/framework/graph/hitl.py +++ b/core/framework/graph/hitl.py @@ -170,9 +170,10 @@ def parse_response( # Use Haiku to extract answers try: - import anthropic import json + import anthropic + questions_str = "\n".join([ f"{i+1}. {q.question} (id: {q.id})" for i, q in enumerate(request.questions) @@ -201,7 +202,7 @@ def parse_response( # Parse Haiku's response import re response_text = message.content[0].text.strip() - json_match = re.search(r'\{[^{}]*\}', response_text, re.DOTALL) + json_match = re.search(r"\{[^{}]*\}", response_text, re.DOTALL) if json_match: parsed = json.loads(json_match.group()) diff --git a/core/framework/graph/judge.py b/core/framework/graph/judge.py index ab0c69d440..5191264598 100644 --- a/core/framework/graph/judge.py +++ b/core/framework/graph/judge.py @@ -8,17 +8,17 @@ Escalation path: rules → LLM → human """ -from typing import Any from dataclasses import dataclass, field +from typing import Any +from framework.graph.code_sandbox import safe_eval +from framework.graph.goal import Goal from framework.graph.plan import ( - PlanStep, + EvaluationRule, Judgment, JudgmentAction, - EvaluationRule, + PlanStep, ) -from framework.graph.goal import Goal -from framework.graph.code_sandbox import safe_eval from framework.llm.provider import LLMProvider @@ -136,9 +136,9 @@ def _evaluate_rules( # Build evaluation context eval_context = { - "step": step.model_dump() if hasattr(step, 'model_dump') else step, + "step": step.model_dump() if hasattr(step, "model_dump") else step, "result": result, - "goal": goal.model_dump() if hasattr(goal, 'model_dump') else goal, + "goal": goal.model_dump() if hasattr(goal, "model_dump") else goal, "context": context, "success": isinstance(result, dict) and result.get("success", False), "error": isinstance(result, dict) and result.get("error"), diff --git a/core/framework/graph/node.py b/core/framework/graph/node.py index c2fd7691d4..2956037e54 100644 --- a/core/framework/graph/node.py +++ b/core/framework/graph/node.py @@ -17,13 +17,14 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable +from collections.abc import Callable from dataclasses import dataclass, field +from typing import Any from pydantic import BaseModel, Field -from framework.runtime.core import Runtime from framework.llm.provider import LLMProvider, Tool +from framework.runtime.core import Runtime logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def find_json_object(text: str) -> str | None: This handles nested objects correctly, unlike simple regex like r'\\{[^{}]*\\}'. """ - start = text.find('{') + start = text.find("{") if start == -1: return None @@ -46,7 +47,7 @@ def find_json_object(text: str) -> str | None: escape_next = False continue - if char == '\\' and in_string: + if char == "\\" and in_string: escape_next = True continue @@ -57,9 +58,9 @@ def find_json_object(text: str) -> str | None: if in_string: continue - if char == '{': + if char == "{": depth += 1 - elif char == '}': + elif char == "}": depth -= 1 if depth == 0: return text[start:i + 1] @@ -361,9 +362,10 @@ def to_summary(self, node_spec: Any = None) -> str: # Use Haiku to generate intelligent summary try: - import anthropic import json + import anthropic + node_context = "" if node_spec: node_context = f"\nNode: {node_spec.name}\nPurpose: {node_spec.description}" @@ -482,7 +484,7 @@ def _strip_code_blocks(self, content: str) -> str: import re content = content.strip() # Match ```json or ``` at start and ``` at end (greedy to handle nested) - match = re.match(r'^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$', content, re.DOTALL) + match = re.match(r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$", content, re.DOTALL) if match: return match.group(1).strip() return content @@ -532,13 +534,13 @@ async def execute(self, ctx: NodeContext) -> NodeResult: # Log the LLM call details logger.info(" šŸ¤– LLM Call:") logger.info(f" System: {system[:150]}..." if len(system) > 150 else f" System: {system}") - logger.info(f" User message: {messages[-1]['content'][:150]}..." if len(messages[-1]['content']) > 150 else f" User message: {messages[-1]['content']}") + logger.info(f" User message: {messages[-1]['content'][:150]}..." if len(messages[-1]["content"]) > 150 else f" User message: {messages[-1]['content']}") if ctx.available_tools: logger.info(f" Tools available: {[t.name for t in ctx.available_tools]}") # Call LLM if ctx.available_tools and self.tool_executor: - from framework.llm.provider import ToolUse, ToolResult + from framework.llm.provider import ToolResult, ToolUse def executor(tool_use: ToolUse) -> ToolResult: logger.info(f" šŸ”§ Tool call: {tool_use.name}({', '.join(f'{k}={v}' for k, v in tool_use.input.items())})") @@ -701,14 +703,14 @@ def _extract_json(self, raw_response: str, output_keys: list[str]) -> dict[str, if content.startswith("```"): # Try multiple patterns for markdown code blocks # Pattern 1: ```json\n...\n``` or ```\n...\n``` - match = re.search(r'^```(?:json)?\s*\n([\s\S]*?)\n```\s*$', content) + match = re.search(r"^```(?:json)?\s*\n([\s\S]*?)\n```\s*$", content) if match: content = match.group(1).strip() else: # Pattern 2: Just strip the first and last lines if they're ``` - lines = content.split('\n') - if lines[0].startswith('```') and lines[-1].strip() == '```': - content = '\n'.join(lines[1:-1]).strip() + lines = content.split("\n") + if lines[0].startswith("```") and lines[-1].strip() == "```": + content = "\n".join(lines[1:-1]).strip() parsed = json.loads(content) if isinstance(parsed, dict): @@ -718,7 +720,7 @@ def _extract_json(self, raw_response: str, output_keys: list[str]) -> dict[str, # Try to extract JSON from markdown code blocks (greedy match to handle nested blocks) # Use anchored match to capture from first ``` to last ``` - code_block_match = re.match(r'^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$', content, re.DOTALL) + code_block_match = re.match(r"^```(?:json|JSON)?\s*\n?(.*)\n?```\s*$", content, re.DOTALL) if code_block_match: try: parsed = json.loads(code_block_match.group(1).strip()) @@ -779,14 +781,14 @@ def _extract_json(self, raw_response: str, output_keys: list[str]) -> dict[str, cleaned = result.content.strip() # Remove markdown if LLM added it if cleaned.startswith("```"): - match = re.search(r'^```(?:json)?\s*\n([\s\S]*?)\n```\s*$', cleaned) + match = re.search(r"^```(?:json)?\s*\n([\s\S]*?)\n```\s*$", cleaned) if match: cleaned = match.group(1).strip() else: # Fallback: strip first/last lines - lines = cleaned.split('\n') - if lines[0].startswith('```') and lines[-1].strip() == '```': - cleaned = '\n'.join(lines[1:-1]).strip() + lines = cleaned.split("\n") + if lines[0].startswith("```") and lines[-1].strip() == "```": + cleaned = "\n".join(lines[1:-1]).strip() parsed = json.loads(cleaned) logger.info(" āœ“ LLM cleaned JSON output") diff --git a/core/framework/graph/output_cleaner.py b/core/framework/graph/output_cleaner.py index 5a2b9e3959..1bbda02024 100644 --- a/core/framework/graph/output_cleaner.py +++ b/core/framework/graph/output_cleaner.py @@ -88,9 +88,10 @@ def __init__(self, config: CleansingConfig, llm_provider=None): elif config.enabled: # Create dedicated fast LLM provider for cleaning try: - from framework.llm.litellm import LiteLLMProvider import os + from framework.llm.litellm import LiteLLMProvider + api_key = os.environ.get("CEREBRAS_API_KEY") if api_key: self.llm = LiteLLMProvider( @@ -318,7 +319,7 @@ def _build_schema_description(self, node_spec: Any) -> str: line = f' "{key}": {type_hint}' if description: - line += f' // {description}' + line += f" // {description}" if required: line += " (required)" lines.append(line + ",") diff --git a/core/framework/graph/plan.py b/core/framework/graph/plan.py index 81515ceab4..f5144cc791 100644 --- a/core/framework/graph/plan.py +++ b/core/framework/graph/plan.py @@ -10,9 +10,9 @@ - If replanning needed, returns feedback to external planner """ -from typing import Any -from enum import Enum from datetime import datetime +from enum import Enum +from typing import Any from pydantic import BaseModel, Field @@ -421,6 +421,7 @@ def load_export(data: str | dict) -> tuple["Plan", Any]: result = await executor.execute_plan(plan, goal, context) """ import json as json_module + from framework.graph.goal import Goal if isinstance(data, str): diff --git a/core/framework/graph/safe_eval.py b/core/framework/graph/safe_eval.py index 079460efc6..a3199f3cfc 100644 --- a/core/framework/graph/safe_eval.py +++ b/core/framework/graph/safe_eval.py @@ -1,6 +1,6 @@ import ast import operator -from typing import Any, Container, Dict, Optional +from typing import Any # Safe operators whitelist SAFE_OPERATORS = { @@ -53,7 +53,7 @@ } class SafeEvalVisitor(ast.NodeVisitor): - def __init__(self, context: Dict[str, Any]): + def __init__(self, context: dict[str, Any]): self.context = context def visit(self, node: ast.AST) -> Any: @@ -80,7 +80,7 @@ def visit_Num(self, node: ast.Num) -> Any: def visit_Str(self, node: ast.Str) -> Any: return node.s - + def visit_NameConstant(self, node: ast.NameConstant) -> Any: return node.value @@ -94,10 +94,10 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple: def visit_Dict(self, node: ast.Dict) -> dict: return { self.visit(k): self.visit(v) - for k, v in zip(node.keys, node.values) + for k, v in zip(node.keys, node.values, strict=False) if k is not None } - + # --- Operations --- def visit_BinOp(self, node: ast.BinOp) -> Any: op_func = SAFE_OPERATORS.get(type(node.op)) @@ -113,7 +113,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: def visit_Compare(self, node: ast.Compare) -> Any: left = self.visit(node.left) - for op, comparator in zip(node.ops, node.comparators): + for op, comparator in zip(node.ops, node.comparators, strict=False): op_func = SAFE_OPERATORS.get(type(op)) if op_func is None: raise ValueError(f"Operator {type(op).__name__} is not allowed") @@ -157,9 +157,9 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: # STIRCT CHECK: No access to private attributes (starting with _) if node.attr.startswith("_"): raise ValueError(f"Access to private attribute '{node.attr}' is not allowed") - + val = self.visit(node.value) - + # Safe attribute access: only allow if it's in the dict (if val is dict) # or it's a safe property of a basic type? # Actually, for flexibility, people often use dot access for dicts in these expressions. @@ -168,37 +168,37 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: # If the user context provides objects, we might want to allow attribute access. # BUT we must be careful not to allow access to dangerous things like __class__ etc. # The check starts_with("_") covers __class__, __init__, etc. - + try: return getattr(val, node.attr) except AttributeError: - # Fallback: maybe it's a dict and they want dot access? + # Fallback: maybe it's a dict and they want dot access? # (Only if we want to support that sugar, usually not standard python) # Let's stick to standard python behavior + strict private check. pass - + raise AttributeError(f"Object has no attribute '{node.attr}'") def visit_Call(self, node: ast.Call) -> Any: # Only allow calling whitelisted functions func = self.visit(node.func) - + # Check if the function object itself is in our whitelist values - # This is tricky because `func` is the actual function object, + # This is tricky because `func` is the actual function object, # but we also want to verify it came from a safe place. # Easier: Check if node.func is a Name and that name is in SAFE_FUNCTIONS. - + is_safe = False if isinstance(node.func, ast.Name): if node.func.id in SAFE_FUNCTIONS: is_safe = True - - # Also allow methods on objects if they are safe? + + # Also allow methods on objects if they are safe? # E.g. "somestring".lower() or list.append() (if we allowed mutation, but we don't for now) # For now, restrict to SAFE_FUNCTIONS whitelist for global calls and deny method calls # unless we explicitly add safe methods. # Actually, allowing method calls on strings/lists (like split, join, get) is commonly needed. - + if isinstance(node.func, ast.Attribute): # Method call. # Allow basic safe methods? @@ -207,13 +207,13 @@ def visit_Call(self, node: ast.Call) -> Any: method_name = node.func.attr if method_name in ["get", "keys", "values", "items", "lower", "upper", "strip", "split"]: is_safe = True - + if not is_safe and func not in SAFE_FUNCTIONS.values(): - raise ValueError(f"Call to function/method is not allowed") + raise ValueError("Call to function/method is not allowed") args = [self.visit(arg) for arg in node.args] keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords} - + return func(*args, **keywords) def visit_Index(self, node: ast.Index) -> Any: @@ -221,32 +221,32 @@ def visit_Index(self, node: ast.Index) -> Any: return self.visit(node.value) -def safe_eval(expr: str, context: Optional[Dict[str, Any]] = None) -> Any: +def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any: """ Safely evaluate a python expression string. - + Args: expr: The expression string to evaluate. context: Dictionary of variables available in the expression. - + Returns: The result of the evaluation. - + Raises: ValueError: If unsafe operations or syntax are detected. SyntaxError: If the expression is invalid Python. """ if context is None: context = {} - + # Add safe builtins to context full_context = context.copy() full_context.update(SAFE_FUNCTIONS) - + try: - tree = ast.parse(expr, mode='eval') + tree = ast.parse(expr, mode="eval") except SyntaxError as e: raise SyntaxError(f"Invalid syntax in expression: {e}") - + visitor = SafeEvalVisitor(full_context) return visitor.visit(tree) diff --git a/core/framework/graph/test_output_cleaner_live.py b/core/framework/graph/test_output_cleaner_live.py index 0545821f49..0b456353df 100644 --- a/core/framework/graph/test_output_cleaner_live.py +++ b/core/framework/graph/test_output_cleaner_live.py @@ -6,8 +6,9 @@ import json import os -from framework.graph.output_cleaner import OutputCleaner, CleansingConfig + from framework.graph.node import NodeSpec +from framework.graph.output_cleaner import CleansingConfig, OutputCleaner from framework.llm.litellm import LiteLLMProvider diff --git a/core/framework/graph/worker_node.py b/core/framework/graph/worker_node.py index 835933db50..9eeaae390c 100644 --- a/core/framework/graph/worker_node.py +++ b/core/framework/graph/worker_node.py @@ -10,20 +10,21 @@ - Code execution (sandboxed) """ -from typing import Any, Callable -from dataclasses import dataclass, field -import time import json import re +import time +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any +from framework.graph.code_sandbox import CodeSandbox from framework.graph.plan import ( - PlanStep, ActionSpec, ActionType, + PlanStep, ) -from framework.graph.code_sandbox import CodeSandbox -from framework.runtime.core import Runtime from framework.llm.provider import LLMProvider, Tool +from framework.runtime.core import Runtime def parse_llm_json_response(text: str) -> tuple[Any | None, str]: @@ -50,7 +51,7 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]: # Try to extract JSON from markdown code blocks # Pattern: ```json ... ``` or ``` ... ``` - code_block_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```' + code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```" matches = re.findall(code_block_pattern, cleaned) if matches: @@ -70,7 +71,7 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]: pass # Try to find JSON-like content (starts with { or [) - json_start_pattern = r'(\{[\s\S]*\}|\[[\s\S]*\])' + json_start_pattern = r"(\{[\s\S]*\}|\[[\s\S]*\])" json_matches = re.findall(json_start_pattern, cleaned) for match in json_matches: diff --git a/core/framework/llm/anthropic.py b/core/framework/llm/anthropic.py index 0d37ac7078..542775fbdb 100644 --- a/core/framework/llm/anthropic.py +++ b/core/framework/llm/anthropic.py @@ -3,8 +3,8 @@ import os from typing import Any -from framework.llm.provider import LLMProvider, LLMResponse, Tool from framework.llm.litellm import LiteLLMProvider +from framework.llm.provider import LLMProvider, LLMResponse, Tool def _get_api_key_from_credential_manager() -> str | None: @@ -55,7 +55,7 @@ def __init__( ) self.model = model - + self._provider = LiteLLMProvider( model=model, api_key=self.api_key, diff --git a/core/framework/mcp/agent_builder_server.py b/core/framework/mcp/agent_builder_server.py index 6860876c02..3ed52a98b6 100644 --- a/core/framework/mcp/agent_builder_server.py +++ b/core/framework/mcp/agent_builder_server.py @@ -15,7 +15,7 @@ from mcp.server import FastMCP -from framework.graph import Goal, SuccessCriterion, Constraint, NodeSpec, EdgeSpec, EdgeCondition +from framework.graph import Constraint, EdgeCondition, EdgeSpec, Goal, NodeSpec, SuccessCriterion from framework.graph.plan import Plan # Testing framework imports @@ -23,7 +23,6 @@ PYTEST_TEST_FILE_HEADER, ) - # Initialize MCP server mcp = FastMCP("agent-builder") @@ -138,7 +137,7 @@ def _load_session(session_id: str) -> BuildSession: if not session_file.exists(): raise ValueError(f"Session '{session_id}' not found") - with open(session_file, "r") as f: + with open(session_file) as f: data = json.load(f) return BuildSession.from_dict(data) @@ -150,7 +149,7 @@ def _load_active_session() -> BuildSession | None: return None try: - with open(ACTIVE_SESSION_FILE, "r") as f: + with open(ACTIVE_SESSION_FILE) as f: session_id = f.read().strip() if session_id: @@ -201,7 +200,7 @@ def list_sessions() -> str: if SESSIONS_DIR.exists(): for session_file in SESSIONS_DIR.glob("*.json"): try: - with open(session_file, "r") as f: + with open(session_file) as f: data = json.load(f) sessions.append({ "session_id": data["session_id"], @@ -219,7 +218,7 @@ def list_sessions() -> str: active_id = None if ACTIVE_SESSION_FILE.exists(): try: - with open(ACTIVE_SESSION_FILE, "r") as f: + with open(ACTIVE_SESSION_FILE) as f: active_id = f.read().strip() except Exception: pass @@ -282,7 +281,7 @@ def delete_session(session_id: Annotated[str, "ID of the session to delete"]) -> _session = None if ACTIVE_SESSION_FILE.exists(): - with open(ACTIVE_SESSION_FILE, "r") as f: + with open(ACTIVE_SESSION_FILE) as f: active_id = f.read().strip() if active_id == session_id: ACTIVE_SESSION_FILE.unlink() @@ -918,7 +917,7 @@ def validate_graph() -> str: initial_context_keys: set[str] = set() # Compute in topological order - remaining = set(n.id for n in session.nodes) + remaining = {n.id for n in session.nodes} max_iterations = len(session.nodes) * 2 for _ in range(max_iterations): @@ -1093,7 +1092,7 @@ def _generate_readme(session: BuildSession, export_data: dict, all_tools: set) - # Build success criteria section criteria_section = [] for criterion in goal.success_criteria: - crit_dict = criterion.model_dump() if hasattr(criterion, 'model_dump') else criterion.__dict__ + crit_dict = criterion.model_dump() if hasattr(criterion, "model_dump") else criterion.__dict__ criteria_section.append( f"**{crit_dict.get('description', 'N/A')}** (weight {crit_dict.get('weight', 1.0)})\n" f"- Metric: {crit_dict.get('metric', 'N/A')}\n" @@ -1103,7 +1102,7 @@ def _generate_readme(session: BuildSession, export_data: dict, all_tools: set) - # Build constraints section constraints_section = [] for constraint in goal.constraints: - const_dict = constraint.model_dump() if hasattr(constraint, 'model_dump') else constraint.__dict__ + const_dict = constraint.model_dump() if hasattr(constraint, "model_dump") else constraint.__dict__ constraints_section.append( f"**{const_dict.get('description', 'N/A')}** ({const_dict.get('constraint_type', 'hard')})\n" f"- Category: {const_dict.get('category', 'N/A')}" @@ -1268,7 +1267,7 @@ def export_graph() -> str: # Strategy 2: Fallback - pair sequentially if no match found unmatched_pause = [p for p in pause_nodes if p not in pause_to_resume] unmatched_resume = [r for r in resume_entry_points if r not in pause_to_resume.values()] - for pause_id, resume_id in zip(unmatched_pause, unmatched_resume): + for pause_id, resume_id in zip(unmatched_pause, unmatched_resume, strict=False): pause_to_resume[pause_id] = resume_id # Build entry_points dict @@ -1354,10 +1353,10 @@ def export_graph() -> str: } # Add enrichment if present in goal - if hasattr(session.goal, 'success_criteria'): + if hasattr(session.goal, "success_criteria"): enriched_criteria = [] for criterion in session.goal.success_criteria: - crit_dict = criterion.model_dump() if hasattr(criterion, 'model_dump') else criterion + crit_dict = criterion.model_dump() if hasattr(criterion, "model_dump") else criterion enriched_criteria.append(crit_dict) export_data["goal"]["success_criteria"] = enriched_criteria @@ -2567,8 +2566,8 @@ def run_tests( By default, tests run in parallel using pytest-xdist with auto-detected worker count. Returns pass/fail summary with detailed results parsed from pytest output. """ - import subprocess import re + import subprocess tests_dir = Path(agent_path) / "tests" @@ -2740,8 +2739,8 @@ def debug_test( Re-runs the test with pytest -vvs to capture full output. Returns detailed failure information and suggestions. """ - import subprocess import re + import subprocess # Derive agent_path from session if not provided if not agent_path and _session: diff --git a/core/framework/runner/__init__.py b/core/framework/runner/__init__.py index c7c24f4db5..a3e4cac458 100644 --- a/core/framework/runner/__init__.py +++ b/core/framework/runner/__init__.py @@ -1,15 +1,15 @@ """Agent Runner - load and run exported agents.""" -from framework.runner.runner import AgentRunner, AgentInfo, ValidationResult -from framework.runner.tool_registry import ToolRegistry, tool from framework.runner.orchestrator import AgentOrchestrator from framework.runner.protocol import ( AgentMessage, - MessageType, CapabilityLevel, CapabilityResponse, + MessageType, OrchestratorResult, ) +from framework.runner.runner import AgentInfo, AgentRunner, ValidationResult +from framework.runner.tool_registry import ToolRegistry, tool __all__ = [ # Single agent diff --git a/core/framework/runner/cli.py b/core/framework/runner/cli.py index 03f091735a..14f0be5917 100644 --- a/core/framework/runner/cli.py +++ b/core/framework/runner/cli.py @@ -170,15 +170,16 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None: def cmd_run(args: argparse.Namespace) -> int: """Run an exported agent.""" import logging + from framework.runner import AgentRunner # Set logging level (quiet by default for cleaner output) if args.quiet: - logging.basicConfig(level=logging.ERROR, format='%(message)s') - elif getattr(args, 'verbose', False): - logging.basicConfig(level=logging.INFO, format='%(message)s') + logging.basicConfig(level=logging.ERROR, format="%(message)s") + elif getattr(args, "verbose", False): + logging.basicConfig(level=logging.INFO, format="%(message)s") else: - logging.basicConfig(level=logging.WARNING, format='%(message)s') + logging.basicConfig(level=logging.WARNING, format="%(message)s") # Load input context context = {} @@ -333,8 +334,8 @@ def cmd_info(args: argparse.Namespace) -> int: print() print(f"Nodes ({info.node_count}):") for node in info.nodes: - inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get('input_keys') else "" - outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get('output_keys') else "" + inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get("input_keys") else "" + outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else "" print(f" - {node['id']}: {node['name']}{inputs}{outputs}") print() print(f"Success Criteria ({len(info.success_criteria)}):") @@ -540,7 +541,7 @@ def cmd_dispatch(args: argparse.Namespace) -> int: def _interactive_approval(request): """Interactive approval callback for HITL mode.""" - from framework.graph import ApprovalResult, ApprovalDecision + from framework.graph import ApprovalDecision, ApprovalResult print() print("=" * 60) @@ -607,9 +608,10 @@ def _interactive_approval(request): def _format_natural_language_to_json(user_input: str, input_keys: list[str], agent_description: str, session_context: dict = None) -> dict: """Use Haiku to convert natural language input to JSON based on agent's input schema.""" - import anthropic import os + import anthropic + client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) # Build prompt for Haiku @@ -619,7 +621,7 @@ def _format_natural_language_to_json(user_input: str, input_keys: list[str], age main_field = input_keys[0] if input_keys else "objective" existing_value = session_context.get(main_field, "") - session_info = f"\n\nExisting {main_field}: \"{existing_value}\"\n\nThe user is providing ADDITIONAL information. Append this new information to the existing {main_field} to create an enriched, more detailed version." + session_info = f'\n\nExisting {main_field}: "{existing_value}"\n\nThe user is providing ADDITIONAL information. Append this new information to the existing {main_field} to create an enriched, more detailed version.' prompt = f"""You are formatting user input for an agent that requires specific input fields. @@ -661,12 +663,13 @@ def _format_natural_language_to_json(user_input: str, input_keys: list[str], age def cmd_shell(args: argparse.Namespace) -> int: """Start an interactive agent session.""" import logging + from framework.runner import AgentRunner # Configure logging to show runtime visibility logging.basicConfig( level=logging.INFO, - format='%(message)s', # Simple format for clean output + format="%(message)s", # Simple format for clean output ) agents_dir = Path(args.agents_dir) @@ -690,7 +693,7 @@ def cmd_shell(args: argparse.Namespace) -> int: return 1 # Set up approval callback by default (unless --no-approve is set) - if not getattr(args, 'no_approve', False): + if not getattr(args, "no_approve", False): runner.set_approval_callback(_interactive_approval) print("\nšŸ”” Human-in-the-loop mode enabled") print(" Steps marked for approval will pause for your review") @@ -748,8 +751,8 @@ def cmd_shell(args: argparse.Namespace) -> int: if user_input == "/nodes": print("\nAgent nodes:") for node in info.nodes: - inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get('input_keys') else "" - outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get('output_keys') else "" + inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get("input_keys") else "" + outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else "" print(f" {node['id']}: {node['name']}{inputs}{outputs}") print(f" {node['description']}") print() diff --git a/core/framework/runner/mcp_client.py b/core/framework/runner/mcp_client.py index 8cb1eb79a8..f8a9e245f4 100644 --- a/core/framework/runner/mcp_client.py +++ b/core/framework/runner/mcp_client.py @@ -146,6 +146,7 @@ def _connect_stdio(self) -> None: try: import threading + from mcp import StdioServerParameters # Create server parameters @@ -353,9 +354,9 @@ async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any] if len(result.content) > 0: content_item = result.content[0] # Check if it's a text content item - if hasattr(content_item, 'text'): + if hasattr(content_item, "text"): return content_item.text - elif hasattr(content_item, 'data'): + elif hasattr(content_item, "data"): return content_item.data return result.content diff --git a/core/framework/runner/orchestrator.py b/core/framework/runner/orchestrator.py index 23c0f9fb12..a053c90e5b 100644 --- a/core/framework/runner/orchestrator.py +++ b/core/framework/runner/orchestrator.py @@ -205,7 +205,7 @@ async def dispatch( responses = await asyncio.gather(*tasks, return_exceptions=True) - for agent_name, response in zip(routing.selected_agents, responses): + for agent_name, response in zip(routing.selected_agents, responses, strict=False): if isinstance(response, Exception): results[agent_name] = {"error": str(response)} else: @@ -326,7 +326,7 @@ async def broadcast( results = await asyncio.gather(*tasks, return_exceptions=True) - for name, result in zip(agent_names, results): + for name, result in zip(agent_names, results, strict=False): if isinstance(result, Exception): responses[name] = AgentMessage( type=MessageType.RESPONSE, @@ -355,7 +355,7 @@ async def _check_all_capabilities( results = await asyncio.gather(*tasks, return_exceptions=True) capabilities = {} - for name, result in zip(agent_names, results): + for name, result in zip(agent_names, results, strict=False): if isinstance(result, Exception): capabilities[name] = CapabilityResponse( agent_name=name, @@ -463,7 +463,7 @@ async def _llm_route( ) import re - json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL) + json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL) if json_match: data = json.loads(json_match.group()) selected = data.get("selected", []) diff --git a/core/framework/runner/protocol.py b/core/framework/runner/protocol.py index 8592cd9db8..44df72a686 100644 --- a/core/framework/runner/protocol.py +++ b/core/framework/runner/protocol.py @@ -1,10 +1,10 @@ """Message protocol for multi-agent communication.""" +import uuid from dataclasses import dataclass, field from datetime import datetime from enum import Enum from typing import Any -import uuid class MessageType(Enum): diff --git a/core/framework/runner/runner.py b/core/framework/runner/runner.py index 0e1fdd70af..ed1e11d5b9 100644 --- a/core/framework/runner/runner.py +++ b/core/framework/runner/runner.py @@ -2,24 +2,25 @@ import json import os +from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Callable, Any +from typing import TYPE_CHECKING, Any from framework.graph import Goal -from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec +from framework.graph.edge import AsyncEntryPointSpec, EdgeCondition, EdgeSpec, GraphSpec +from framework.graph.executor import ExecutionResult, GraphExecutor from framework.graph.node import NodeSpec -from framework.graph.executor import GraphExecutor, ExecutionResult from framework.llm.provider import LLMProvider, Tool from framework.runner.tool_registry import ToolRegistry -from framework.runtime.core import Runtime # Multi-entry-point runtime imports -from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime +from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime +from framework.runtime.core import Runtime from framework.runtime.execution_stream import EntryPointSpec if TYPE_CHECKING: - from framework.runner.protocol import CapabilityResponse, AgentMessage + from framework.runner.protocol import AgentMessage, CapabilityResponse @dataclass @@ -131,7 +132,7 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]: ) # Build Goal - from framework.graph.goal import SuccessCriterion, Constraint + from framework.graph.goal import Constraint, SuccessCriterion success_criteria = [] for sc_data in goal_data.get("success_criteria", []): @@ -810,7 +811,7 @@ def validate(self) -> ValidationResult: # Check tool credentials (Tier 2) missing_creds = cred_manager.get_missing_for_tools(info.required_tools) - for cred_name, spec in missing_creds: + for _cred_name, spec in missing_creds: missing_credentials.append(spec.env_var) affected_tools = [t for t in info.required_tools if t in spec.tools] tools_str = ", ".join(affected_tools) @@ -820,9 +821,9 @@ def validate(self) -> ValidationResult: warnings.append(warning_msg) # Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes) - node_types = list(set(node.node_type for node in self.graph.nodes)) + node_types = list({node.node_type for node in self.graph.nodes}) missing_node_creds = cred_manager.get_missing_for_node_types(node_types) - for cred_name, spec in missing_node_creds: + for _cred_name, spec in missing_node_creds: if spec.env_var not in missing_credentials: # Avoid duplicates missing_credentials.append(spec.env_var) affected_types = [t for t in node_types if t in spec.node_types] @@ -867,7 +868,7 @@ async def can_handle(self, request: dict, llm: LLMProvider | None = None) -> "Ca Returns: CapabilityResponse with level, confidence, and reasoning """ - from framework.runner.protocol import CapabilityResponse, CapabilityLevel + from framework.runner.protocol import CapabilityLevel, CapabilityResponse # Use provided LLM or set up our own eval_llm = llm @@ -924,7 +925,7 @@ async def can_handle(self, request: dict, llm: LLMProvider | None = None) -> "Ca # Parse response import re - json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL) + json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL) if json_match: data = json.loads(json_match.group()) level_map = { @@ -948,7 +949,7 @@ async def can_handle(self, request: dict, llm: LLMProvider | None = None) -> "Ca def _keyword_capability_check(self, request: dict) -> "CapabilityResponse": """Simple keyword-based capability check (fallback when no LLM).""" - from framework.runner.protocol import CapabilityResponse, CapabilityLevel + from framework.runner.protocol import CapabilityLevel, CapabilityResponse info = self.info() request_str = json.dumps(request).lower() diff --git a/core/framework/runner/tool_registry.py b/core/framework/runner/tool_registry.py index a4ba691fc2..709480b7f2 100644 --- a/core/framework/runner/tool_registry.py +++ b/core/framework/runner/tool_registry.py @@ -4,11 +4,12 @@ import inspect import json import logging +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable +from typing import Any -from framework.llm.provider import Tool, ToolUse, ToolResult +from framework.llm.provider import Tool, ToolResult, ToolUse logger = logging.getLogger(__name__) @@ -142,7 +143,7 @@ def discover_from_module(self, module_path: Path) -> int: # Check for TOOLS dict if hasattr(module, "TOOLS"): - tools_dict = getattr(module, "TOOLS") + tools_dict = module.TOOLS executor_func = getattr(module, "tool_executor", None) for name, tool in tools_dict.items(): diff --git a/core/framework/runtime/agent_runtime.py b/core/framework/runtime/agent_runtime.py index 4bd35b50b2..7b3691d146 100644 --- a/core/framework/runtime/agent_runtime.py +++ b/core/framework/runtime/agent_runtime.py @@ -7,15 +7,16 @@ import asyncio import logging -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from framework.graph.executor import ExecutionResult -from framework.runtime.shared_state import SharedStateManager -from framework.runtime.outcome_aggregator import OutcomeAggregator from framework.runtime.event_bus import EventBus -from framework.runtime.execution_stream import ExecutionStream, EntryPointSpec +from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream +from framework.runtime.outcome_aggregator import OutcomeAggregator +from framework.runtime.shared_state import SharedStateManager from framework.storage.concurrent import ConcurrentStorage if TYPE_CHECKING: diff --git a/core/framework/runtime/core.py b/core/framework/runtime/core.py index 70acdde16e..624e6282fa 100644 --- a/core/framework/runtime/core.py +++ b/core/framework/runtime/core.py @@ -6,13 +6,13 @@ handles all the structured logging. """ -from datetime import datetime -from typing import Any -from pathlib import Path import logging import uuid +from datetime import datetime +from pathlib import Path +from typing import Any -from framework.schemas.decision import Decision, Option, Outcome, DecisionType +from framework.schemas.decision import Decision, DecisionType, Option, Outcome from framework.schemas.run import Run, RunStatus from framework.storage.backend import FileStorage diff --git a/core/framework/runtime/event_bus.py b/core/framework/runtime/event_bus.py index 8a2501e271..b8a36bf817 100644 --- a/core/framework/runtime/event_bus.py +++ b/core/framework/runtime/event_bus.py @@ -9,11 +9,11 @@ import asyncio import logging -import time +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Awaitable, Callable +from typing import Any logger = logging.getLogger(__name__) @@ -432,7 +432,7 @@ async def handler(event: AgentEvent) -> None: if timeout: try: await asyncio.wait_for(event_received.wait(), timeout=timeout) - except asyncio.TimeoutError: + except TimeoutError: return None else: await event_received.wait() diff --git a/core/framework/runtime/execution_stream.py b/core/framework/runtime/execution_stream.py index e786a60de6..322222a4d5 100644 --- a/core/framework/runtime/execution_stream.py +++ b/core/framework/runtime/execution_stream.py @@ -10,21 +10,22 @@ import asyncio import logging import uuid +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Callable, TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from framework.graph.executor import GraphExecutor, ExecutionResult +from framework.graph.executor import ExecutionResult, GraphExecutor +from framework.runtime.shared_state import IsolationLevel, SharedStateManager from framework.runtime.stream_runtime import StreamRuntime, StreamRuntimeAdapter -from framework.runtime.shared_state import SharedStateManager, IsolationLevel, StreamMemory if TYPE_CHECKING: from framework.graph.edge import GraphSpec from framework.graph.goal import Goal - from framework.storage.concurrent import ConcurrentStorage - from framework.runtime.outcome_aggregator import OutcomeAggregator - from framework.runtime.event_bus import EventBus from framework.llm.provider import LLMProvider, Tool + from framework.runtime.event_bus import EventBus + from framework.runtime.outcome_aggregator import OutcomeAggregator + from framework.storage.concurrent import ConcurrentStorage logger = logging.getLogger(__name__) @@ -164,7 +165,7 @@ async def start(self) -> None: # Emit stream started event if self._event_bus: - from framework.runtime.event_bus import EventType, AgentEvent + from framework.runtime.event_bus import AgentEvent, EventType await self._event_bus.publish(AgentEvent( type=EventType.STREAM_STARTED, stream_id=self.stream_id, @@ -179,7 +180,7 @@ async def stop(self) -> None: self._running = False # Cancel all active executions - for exec_id, task in self._execution_tasks.items(): + for _exec_id, task in self._execution_tasks.items(): if not task.done(): task.cancel() try: @@ -194,7 +195,7 @@ async def stop(self) -> None: # Emit stream stopped event if self._event_bus: - from framework.runtime.event_bus import EventType, AgentEvent + from framework.runtime.event_bus import AgentEvent, EventType await self._event_bus.publish(AgentEvent( type=EventType.STREAM_STOPPED, stream_id=self.stream_id, @@ -268,7 +269,7 @@ async def _run_execution(self, ctx: ExecutionContext) -> None: ) # Create execution-scoped memory - memory = self._state_manager.create_memory( + self._state_manager.create_memory( execution_id=execution_id, stream_id=self.stream_id, isolation=ctx.isolation_level, @@ -408,7 +409,7 @@ async def wait_for_completion( return self._execution_results.get(execution_id) - except asyncio.TimeoutError: + except TimeoutError: return None def get_result(self, execution_id: str) -> ExecutionResult | None: diff --git a/core/framework/runtime/outcome_aggregator.py b/core/framework/runtime/outcome_aggregator.py index 9075330bac..bd0012870d 100644 --- a/core/framework/runtime/outcome_aggregator.py +++ b/core/framework/runtime/outcome_aggregator.py @@ -9,7 +9,7 @@ import logging from dataclasses import dataclass, field from datetime import datetime -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from framework.schemas.decision import Decision, Outcome @@ -286,8 +286,8 @@ async def evaluate_goal_progress(self) -> dict[str, Any]: "success_rate": ( self._successful_outcomes / max(1, self._successful_outcomes + self._failed_outcomes) ), - "streams_active": len(set(d.stream_id for d in self._decisions)), - "executions_total": len(set((d.stream_id, d.execution_id) for d in self._decisions)), + "streams_active": len({d.stream_id for d in self._decisions}), + "executions_total": len({(d.stream_id, d.execution_id) for d in self._decisions}), } # Determine recommendation @@ -296,7 +296,7 @@ async def evaluate_goal_progress(self) -> dict[str, Any]: # Publish progress event if self._event_bus: # Get any stream ID for the event - stream_ids = set(d.stream_id for d in self._decisions) + stream_ids = {d.stream_id for d in self._decisions} if stream_ids: await self._event_bus.emit_goal_progress( stream_id=list(stream_ids)[0], @@ -429,7 +429,7 @@ def get_stats(self) -> dict: "failed_outcomes": self._failed_outcomes, "constraint_violations": len(self._constraint_violations), "criteria_tracked": len(self._criterion_status), - "streams_seen": len(set(d.stream_id for d in self._decisions)), + "streams_seen": len({d.stream_id for d in self._decisions}), } # === RESET OPERATIONS === diff --git a/core/framework/runtime/stream_runtime.py b/core/framework/runtime/stream_runtime.py index 3820bc45d5..82137d3702 100644 --- a/core/framework/runtime/stream_runtime.py +++ b/core/framework/runtime/stream_runtime.py @@ -10,9 +10,9 @@ import logging import uuid from datetime import datetime -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from framework.schemas.decision import Decision, Option, Outcome, DecisionType +from framework.schemas.decision import Decision, DecisionType, Option, Outcome from framework.schemas.run import Run, RunStatus from framework.storage.concurrent import ConcurrentStorage diff --git a/core/framework/runtime/tests/test_agent_runtime.py b/core/framework/runtime/tests/test_agent_runtime.py index d46f35f6a1..877eadbf0d 100644 --- a/core/framework/runtime/tests/test_agent_runtime.py +++ b/core/framework/runtime/tests/test_agent_runtime.py @@ -11,21 +11,20 @@ """ import asyncio -import pytest import tempfile from pathlib import Path +import pytest + from framework.graph import Goal -from framework.graph.goal import SuccessCriterion, Constraint -from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec +from framework.graph.edge import AsyncEntryPointSpec, EdgeCondition, EdgeSpec, GraphSpec +from framework.graph.goal import Constraint, SuccessCriterion from framework.graph.node import NodeSpec -from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime +from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime +from framework.runtime.event_bus import AgentEvent, EventBus, EventType from framework.runtime.execution_stream import EntryPointSpec -from framework.runtime.shared_state import SharedStateManager, IsolationLevel -from framework.runtime.event_bus import EventBus, EventType, AgentEvent from framework.runtime.outcome_aggregator import OutcomeAggregator -from framework.runtime.stream_runtime import StreamRuntime - +from framework.runtime.shared_state import IsolationLevel, SharedStateManager # === Test Fixtures === @@ -175,8 +174,8 @@ async def test_shared_state(self): """Test shared state is visible across executions.""" manager = SharedStateManager() - mem1 = manager.create_memory("exec-1", "stream-1", IsolationLevel.SHARED) - mem2 = manager.create_memory("exec-2", "stream-1", IsolationLevel.SHARED) + manager.create_memory("exec-1", "stream-1", IsolationLevel.SHARED) + manager.create_memory("exec-2", "stream-1", IsolationLevel.SHARED) # Write to global scope await manager.write( diff --git a/core/framework/schemas/__init__.py b/core/framework/schemas/__init__.py index 23c06a6c00..5682c771a8 100644 --- a/core/framework/schemas/__init__.py +++ b/core/framework/schemas/__init__.py @@ -1,7 +1,7 @@ """Schema definitions for runtime data.""" -from framework.schemas.decision import Decision, Option, Outcome, DecisionEvaluation -from framework.schemas.run import Run, RunSummary, Problem +from framework.schemas.decision import Decision, DecisionEvaluation, Option, Outcome +from framework.schemas.run import Problem, Run, RunSummary __all__ = [ "Decision", diff --git a/core/framework/schemas/decision.py b/core/framework/schemas/decision.py index 8bf82a9371..9e06d58867 100644 --- a/core/framework/schemas/decision.py +++ b/core/framework/schemas/decision.py @@ -10,8 +10,8 @@ """ from datetime import datetime -from typing import Any from enum import Enum +from typing import Any from pydantic import BaseModel, Field, computed_field diff --git a/core/framework/schemas/run.py b/core/framework/schemas/run.py index 353f64868c..54d256e859 100644 --- a/core/framework/schemas/run.py +++ b/core/framework/schemas/run.py @@ -6,8 +6,8 @@ """ from datetime import datetime -from typing import Any from enum import Enum +from typing import Any from pydantic import BaseModel, Field, computed_field diff --git a/core/framework/storage/backend.py b/core/framework/storage/backend.py index d56534ff23..9cb94ac31b 100644 --- a/core/framework/storage/backend.py +++ b/core/framework/storage/backend.py @@ -8,7 +8,7 @@ import json from pathlib import Path -from framework.schemas.run import Run, RunSummary, RunStatus +from framework.schemas.run import Run, RunStatus, RunSummary class FileStorage: diff --git a/core/framework/storage/concurrent.py b/core/framework/storage/concurrent.py index 8aac83c586..3b093f6356 100644 --- a/core/framework/storage/concurrent.py +++ b/core/framework/storage/concurrent.py @@ -8,15 +8,14 @@ """ import asyncio -import json import logging import time from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any -from framework.schemas.run import Run, RunSummary, RunStatus +from framework.schemas.run import Run, RunStatus, RunSummary from framework.storage.backend import FileStorage logger = logging.getLogger(__name__) @@ -283,7 +282,7 @@ async def _batch_writer(self) -> None: except asyncio.QueueEmpty: break - except asyncio.TimeoutError: + except TimeoutError: pass # Flush batch if we have items @@ -339,7 +338,7 @@ def invalidate_cache(self, key: str) -> None: def get_cache_stats(self) -> dict: """Get cache statistics.""" - now = time.time() + time.time() expired = sum( 1 for entry in self._cache.values() if entry.is_expired(self._cache_ttl) diff --git a/core/framework/testing/__init__.py b/core/framework/testing/__init__.py index 2a91532d5f..5bb0e6def7 100644 --- a/core/framework/testing/__init__.py +++ b/core/framework/testing/__init__.py @@ -33,20 +33,7 @@ """ # Schemas -from framework.testing.test_case import ( - ApprovalStatus, - TestType, - Test, -) -from framework.testing.test_result import ( - ErrorCategory, - TestResult, - TestSuiteResult, -) - -# Storage -from framework.testing.test_storage import TestStorage - +from framework.testing.approval_cli import batch_approval, interactive_approval # Approval from framework.testing.approval_types import ( @@ -56,19 +43,31 @@ BatchApprovalRequest, BatchApprovalResult, ) -from framework.testing.approval_cli import interactive_approval, batch_approval # Error categorization from framework.testing.categorizer import ErrorCategorizer -# LLM Judge for semantic evaluation -from framework.testing.llm_judge import LLMJudge +# CLI +from framework.testing.cli import register_testing_commands # Debug -from framework.testing.debug_tool import DebugTool, DebugInfo +from framework.testing.debug_tool import DebugInfo, DebugTool -# CLI -from framework.testing.cli import register_testing_commands +# LLM Judge for semantic evaluation +from framework.testing.llm_judge import LLMJudge +from framework.testing.test_case import ( + ApprovalStatus, + Test, + TestType, +) +from framework.testing.test_result import ( + ErrorCategory, + TestResult, + TestSuiteResult, +) + +# Storage +from framework.testing.test_storage import TestStorage __all__ = [ # Schemas diff --git a/core/framework/testing/approval_cli.py b/core/framework/testing/approval_cli.py index 9390ff0de1..d503841853 100644 --- a/core/framework/testing/approval_cli.py +++ b/core/framework/testing/approval_cli.py @@ -6,19 +6,19 @@ """ import json -import tempfile -import subprocess import os -from typing import Callable +import subprocess +import tempfile +from collections.abc import Callable -from framework.testing.test_case import Test -from framework.testing.test_storage import TestStorage from framework.testing.approval_types import ( ApprovalAction, ApprovalRequest, ApprovalResult, BatchApprovalResult, ) +from framework.testing.test_case import Test +from framework.testing.test_storage import TestStorage def interactive_approval( diff --git a/core/framework/testing/approval_types.py b/core/framework/testing/approval_types.py index f1f2ea54be..2a5254dbf4 100644 --- a/core/framework/testing/approval_types.py +++ b/core/framework/testing/approval_types.py @@ -5,8 +5,8 @@ programmatic/MCP-based approval. """ -from enum import Enum from datetime import datetime +from enum import Enum from typing import Any from pydantic import BaseModel, Field diff --git a/core/framework/testing/debug_tool.py b/core/framework/testing/debug_tool.py index 404a683071..c36eb69888 100644 --- a/core/framework/testing/debug_tool.py +++ b/core/framework/testing/debug_tool.py @@ -13,10 +13,10 @@ from pydantic import BaseModel, Field +from framework.testing.categorizer import ErrorCategorizer from framework.testing.test_case import Test -from framework.testing.test_result import TestResult, ErrorCategory +from framework.testing.test_result import ErrorCategory, TestResult from framework.testing.test_storage import TestStorage -from framework.testing.categorizer import ErrorCategorizer class DebugInfo(BaseModel): diff --git a/core/framework/testing/test_storage.py b/core/framework/testing/test_storage.py index e39fabf262..559daa32e4 100644 --- a/core/framework/testing/test_storage.py +++ b/core/framework/testing/test_storage.py @@ -6,10 +6,10 @@ """ import json -from pathlib import Path from datetime import datetime +from pathlib import Path -from framework.testing.test_case import Test, ApprovalStatus, TestType +from framework.testing.test_case import ApprovalStatus, Test, TestType from framework.testing.test_result import TestResult diff --git a/core/setup_mcp.py b/core/setup_mcp.py index 212030d021..c0220ba20a 100755 --- a/core/setup_mcp.py +++ b/core/setup_mcp.py @@ -14,11 +14,11 @@ class Colors: """ANSI color codes for terminal output.""" - GREEN = '\033[0;32m' - YELLOW = '\033[1;33m' - RED = '\033[0;31m' - BLUE = '\033[0;34m' - NC = '\033[0m' # No Color + GREEN = "\033[0;32m" + YELLOW = "\033[1;33m" + RED = "\033[0;31m" + BLUE = "\033[0;34m" + NC = "\033[0m" # No Color def print_step(message: str, color: str = Colors.YELLOW): @@ -100,7 +100,7 @@ def main(): } } - with open(mcp_config_path, 'w') as f: + with open(mcp_config_path, "w") as f: json.dump(config, f, indent=2) print_success("Created .mcp.json") diff --git a/core/tests/test_builder.py b/core/tests/test_builder.py index 1858833926..67aac648ff 100644 --- a/core/tests/test_builder.py +++ b/core/tests/test_builder.py @@ -2,7 +2,7 @@ from pathlib import Path -from framework import Runtime, BuilderQuery +from framework import BuilderQuery, Runtime from framework.schemas.run import RunStatus diff --git a/core/tests/test_executor_max_retries.py b/core/tests/test_executor_max_retries.py index 8b27eb1d50..7e67b654e2 100644 --- a/core/tests/test_executor_max_retries.py +++ b/core/tests/test_executor_max_retries.py @@ -5,31 +5,33 @@ the max_retries field in NodeSpec and using a hardcoded value of 3. """ +from unittest.mock import MagicMock + import pytest -from unittest.mock import AsyncMock, MagicMock -from framework.graph.executor import GraphExecutor, ExecutionResult -from framework.graph.node import NodeSpec, NodeProtocol, NodeContext, NodeResult + from framework.graph.edge import GraphSpec +from framework.graph.executor import GraphExecutor from framework.graph.goal import Goal +from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec from framework.runtime.core import Runtime class FlakyTestNode(NodeProtocol): """A test node that fails a configurable number of times before succeeding.""" - + def __init__(self, fail_times: int = 2): self.fail_times = fail_times self.attempt_count = 0 - + async def execute(self, ctx: NodeContext) -> NodeResult: self.attempt_count += 1 - + if self.attempt_count <= self.fail_times: return NodeResult( success=False, error=f"Transient error (attempt {self.attempt_count})" ) - + return NodeResult( success=True, output={"result": f"succeeded after {self.attempt_count} attempts"} @@ -38,10 +40,10 @@ async def execute(self, ctx: NodeContext) -> NodeResult: class AlwaysFailsNode(NodeProtocol): """A test node that always fails.""" - + def __init__(self): self.attempt_count = 0 - + async def execute(self, ctx: NodeContext) -> NodeResult: self.attempt_count += 1 return NodeResult( @@ -67,7 +69,7 @@ def runtime(): async def test_executor_respects_custom_max_retries_high(runtime): """ Test that executor respects max_retries when set to high value (10). - + Node fails 5 times before succeeding. With max_retries=10, should succeed. """ # Create node with max_retries=10 @@ -79,7 +81,7 @@ async def test_executor_respects_custom_max_retries_high(runtime): node_type="function", output_keys=["result"] ) - + # Create graph graph = GraphSpec( id="test_graph", @@ -97,17 +99,17 @@ async def test_executor_respects_custom_max_retries_high(runtime): name="Test Goal", description="Test that max_retries is respected" ) - + # Create executor and register flaky node (fails 5 times, succeeds on 6th) executor = GraphExecutor(runtime=runtime) flaky_node = FlakyTestNode(fail_times=5) executor.register_node("flaky_node", flaky_node) - + # Execute result = await executor.execute(graph, goal, {}) # Should succeed because 5 failures < 10 max_retries (max_retries=N means N total attempts allowed) - assert result.success == True + assert result.success assert flaky_node.attempt_count == 6 # 5 failures + 1 success @@ -155,7 +157,7 @@ async def test_executor_respects_custom_max_retries_low(runtime): result = await executor.execute(graph, goal, {}) # Should fail after exactly 2 attempts (max_retries=N means N total attempts) - assert result.success == False + assert not result.success assert failing_node.attempt_count == 2 # 2 total attempts assert "failed after 2 attempts" in result.error @@ -202,7 +204,7 @@ async def test_executor_respects_default_max_retries(runtime): result = await executor.execute(graph, goal, {}) # Should fail after default 3 total attempts (max_retries=N means N total attempts) - assert result.success == False + assert not result.success assert failing_node.attempt_count == 3 # 3 total attempts assert "failed after 3 attempts" in result.error @@ -251,7 +253,7 @@ async def test_executor_max_retries_two_succeeds_on_second(runtime): result = await executor.execute(graph, goal, {}) # Should succeed on second attempt (max_retries=2 allows 2 total attempts) - assert result.success == True + assert result.success assert flaky_node.attempt_count == 2 # 1 failure + 1 success @@ -279,7 +281,7 @@ async def test_executor_different_nodes_different_max_retries(runtime): input_keys=["result1"], output_keys=["result2"] ) - + # Note: This test would require more complex graph setup with edges # For now, we've verified that max_retries is read from node_spec correctly # The actual value varies per node as expected diff --git a/core/tests/test_flexible_executor.py b/core/tests/test_flexible_executor.py index ff18520008..65b0fb3768 100644 --- a/core/tests/test_flexible_executor.py +++ b/core/tests/test_flexible_executor.py @@ -10,27 +10,28 @@ """ import asyncio + import pytest +from framework.graph.code_sandbox import ( + CodeSandbox, + safe_eval, + safe_exec, +) +from framework.graph.goal import Goal, SuccessCriterion +from framework.graph.judge import HybridJudge, create_default_judge from framework.graph.plan import ( - Plan, - PlanStep, ActionSpec, ActionType, - StepStatus, + EvaluationRule, + ExecutionStatus, Judgment, JudgmentAction, - EvaluationRule, + Plan, PlanExecutionResult, - ExecutionStatus, -) -from framework.graph.code_sandbox import ( - CodeSandbox, - safe_exec, - safe_eval, + PlanStep, + StepStatus, ) -from framework.graph.judge import HybridJudge, create_default_judge -from framework.graph.goal import Goal, SuccessCriterion class TestPlanDataStructures: @@ -397,8 +398,8 @@ class TestFlexibleExecutorIntegration: def test_executor_creation(self, tmp_path): """Test creating a FlexibleGraphExecutor.""" - from framework.runtime.core import Runtime from framework.graph.flexible_executor import FlexibleGraphExecutor + from framework.runtime.core import Runtime runtime = Runtime(storage_path=tmp_path / "runtime") executor = FlexibleGraphExecutor(runtime=runtime) @@ -409,8 +410,8 @@ def test_executor_creation(self, tmp_path): def test_executor_with_custom_judge(self, tmp_path): """Test executor with custom judge.""" - from framework.runtime.core import Runtime from framework.graph.flexible_executor import FlexibleGraphExecutor + from framework.runtime.core import Runtime runtime = Runtime(storage_path=tmp_path / "runtime") custom_judge = HybridJudge() diff --git a/core/tests/test_graph_executor.py b/core/tests/test_graph_executor.py new file mode 100644 index 0000000000..1113b1e1d1 --- /dev/null +++ b/core/tests/test_graph_executor.py @@ -0,0 +1,132 @@ +""" +Tests for core GraphExecutor execution paths. +Focused on minimal success and failure scenarios. +""" + +import pytest + +from framework.graph.edge import GraphSpec +from framework.graph.executor import GraphExecutor +from framework.graph.goal import Goal +from framework.graph.node import NodeResult, NodeSpec + + +# ---- Dummy runtime (no real logging) ---- +class DummyRuntime: + def start_run(self, **kwargs): + return "run-1" + + def end_run(self, **kwargs): + pass + + def report_problem(self, **kwargs): + pass + + +# ---- Fake node that always succeeds ---- +class SuccessNode: + def validate_input(self, ctx): + return [] + + async def execute(self, ctx): + return NodeResult( + success=True, + output={"result": 123}, + tokens_used=1, + latency_ms=1, + ) + + +@pytest.mark.asyncio +async def test_executor_single_node_success(): + runtime = DummyRuntime() + + graph = GraphSpec( + id="graph-1", + goal_id="g1", + nodes=[ + NodeSpec( + id="n1", + name="node1", + description="test node", + node_type="llm_generate", + input_keys=[], + output_keys=["result"], + max_retries=0, + ) + ], + edges=[], + entry_node="n1", + ) + + executor = GraphExecutor( + runtime=runtime, + node_registry={"n1": SuccessNode()}, + ) + + goal = Goal( + id="g1", + name="test-goal", + description="simple test", + ) + + result = await executor.execute(graph=graph, goal=goal) + + assert result.success is True + assert result.path == ["n1"] + assert result.steps_executed == 1 + + +# ---- Fake node that always fails ---- +class FailingNode: + def validate_input(self, ctx): + return [] + + async def execute(self, ctx): + return NodeResult( + success=False, + error="boom", + output={}, + tokens_used=0, + latency_ms=0, + ) + + +@pytest.mark.asyncio +async def test_executor_single_node_failure(): + runtime = DummyRuntime() + + graph = GraphSpec( + id="graph-2", + goal_id="g2", + nodes=[ + NodeSpec( + id="n1", + name="node1", + description="failing node", + node_type="llm_generate", + input_keys=[], + output_keys=["result"], + max_retries=0, + ) + ], + edges=[], + entry_node="n1", + ) + + executor = GraphExecutor( + runtime=runtime, + node_registry={"n1": FailingNode()}, + ) + + goal = Goal( + id="g2", + name="fail-goal", + description="failure test", + ) + + result = await executor.execute(graph=graph, goal=goal) + + assert result.success is False + assert result.error is not None + assert result.path == ["n1"] diff --git a/core/tests/test_hallucination_detection.py b/core/tests/test_hallucination_detection.py index f36eb5cfec..d4ee6436b1 100644 --- a/core/tests/test_hallucination_detection.py +++ b/core/tests/test_hallucination_detection.py @@ -6,7 +6,8 @@ """ import pytest -from framework.graph.node import SharedMemory, MemoryWriteError + +from framework.graph.node import MemoryWriteError, SharedMemory from framework.graph.validator import OutputValidator, ValidationResult diff --git a/core/tests/test_litellm_provider.py b/core/tests/test_litellm_provider.py index 9f17ee9810..b2e284c91a 100644 --- a/core/tests/test_litellm_provider.py +++ b/core/tests/test_litellm_provider.py @@ -10,11 +10,11 @@ """ import os -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch -from framework.llm.litellm import LiteLLMProvider from framework.llm.anthropic import AnthropicProvider -from framework.llm.provider import LLMProvider, Tool, ToolUse, ToolResult +from framework.llm.litellm import LiteLLMProvider +from framework.llm.provider import LLMProvider, Tool, ToolResult, ToolUse class TestLiteLLMProviderInit: diff --git a/core/tests/test_mcp_server.py b/core/tests/test_mcp_server.py index bbbcd500d4..5aee9cfb10 100644 --- a/core/tests/test_mcp_server.py +++ b/core/tests/test_mcp_server.py @@ -55,9 +55,10 @@ def test_mcp_object_exported(self): if not MCP_AVAILABLE: pytest.skip(MCP_SKIP_REASON) - from framework.mcp.agent_builder_server import mcp from mcp.server import FastMCP + from framework.mcp.agent_builder_server import mcp + assert mcp is not None assert isinstance(mcp, FastMCP) @@ -86,8 +87,9 @@ def test_agent_builder_server_exported(self): if not MCP_AVAILABLE: pytest.skip(MCP_SKIP_REASON) - from framework.mcp import agent_builder_server from mcp.server import FastMCP + from framework.mcp import agent_builder_server + assert agent_builder_server is not None assert isinstance(agent_builder_server, FastMCP) diff --git a/core/tests/test_node_json_extraction.py b/core/tests/test_node_json_extraction.py index f90d50b824..7b1e91b645 100644 --- a/core/tests/test_node_json_extraction.py +++ b/core/tests/test_node_json_extraction.py @@ -6,6 +6,7 @@ """ import pytest + from framework.graph.node import LLMNode diff --git a/core/tests/test_orchestrator.py b/core/tests/test_orchestrator.py index 6229584bfe..f0168f3977 100644 --- a/core/tests/test_orchestrator.py +++ b/core/tests/test_orchestrator.py @@ -7,8 +7,8 @@ from unittest.mock import Mock, patch -from framework.llm.provider import LLMProvider from framework.llm.litellm import LiteLLMProvider +from framework.llm.provider import LLMProvider from framework.runner.orchestrator import AgentOrchestrator @@ -17,7 +17,7 @@ class TestOrchestratorLLMInitialization: def test_auto_creates_litellm_provider_when_no_llm_passed(self): """Test that LiteLLMProvider is auto-created when no llm is passed.""" - with patch.object(LiteLLMProvider, '__init__', return_value=None) as mock_init: + with patch.object(LiteLLMProvider, "__init__", return_value=None) as mock_init: orchestrator = AgentOrchestrator() mock_init.assert_called_once_with(model="claude-haiku-4-5-20251001") @@ -25,14 +25,14 @@ def test_auto_creates_litellm_provider_when_no_llm_passed(self): def test_uses_custom_model_parameter(self): """Test that custom model parameter is passed to LiteLLMProvider.""" - with patch.object(LiteLLMProvider, '__init__', return_value=None) as mock_init: + with patch.object(LiteLLMProvider, "__init__", return_value=None) as mock_init: AgentOrchestrator(model="gpt-4o") mock_init.assert_called_once_with(model="gpt-4o") def test_supports_openai_model_names(self): """Test that OpenAI model names are supported.""" - with patch.object(LiteLLMProvider, '__init__', return_value=None) as mock_init: + with patch.object(LiteLLMProvider, "__init__", return_value=None) as mock_init: orchestrator = AgentOrchestrator(model="gpt-4o-mini") mock_init.assert_called_once_with(model="gpt-4o-mini") @@ -40,7 +40,7 @@ def test_supports_openai_model_names(self): def test_supports_anthropic_model_names(self): """Test that Anthropic model names are supported.""" - with patch.object(LiteLLMProvider, '__init__', return_value=None) as mock_init: + with patch.object(LiteLLMProvider, "__init__", return_value=None) as mock_init: orchestrator = AgentOrchestrator(model="claude-3-haiku-20240307") mock_init.assert_called_once_with(model="claude-3-haiku-20240307") @@ -50,7 +50,7 @@ def test_skips_auto_creation_when_llm_passed(self): """Test that auto-creation is skipped when llm is explicitly passed.""" mock_llm = Mock(spec=LLMProvider) - with patch.object(LiteLLMProvider, '__init__', return_value=None) as mock_init: + with patch.object(LiteLLMProvider, "__init__", return_value=None) as mock_init: orchestrator = AgentOrchestrator(llm=mock_llm) mock_init.assert_not_called() @@ -58,7 +58,7 @@ def test_skips_auto_creation_when_llm_passed(self): def test_model_attribute_stored_correctly(self): """Test that _model attribute is stored correctly.""" - with patch.object(LiteLLMProvider, '__init__', return_value=None): + with patch.object(LiteLLMProvider, "__init__", return_value=None): orchestrator = AgentOrchestrator(model="gemini/gemini-1.5-flash") assert orchestrator._model == "gemini/gemini-1.5-flash" @@ -78,5 +78,5 @@ def test_llm_implements_llm_provider_interface(self): orchestrator = AgentOrchestrator() assert isinstance(orchestrator._llm, LLMProvider) - assert hasattr(orchestrator._llm, 'complete') - assert hasattr(orchestrator._llm, 'complete_with_tools') + assert hasattr(orchestrator._llm, "complete") + assert hasattr(orchestrator._llm, "complete_with_tools") diff --git a/core/tests/test_plan.py b/core/tests/test_plan.py index 158eab1a2c..a7f7eff587 100644 --- a/core/tests/test_plan.py +++ b/core/tests/test_plan.py @@ -1,16 +1,17 @@ """Tests for plan.py - Plan enums and Pydantic models.""" import json + import pytest from framework.graph.plan import ( + ActionSpec, ActionType, - StepStatus, ApprovalDecision, - JudgmentAction, ExecutionStatus, - ActionSpec, - PlanStep, + JudgmentAction, Plan, + PlanStep, + StepStatus, ) diff --git a/core/tests/test_run.py b/core/tests/test_run.py index aff99ca3f6..987b7ef246 100644 --- a/core/tests/test_run.py +++ b/core/tests/test_run.py @@ -2,8 +2,10 @@ Test the run module. """ from datetime import datetime -from framework.schemas.run import RunMetrics, Run, RunStatus, RunSummary -from framework.schemas.decision import Decision, Outcome, Option + +from framework.schemas.decision import Decision, Option, Outcome +from framework.schemas.run import Run, RunMetrics, RunStatus, RunSummary + class TestRuntimeMetrics: """Test the RunMetrics class.""" @@ -14,7 +16,7 @@ def test_success_rate(self): failed_decisions=2, ) assert metrics.success_rate == 0.8 - + def test_success_rate_zero_decisions(self): metrics = RunMetrics( total_decisions=0, @@ -87,7 +89,7 @@ def test_record_outcome(self): assert run.metrics.failed_decisions == 0 assert run.metrics.total_tokens == 10 assert run.metrics.total_latency_ms == 100 - + def test_add_problem(self): run = Run( id="test_run", @@ -96,15 +98,15 @@ def test_add_problem(self): completed_at=datetime.now(), ) problem_id = run.add_problem( - "Test problem", - "Test problem description", - "test_decision", - "Test root cause", + "Test problem", + "Test problem description", + "test_decision", + "Test root cause", "Test suggested fix", ) - + assert problem_id == f"prob_{len(run.problems) - 1}" - + problem = run.problems[0] assert problem.id == f"prob_{len(run.problems) - 1}" assert problem.severity == "Test problem" @@ -112,7 +114,7 @@ def test_add_problem(self): assert problem.decision_id == "test_decision" assert problem.root_cause == "Test root cause" assert problem.suggested_fix == "Test suggested fix" - + def test_complete(self): run = Run( id="test_run", @@ -134,9 +136,9 @@ def test_from_run_basic(self): completed_at=datetime.now(), ) run.complete(RunStatus.COMPLETED, "Test narrative") - + summary = RunSummary.from_run(run) - + assert summary.run_id == "test_run" assert summary.goal_id == "test_goal" assert summary.status == RunStatus.COMPLETED @@ -144,7 +146,7 @@ def test_from_run_basic(self): assert summary.success_rate == 0.0 assert summary.problem_count == 0 assert summary.narrative == "Test narrative" - + def test_from_run_with_decisions(self): run = Run( id="test_run", @@ -152,7 +154,7 @@ def test_from_run_with_decisions(self): started_at=datetime.now(), completed_at=datetime.now(), ) - + successful_decision = Decision( id="decision_1", timestamp=datetime.now(), @@ -173,7 +175,7 @@ def test_from_run_with_decisions(self): latency_ms=100, summary="Successfully greeted user", ) - + failed_decision = Decision( id="decision_2", timestamp=datetime.now(), @@ -194,21 +196,21 @@ def test_from_run_with_decisions(self): tokens_used=5, latency_ms=50, ) - + run.add_decision(successful_decision) run.record_outcome("decision_1", successful_outcome) run.add_decision(failed_decision) run.record_outcome("decision_2", failed_outcome) run.complete(RunStatus.COMPLETED, "Test narrative") - + summary = RunSummary.from_run(run) - + assert summary.decision_count == 2 assert summary.success_rate == 0.5 assert len(summary.key_decisions) == 1 assert len(summary.successes) == 1 assert summary.successes[0] == "Successfully greeted user" - + def test_from_run_with_problems(self): run = Run( id="test_run", @@ -216,7 +218,7 @@ def test_from_run_with_problems(self): started_at=datetime.now(), completed_at=datetime.now(), ) - + run.add_problem( severity="critical", description="API timeout", @@ -224,7 +226,7 @@ def test_from_run_with_problems(self): root_cause="Network issue", suggested_fix="Add retry logic", ) - + run.add_problem( severity="warning", description="High latency", @@ -232,13 +234,13 @@ def test_from_run_with_problems(self): root_cause="Large payload", suggested_fix="Optimize data size", ) - + run.complete(RunStatus.COMPLETED, "Test narrative") - + summary = RunSummary.from_run(run) - + assert summary.problem_count == 2 assert len(summary.critical_problems) == 1 assert len(summary.warnings) == 1 assert summary.critical_problems[0] == "API timeout" - assert summary.warnings[0] == "High latency" \ No newline at end of file + assert summary.warnings[0] == "High latency" diff --git a/core/tests/test_runtime.py b/core/tests/test_runtime.py index 32f18fb19c..2ee514d7f3 100644 --- a/core/tests/test_runtime.py +++ b/core/tests/test_runtime.py @@ -1,8 +1,9 @@ """Tests for the Runtime class - the agent's interface to record decisions.""" -import pytest from pathlib import Path +import pytest + from framework import Runtime from framework.schemas.decision import DecisionType diff --git a/core/tests/test_testing_framework.py b/core/tests/test_testing_framework.py index ec1890e461..78a06a940d 100644 --- a/core/tests/test_testing_framework.py +++ b/core/tests/test_testing_framework.py @@ -9,20 +9,19 @@ import pytest +from framework.testing.categorizer import ErrorCategorizer +from framework.testing.debug_tool import DebugTool from framework.testing.test_case import ( + ApprovalStatus, Test, TestType, - ApprovalStatus, ) from framework.testing.test_result import ( + ErrorCategory, TestResult, TestSuiteResult, - ErrorCategory, ) from framework.testing.test_storage import TestStorage -from framework.testing.categorizer import ErrorCategorizer -from framework.testing.debug_tool import DebugTool - # ============================================================================ # Test Schema Tests diff --git a/core/verify_mcp.py b/core/verify_mcp.py index 1704d85ef0..41e8159005 100644 --- a/core/verify_mcp.py +++ b/core/verify_mcp.py @@ -12,11 +12,11 @@ class Colors: - GREEN = '\033[0;32m' - YELLOW = '\033[1;33m' - RED = '\033[0;31m' - BLUE = '\033[0;34m' - NC = '\033[0m' + GREEN = "\033[0;32m" + YELLOW = "\033[1;33m" + RED = "\033[0;31m" + BLUE = "\033[0;34m" + NC = "\033[0m" def check(description: str) -> bool: diff --git a/tools/tests/tools/test_security.py b/tools/tests/tools/test_security.py index 242a6511b0..2ec3a6496d 100644 --- a/tools/tests/tools/test_security.py +++ b/tools/tests/tools/test_security.py @@ -31,7 +31,7 @@ def test_creates_session_directory(self, ids): """Session directory is created if it doesn't exist.""" from aden_tools.tools.file_system_toolkits.security import get_secure_path - result = get_secure_path("file.txt", **ids) + get_secure_path("file.txt", **ids) session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" assert session_dir.exists()