diff --git a/VERSION b/VERSION index 17a79bad4..8725d4cd5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -26.2.0.dev0 +26.2.0.dev2 diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 916b3bd8b..94fdd7172 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -112,6 +112,11 @@ New features Improvements ^^^^^^^^^^^^ +* **Support for native Agent Spec CatchExceptionNode:** + + WayFlow flows using the CatchExceptionStep now convert to the native Agent Spec + CatchExceptionNode when compatible (i.e., when catching all exceptions). + * **Added ``sensitive_headers`` in components that perform remote calls:** ``ApiCallStep``, ``RemoteTool``, and MCP ``RemoteBaseTransport`` now have a new attribute ``sensitive_headers``. diff --git a/wayflowcore/constraints/constraints_dev.txt b/wayflowcore/constraints/constraints_dev.txt index 8863fb5cc..2c04c11f0 100644 --- a/wayflowcore/constraints/constraints_dev.txt +++ b/wayflowcore/constraints/constraints_dev.txt @@ -8,6 +8,6 @@ PyYAML==6.0.3 pydantic==2.12.4 httpx==0.28.1 mcp==1.24.0 -pyagentspec==26.2.0.dev0 # Main branch of pyagentspec +pyagentspec==26.2.0.dev1 # Main branch of pyagentspec opentelemetry-api==1.36.0 opentelemetry-sdk==1.36.0 diff --git a/wayflowcore/setup.py b/wayflowcore/setup.py index 75dcfb0f7..4b03eaa94 100644 --- a/wayflowcore/setup.py +++ b/wayflowcore/setup.py @@ -52,7 +52,7 @@ def read(file_name): packages=find_packages("src"), python_requires=">=3.10,<3.15", install_requires=[ - "pyagentspec>=26.1.0", + "pyagentspec>=26.2.0.dev0", "httpx>0.28.0,<1.0.0", # warning but no vulnerabilities "numpy>=1.24.3,<3.0.0", "pandas>=2.0.3,<3.0.0", diff --git a/wayflowcore/src/wayflowcore/datastore/inmemory.py b/wayflowcore/src/wayflowcore/datastore/inmemory.py index 2d0410b43..c6995c57b 100644 --- a/wayflowcore/src/wayflowcore/datastore/inmemory.py +++ b/wayflowcore/src/wayflowcore/datastore/inmemory.py @@ -6,12 +6,10 @@ import warnings from logging import getLogger -from typing import Any, Dict, List, Optional, Union, cast, overload - -import numpy as np -import pandas as pd +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast, overload from wayflowcore._metadata import MetadataType +from wayflowcore._utils.lazy_loader import LazyLoader from wayflowcore.datastore._datatable import Datatable from wayflowcore.datastore._utils import ( check_collection_name, @@ -35,6 +33,16 @@ from wayflowcore.serialization.context import DeserializationContext, SerializationContext from wayflowcore.serialization.serializer import serialize_to_dict +if TYPE_CHECKING: + # Important: do not move these imports out of the TYPE_CHECKING + # Otherwise, importing the module when they are not installed would lead to an import error. + import numpy as np + import pandas as pd +else: + np = LazyLoader("numpy") + pd = LazyLoader("pandas") + + logger = getLogger(__name__) _INMEMORY_USER_WARNING = "InMemoryDatastore is for DEVELOPMENT and PROOF-OF-CONCEPT ONLY!" diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py index c313590e9..4b386db2f 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py @@ -46,6 +46,9 @@ from pyagentspec.flows.nodes.agentnode import AgentNode as AgentSpecAgentNode from pyagentspec.flows.nodes.apinode import ApiNode as AgentSpecApiNode from pyagentspec.flows.nodes.branchingnode import BranchingNode as AgentSpecBranchingNode +from pyagentspec.flows.nodes.catchexceptionnode import ( + CatchExceptionNode as AgentSpecCatchExceptionNode, +) from pyagentspec.flows.nodes.endnode import EndNode as AgentSpecEndNode from pyagentspec.flows.nodes.flownode import FlowNode as AgentSpecFlowNode from pyagentspec.flows.nodes.llmnode import LlmNode as AgentSpecLlmNode @@ -379,6 +382,7 @@ from wayflowcore.outputparser import PythonToolOutputParser as RuntimePythonToolOutputParser from wayflowcore.outputparser import RegexOutputParser as RuntimeRegexOutputParser from wayflowcore.outputparser import RegexPattern as RuntimeRegexPattern +from wayflowcore.property import AnyProperty as RuntimeAnyProperty from wayflowcore.property import JsonSchemaParam from wayflowcore.property import ListProperty as RuntimeListProperty from wayflowcore.property import Property as RuntimeProperty @@ -1105,10 +1109,18 @@ def _find_property(properties: List[AgentSpecProperty], name: str) -> AgentSpecP conversion_context.convert(edge, tool_registry, converted_components) for edge in data_flow_connections or [] ] - control_flow_edges: List[RuntimeControlFlowEdge] = [ - conversion_context.convert(edge, tool_registry, converted_components) - for edge in agentspec_component.control_flow_connections - ] + control_flow_edges: List[RuntimeControlFlowEdge] = [] + for edge in agentspec_component.control_flow_connections: + if ( + isinstance(edge.from_node, AgentSpecCatchExceptionNode) + and edge.from_branch == AgentSpecCatchExceptionNode.CAUGHT_EXCEPTION_BRANCH + ): + # we need to rename the branch used in the CatchExceptionNode + edge.from_branch = RuntimeCatchExceptionStep.DEFAULT_EXCEPTION_BRANCH + control_flow_edges.append( + conversion_context.convert(edge, tool_registry, converted_components) + ) + for step in steps.values(): for branch in step.get_branches(): edge_exists = any( @@ -1482,6 +1494,55 @@ def _find_property(properties: List[AgentSpecProperty], name: str) -> AgentSpecP except_on=agentspec_component.except_on, **self._get_rt_nodes_arguments(agentspec_component, metadata_info), ) + elif isinstance(agentspec_component, AgentSpecCatchExceptionNode): + # Standard CatchExceptionNode from Agent Spec does not expose catch_all_exceptions + # and except_on fields + # Also, the output of the catch exception node might be renamed, + # so we have to use output mapping when needed + rt_nodes_arguments = self._get_node_arguments(agentspec_component, metadata_info) + if agentspec_component.outputs: + subflow_outputs_titles = { + p.title for p in agentspec_component.subflow.outputs or [] + } + caught_exception_property = next( + ( + p + for p in agentspec_component.outputs + if p.title not in subflow_outputs_titles + ), + None, + ) + if caught_exception_property is None: + raise ValueError( + f"Internal error: Agent Spec CatchExceptionNode '{agentspec_component.name}' " + "is missing a output for the exception info. Make sure the pyagentspec " + "component is successfully validated." + ) + if ( + caught_exception_property.title + != RuntimeCatchExceptionStep.EXCEPTION_PAYLOAD_OUTPUT_NAME + ): + # there is no output mapping by default. We add one to handle renaming + rt_nodes_arguments["output_mapping"] = { + RuntimeCatchExceptionStep.EXCEPTION_PAYLOAD_OUTPUT_NAME: caught_exception_property.title + } + + # we need to add a default output property for the exception name + rt_nodes_arguments["output_descriptors"].append( + RuntimeAnyProperty( + name=RuntimeCatchExceptionStep.EXCEPTION_NAME_OUTPUT_NAME, + default_value="", + ) + ) + + return RuntimeCatchExceptionStep( + flow=conversion_context.convert( + agentspec_component.subflow, tool_registry, converted_components + ), + catch_all_exceptions=True, + except_on=None, + **rt_nodes_arguments, + ) elif isinstance(agentspec_component, AgentSpecPluginRegexNode): regex_pattern = self._regex_pattern_to_runtime(agentspec_component.regex_pattern) if not ( diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py index 31aa8f8c8..1aece6417 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py @@ -5,7 +5,7 @@ # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import uuid -from dataclasses import MISSING, fields, is_dataclass +from dataclasses import MISSING, fields, is_dataclass, replace from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from warnings import warn @@ -38,6 +38,7 @@ from pyagentspec.flows.nodes import AgentNode as AgentSpecAgentNode from pyagentspec.flows.nodes import ApiNode as AgentSpecApiNode from pyagentspec.flows.nodes import BranchingNode as AgentSpecBranchingNode +from pyagentspec.flows.nodes import CatchExceptionNode as AgentSpecCatchExceptionNode from pyagentspec.flows.nodes import EndNode as AgentSpecEndNode from pyagentspec.flows.nodes import FlowNode as AgentSpecFlowNode from pyagentspec.flows.nodes import InputMessageNode as AgentSpecInputMessageNode @@ -2142,14 +2143,25 @@ def _flow_convert_to_agentspec( runtime_flow: RuntimeFlow, referenced_objects: Optional[Dict[str, Any]] = None, ) -> AgentSpecFlow: + runtime_control_data_edges = runtime_flow.control_flow_edges + runtime_flow_data_edges = runtime_flow.data_flow_edges + runtime_flow_output_descriptors = runtime_flow.output_descriptors + steps_to_process_separately: Dict[str, AgentSpecNode] = {} + steps_to_process_together: Dict[str, AgentSpecNode] = {} + for node_name, runtime_node in runtime_flow.steps.items(): + if isinstance(runtime_node, (RuntimeCompleteStep, RuntimeCatchExceptionStep)): + # Those nodes need to be converted with some specific logic + steps_to_process_separately[node_name] = runtime_node + else: + steps_to_process_together[node_name] = runtime_node + # Convert the nodes that can be converted without additional logic agentspec_nodes: Dict[str, AgentSpecNode] = { runtime_node.id: cast( AgentSpecNode, conversion_context.convert(runtime_node, referenced_objects), ) - for node_name, runtime_node in runtime_flow.steps.items() - if not issubclass(type(runtime_node), RuntimeCompleteStep) + for runtime_node in steps_to_process_together.values() } start_node = next( @@ -2160,18 +2172,105 @@ def _flow_convert_to_agentspec( # End nodes are created separately because we need to infer the outputs for them. # We assume that all the steps are going to expose all the outputs of the flow - for node_name, runtime_node in runtime_flow.steps.items(): + for node_name, runtime_node in steps_to_process_separately.items(): + step_args = _get_step_args(runtime_node) if issubclass(type(runtime_node), RuntimeCompleteStep): agentspec_nodes[runtime_node.id] = AgentSpecEndNode( name=node_name, branch_name=runtime_node.branch_name or node_name, # type: ignore outputs=[ _runtime_property_to_pyagentspec_property(property_) - for property_ in runtime_flow.output_descriptors + for property_ in runtime_flow_output_descriptors ], metadata=_create_agentspec_metadata_from_runtime_component(runtime_node), id=runtime_node.id, ) + elif isinstance( + runtime_node, RuntimeCatchExceptionStep + ) and _flow_with_catchexceptionstep_requires_plugin(runtime_node, runtime_flow): + agentspec_nodes[runtime_node.id] = AgentSpecPluginCatchExceptionNode( + **step_args, + flow=cast( + AgentSpecFlow, + conversion_context.convert(runtime_node.flow, referenced_objects), + ), + catch_all_exceptions=runtime_node.catch_all_exceptions, + except_on=runtime_node.except_on, + input_mapping=runtime_node.input_mapping, + output_mapping=runtime_node.output_mapping, + ) + elif isinstance(runtime_node, RuntimeCatchExceptionStep): + subflow = cast( + AgentSpecFlow, + conversion_context.convert(runtime_node.flow, referenced_objects), + ) + # need to add _type_default_value when needed + # The flow is compatible with the native CatchExceptionNode. Still, we need + # to drop the extra fields from the WayFlow CatchExceptionStep + # and handle any renaming so that this can be used in Agent Spec. + del step_args["inputs"] + del step_args["outputs"] + + # update control flow edges for branch renaming + for i, control_flow_edge in enumerate(runtime_control_data_edges): + if ( + control_flow_edge.source_step.name == runtime_node.name + and control_flow_edge.source_branch + == RuntimeCatchExceptionStep.DEFAULT_EXCEPTION_BRANCH + ): + control_flow_edge_copy = replace(control_flow_edge) + object.__setattr__( + control_flow_edge_copy, + "source_branch", + AgentSpecCatchExceptionNode.CAUGHT_EXCEPTION_BRANCH, + ) + # ^ controlled setattr to inject the branch name that will be accepted by Agent Spec + runtime_control_data_edges[i] = control_flow_edge_copy + + # update data flow edges to revert mapping and rename if needed (done here to keep code logic local) + mapped_payload_output_name = runtime_node.output_mapping.get( + RuntimeCatchExceptionStep.EXCEPTION_PAYLOAD_OUTPUT_NAME, + RuntimeCatchExceptionStep.EXCEPTION_PAYLOAD_OUTPUT_NAME, + ) + for i, data_flow_edge in enumerate(runtime_flow_data_edges): + if data_flow_edge.source_output == mapped_payload_output_name: + data_flow_edge_copy = replace(data_flow_edge) + object.__setattr__( + data_flow_edge_copy, + "source_output", + AgentSpecCatchExceptionNode.DEFAULT_EXCEPTION_INFO_VALUE, + ) + # ^ controlled setattr to inject the output name that will be accepted by Agent Spec + runtime_flow_data_edges[i] = data_flow_edge_copy + + # update flow output descriptors (remove the exception name output, rename the payload output) + mapped_exception_output_name = runtime_node.output_mapping.get( + RuntimeCatchExceptionStep.EXCEPTION_NAME_OUTPUT_NAME, + RuntimeCatchExceptionStep.EXCEPTION_NAME_OUTPUT_NAME, + ) + idx_to_remove: int | None = None + for i, property_ in enumerate(runtime_flow_output_descriptors): + if property_.name == mapped_exception_output_name: + idx_to_remove = i + elif property_.name == mapped_payload_output_name: + output_property_copy = replace(property_) + object.__setattr__( + output_property_copy, + "name", + AgentSpecCatchExceptionNode.DEFAULT_EXCEPTION_INFO_VALUE, + ) + # ^ controlled setattr to inject the output name that will be accepted by Agent Spec + runtime_flow_output_descriptors[i] = output_property_copy + + if idx_to_remove: + runtime_flow_output_descriptors.pop(idx_to_remove) + + runtime_flow_output_descriptors + + agentspec_nodes[runtime_node.id] = AgentSpecCatchExceptionNode( + **step_args, + subflow=subflow, + ) # Overwrite the temp names assigned by the conversion for node_name, runtime_node in runtime_flow.steps.items(): @@ -2191,7 +2290,7 @@ def _flow_convert_to_agentspec( metadata=_create_agentspec_metadata_from_runtime_component(control_flow_edge), id=control_flow_edge.id, ) - for control_flow_edge in runtime_flow.control_flow_edges + for control_flow_edge in runtime_control_data_edges if control_flow_edge.destination_step is not None ] context_providers_dict = {} @@ -2206,7 +2305,7 @@ def _flow_convert_to_agentspec( ) data_flow_connections: List[AgentSpecDataFlowEdge] = [] - for runtime_data_flow_edge in runtime_flow.data_flow_edges: + for runtime_data_flow_edge in runtime_flow_data_edges: source_step_id = runtime_data_flow_edge.source_step.id destination_node = agentspec_nodes[runtime_data_flow_edge.destination_step.id] if source_step_id in agentspec_nodes: @@ -2235,7 +2334,7 @@ def _flow_convert_to_agentspec( node_name = f"None End node" end_node_outputs = [ _runtime_property_to_pyagentspec_property(property_) - for property_ in runtime_flow.output_descriptors + for property_ in runtime_flow_output_descriptors ] new_end_node = AgentSpecEndNode( name=node_name, @@ -2310,7 +2409,7 @@ def _check_data_flow_edge_does_not_exist( all_nodes = list(agentspec_nodes.values()) # As currently we do not always have end steps in wayflowcore, we create them if there aren't and we connect them end_node_added = False - for control_flow_edge in runtime_flow.control_flow_edges: + for control_flow_edge in runtime_control_data_edges: if control_flow_edge.destination_step is None and not isinstance( control_flow_edge.source_step, RuntimeCompleteStep ): @@ -2451,7 +2550,7 @@ def _add_collected(data_flow_edge: AgentSpecDataFlowEdge) -> None: # Or we again check the renamed names when converting from flow outputs # Below is the latter option flow_outputs: List[AgentSpecProperty] = [] - for flow_output in runtime_flow.output_descriptors or []: + for flow_output in runtime_flow_output_descriptors or []: if flow_output.name in renamed_outputs: original_property = _runtime_property_to_pyagentspec_property(flow_output) json_schema = original_property.json_schema @@ -2585,29 +2684,8 @@ def _step_convert_to_agentspec( runtime_step: RuntimeStep, referenced_objects: Optional[Dict[str, Any]] = None, ) -> AgentSpecNode: - # The runtime steps do not contain the name, but it is mandatory to buildAgent Spec Nodes - # We give a temp name, and we assume that who knows the node's name will overwrite it - node_name = runtime_step.name or runtime_step.__metadata_info__.get("name", "_temp_name_") - node_description = runtime_step.__metadata_info__.get("description", "") + step_args = _get_step_args(runtime_step) runtime_step_type = type(runtime_step) - metadata = _create_agentspec_metadata_from_runtime_component(runtime_step) - inputs = [ - _runtime_property_to_pyagentspec_property(output) - for output in runtime_step.input_descriptors or [] - ] - outputs = [ - _runtime_property_to_pyagentspec_property(output) - for output in runtime_step.output_descriptors or [] - ] - node_id = runtime_step.id - step_args = dict( - name=node_name, - description=node_description, - inputs=inputs, - outputs=outputs, - metadata=metadata, - id=node_id, - ) # We compare the type directly instead of using isinstance in order to avoid # undesired, multiple node type matches due to class inheritance if runtime_step_type is RuntimePromptExecutionStep: @@ -3040,17 +3118,9 @@ def _step_convert_to_agentspec( output_mapping=runtime_step.output_mapping, ) elif runtime_step_type is RuntimeCatchExceptionStep: - runtime_step = cast(RuntimeCatchExceptionStep, runtime_step) - return AgentSpecPluginCatchExceptionNode( - **step_args, - flow=cast( - AgentSpecFlow, - conversion_context.convert(runtime_step.flow, referenced_objects), - ), - catch_all_exceptions=runtime_step.catch_all_exceptions, - except_on=runtime_step.except_on, - input_mapping=runtime_step.input_mapping, - output_mapping=runtime_step.output_mapping, + raise ValueError( + "CatchExceptionStep cannot be converted by itself, please make " + "sure to convert at least the Flow using this step." ) elif runtime_step_type is RuntimeDatastoreListStep: runtime_step = cast(RuntimeDatastoreListStep, runtime_step) @@ -3277,3 +3347,46 @@ def _a2aagent_convert_to_agentspec( ], metadata=_create_agentspec_metadata_from_runtime_component(runtime_a2aagent), ) + + +def _flow_with_catchexceptionstep_requires_plugin( + step: RuntimeCatchExceptionStep, flow: RuntimeFlow +) -> bool: + # 1. If the step catch specific exceptions + # we need to use the plugin + if not step.catch_all_exceptions: + return True + # 2. If the main flow uses the EXCEPTION_NAME_OUTPUT_NAME data + # we need to use the plugin (missing in the Agent Spec node) + non_mapped_output_name = step.EXCEPTION_NAME_OUTPUT_NAME + mapped_output_name = step.output_mapping.get(non_mapped_output_name, non_mapped_output_name) + if any( + d.source_output == mapped_output_name for d in flow.data_flow_edges if d.source_step is step + ): + return True + return False + + +def _get_step_args(runtime_step: RuntimeStep) -> dict[str, Any]: + # The runtime steps do not contain the name, but it is mandatory to buildAgent Spec Nodes + # We give a temp name, and we assume that who knows the node's name will overwrite it + node_name = runtime_step.name or runtime_step.__metadata_info__.get("name", "_temp_name_") + node_description = runtime_step.__metadata_info__.get("description", "") + metadata = _create_agentspec_metadata_from_runtime_component(runtime_step) + inputs = [ + _runtime_property_to_pyagentspec_property(output) + for output in runtime_step.input_descriptors or [] + ] + outputs = [ + _runtime_property_to_pyagentspec_property(output) + for output in runtime_step.output_descriptors or [] + ] + node_id = runtime_step.id + return dict( + name=node_name, + description=node_description, + inputs=inputs, + outputs=outputs, + metadata=metadata, + id=node_id, + ) diff --git a/wayflowcore/src/wayflowcore/steps/catchexceptionstep.py b/wayflowcore/src/wayflowcore/steps/catchexceptionstep.py index beb04d3ba..3920fb132 100644 --- a/wayflowcore/src/wayflowcore/steps/catchexceptionstep.py +++ b/wayflowcore/src/wayflowcore/steps/catchexceptionstep.py @@ -9,7 +9,7 @@ from wayflowcore._metadata import MetadataType from wayflowcore.executors._flowexecutor import FlowConversationExecutor -from wayflowcore.property import AnyProperty, Property, _format_default_value +from wayflowcore.property import AnyProperty, Property, _empty_default, _format_default_value from wayflowcore.steps import FlowExecutionStep from wayflowcore.steps.step import Step, StepExecutionStatus, StepResult @@ -116,6 +116,7 @@ def __init__( __metadata_info__=__metadata_info__, ) + self._validate_output_descriptors_in_subflow(flow) self.flow = flow self.except_on = except_on or {} self.catch_all_exceptions = catch_all_exceptions @@ -124,6 +125,48 @@ def __init__( def sub_flows(self) -> Optional[List["Flow"]]: return [self.flow] + def _validate_output_descriptors_in_subflow(self, flow: "Flow") -> None: + subflow_outputs = flow.output_descriptors or [] + current_step_outputs = self.output_descriptors or [] # may be renamed + + step_output_titles = {p.name for p in current_step_outputs} + subflow_output_titles = {p.name for p in subflow_outputs} + + # 1. Subflow outputs must not conflict with the CatchExceptionStep outputs + if ( + self.EXCEPTION_NAME_OUTPUT_NAME in subflow_output_titles + or self.EXCEPTION_PAYLOAD_OUTPUT_NAME in subflow_output_titles + ): + raise ValueError( + f"Found reserved descriptor names in subflow output descriptors '{subflow_output_titles}'. " + f"Names {self.EXCEPTION_NAME_OUTPUT_NAME} and {self.EXCEPTION_PAYLOAD_OUTPUT_NAME} are " + "reserved names of the CatchExceptionStep and should not be used as outputs of the subflow." + ) + + # 2. when provided by the user, step outputs should match subflow outputs + expected_titles = { + *subflow_output_titles, + self.output_mapping.get( + self.EXCEPTION_NAME_OUTPUT_NAME, self.EXCEPTION_NAME_OUTPUT_NAME + ), + self.output_mapping.get( + self.EXCEPTION_PAYLOAD_OUTPUT_NAME, self.EXCEPTION_PAYLOAD_OUTPUT_NAME + ), + } + if step_output_titles != expected_titles: + raise ValueError( + f"CatchExceptionStep '{self.name}': provided outputs must have the same names as subflow outputs. " + f"Provided: {sorted(step_output_titles)}, Subflow: {sorted(subflow_output_titles)}" + ) + + # 3. Subflow ouutput descriptors must have a default value + for property_ in subflow_outputs: + if property_.default_value is _empty_default: + raise ValueError( + f"CatchExceptionStep '{self.name}': subflow output '{property_.name}' " + "must have a default value when the subflow is used in a CatchExceptionStep." + ) + @classmethod def _get_step_specific_static_configuration_descriptors( cls, diff --git a/wayflowcore/tests/agentspec/test_catchexception_node.py b/wayflowcore/tests/agentspec/test_catchexception_node.py new file mode 100644 index 000000000..9e5175caa --- /dev/null +++ b/wayflowcore/tests/agentspec/test_catchexception_node.py @@ -0,0 +1,432 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + + +import pytest +from pyagentspec.flows.edges import ControlFlowEdge as SpecControlFlowEdge +from pyagentspec.flows.edges import DataFlowEdge as SpecDataFlowEdge +from pyagentspec.flows.flow import Flow as SpecFlow +from pyagentspec.flows.node import Node +from pyagentspec.flows.nodes import CatchExceptionNode, EndNode, StartNode, ToolNode +from pyagentspec.property import IntegerProperty, NullProperty, StringProperty, UnionProperty +from pyagentspec.tools import ServerTool as SpecServerTool + +from wayflowcore.agentspec import AgentSpecExporter, AgentSpecLoader +from wayflowcore.agentspec.components import PluginCatchExceptionNode +from wayflowcore.controlconnection import ControlFlowEdge as RuntimeControlFlowEdge +from wayflowcore.dataconnection import DataFlowEdge as RuntimeDataFlowEdge +from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.flow import Flow as RuntimeFlow + +# from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.property import StringProperty as RuntimeStringProperty +from wayflowcore.steps import CatchExceptionStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.steps.step import Step +from wayflowcore.tools import tool + + +@pytest.fixture +def spec_flow_with_catchexception() -> SpecFlow: + inp = IntegerProperty(title="x") + subflow_output = StringProperty(title="tool_output", default="") + flaky_tool = SpecServerTool( + name="flaky_tool", + description="Raises for negative inputs", + inputs=[inp], + outputs=[subflow_output], + ) + sub_start = StartNode(name="sub_start", inputs=[inp]) + tool_node = ToolNode(name="flaky_node", tool=flaky_tool) + sub_end = EndNode(name="sub_end", outputs=[subflow_output]) + + subflow = SpecFlow( + name="flaky_subflow", + start_node=sub_start, + nodes=[sub_start, tool_node, sub_end], + control_flow_connections=[ + SpecControlFlowEdge(name="s2t", from_node=sub_start, to_node=tool_node), + SpecControlFlowEdge(name="t2e", from_node=tool_node, to_node=sub_end), + ], + data_flow_connections=[ + SpecDataFlowEdge( + name="in", + source_node=sub_start, + source_output=inp.title, + destination_node=tool_node, + destination_input=inp.title, + ), + SpecDataFlowEdge( + name="out", + source_node=tool_node, + source_output=subflow_output.title, + destination_node=sub_end, + destination_input=subflow_output.title, + ), + ], + inputs=[inp], + outputs=[subflow_output], + ) + + catch = CatchExceptionNode(name="catch_step", subflow=subflow) + + start = StartNode(name="start", inputs=[inp]) + error_info = UnionProperty( + title="error_info", + any_of=[StringProperty(title="error_info"), NullProperty(title="error_info")], + default=None, + ) + success_end = EndNode(name="success_end", outputs=[subflow_output, error_info]) + error_end = EndNode(name="error_end", outputs=[subflow_output, error_info], branch_name="ERROR") + + return SpecFlow( + name="outer", + start_node=start, + nodes=[start, catch, success_end, error_end], + control_flow_connections=[ + SpecControlFlowEdge(name="s2c", from_node=start, to_node=catch), + SpecControlFlowEdge(name="c2e", from_node=catch, to_node=success_end), + SpecControlFlowEdge( + name="caught_to_error", + from_node=catch, + from_branch=CatchExceptionNode.CAUGHT_EXCEPTION_BRANCH, + to_node=error_end, + ), + ], + data_flow_connections=[ + SpecDataFlowEdge( + name="in", + source_node=start, + source_output=inp.title, + destination_node=catch, + destination_input=inp.title, + ), + SpecDataFlowEdge( + name="out", + source_node=catch, + source_output=subflow_output.title, + destination_node=success_end, + destination_input=subflow_output.title, + ), + SpecDataFlowEdge( + name="exception_to_error", + source_node=catch, + source_output="caught_exception_info", + destination_node=error_end, + destination_input=error_info.title, + ), + ], + inputs=[inp], + outputs=[subflow_output, error_info], + ) + + +@tool(description_mode="only_docstring") +def flaky_tool(x: int) -> str: + """Raises for negative inputs.""" + if x < 0: + raise ValueError("x must be non-negative") + return "ok" + + +@pytest.fixture +def flaky_wayflow_subflow() -> RuntimeFlow: + tool_step = ToolExecutionStep( + name="flaky", + tool=flaky_tool, + raise_exceptions=True, + output_descriptors=[ + RuntimeStringProperty(name=ToolExecutionStep.TOOL_OUTPUT, default_value="no_output") + ], + ) + return RuntimeFlow.from_steps([tool_step]) + + +@pytest.fixture +def runtime_flow_catching_all_exceptions(flaky_wayflow_subflow: SpecFlow) -> RuntimeFlow: + tool_node_with_catch = CatchExceptionStep( + name="catch_step", flow=flaky_wayflow_subflow, catch_all_exceptions=True + ) + tool_sucess_output = OutputMessageStep( + name="success_output_step", + message_template="Tool succeeded without exceptions.", + ) + tool_failure_output_step = OutputMessageStep( + name="failure_output_step", + message_template="Tool failed with ValueError: {{tool_error}}", + ) + return RuntimeFlow( + name="flow_catch_all_exceptions", + begin_step=tool_node_with_catch, + control_flow_edges=[ + RuntimeControlFlowEdge( + source_step=tool_node_with_catch, + destination_step=tool_sucess_output, + source_branch=CatchExceptionStep.BRANCH_NEXT, + ), + RuntimeControlFlowEdge( + source_step=tool_node_with_catch, + destination_step=tool_failure_output_step, + source_branch=CatchExceptionStep.DEFAULT_EXCEPTION_BRANCH, + # ^ This is the branch taken when any exception is raised + ), + RuntimeControlFlowEdge(source_step=tool_sucess_output, destination_step=None), + RuntimeControlFlowEdge(source_step=tool_failure_output_step, destination_step=None), + ], + data_flow_edges=[ + RuntimeDataFlowEdge( + tool_node_with_catch, + CatchExceptionStep.EXCEPTION_PAYLOAD_OUTPUT_NAME, + tool_failure_output_step, + "tool_error", + ), + ], + ) + + +def test_runtime_flow_catching_all_exceptions_runs_as_expected( + runtime_flow_catching_all_exceptions: RuntimeFlow, +) -> None: + flow = runtime_flow_catching_all_exceptions + # flaky case + conv = flow.start_conversation({"x": -5}) + status = conv.execute() + assert isinstance(status, FinishedStatus) + assert status.output_values == { + "exception_name": "ValueError", + "tool_output": "no_output", # default value + "output_message": "Tool failed with ValueError: x must be non-negative", + "exception_payload_name": "x must be non-negative", + } + + # non-flaky case + conv2 = flow.start_conversation({"x": 5}) + status = conv2.execute() + assert isinstance(status, FinishedStatus) + assert status.output_values == { + "exception_payload_name": "", + "tool_output": "ok", + "output_message": "Tool succeeded without exceptions.", + "exception_name": "", + } + + +def test_wayflow_flow_catching_all_exceptions_properly_converts_to_agentspec( + runtime_flow_catching_all_exceptions: RuntimeFlow, +) -> None: + spec_flow: SpecFlow = AgentSpecExporter().to_component(runtime_flow_catching_all_exceptions) + + # 1. there should be two output message nodes + node_titles = {n.name for n in spec_flow.nodes} + assert "success_output_step" in node_titles + assert "failure_output_step" in node_titles + + # 2. the flow should use the native CatchExceptionNode + catch_node = next((n for n in spec_flow.nodes if n.name == "catch_step"), None) + assert (catch_node is not None) and isinstance(catch_node, CatchExceptionNode) + + # 3. check for correct control flow edge connection + control_flow_edges = { + (c.from_node.name, c.to_node.name, c.from_branch) + for c in spec_flow.control_flow_connections + } + assert ( + "catch_step", + "success_output_step", + Node.DEFAULT_NEXT_BRANCH, + ) in control_flow_edges or ("catch_step", "success_output_step", None) in control_flow_edges + assert ( + "catch_step", + "failure_output_step", + CatchExceptionNode.CAUGHT_EXCEPTION_BRANCH, + ) in control_flow_edges + + # 4. check for correct data flow edge connection + data_flow_edges = { + (d.source_node.name, d.source_output, d.destination_node.name, d.destination_input) + for d in spec_flow.data_flow_connections + } + assert ( + "catch_step", + CatchExceptionNode.DEFAULT_EXCEPTION_INFO_VALUE, + "failure_output_step", + "tool_error", + ) in data_flow_edges + + # 5. check for input/output + spec_flow_input_titles = {p.title for p in spec_flow.inputs} + spec_flow_output_titles = {p.title for p in spec_flow.outputs} + + assert spec_flow_input_titles == {"x"} + + # check that the exception name has been removed + assert CatchExceptionStep.EXCEPTION_NAME_OUTPUT_NAME not in spec_flow_output_titles + # check that the exception payload has been renamed + assert CatchExceptionStep.EXCEPTION_PAYLOAD_OUTPUT_NAME not in spec_flow_output_titles + assert CatchExceptionNode.DEFAULT_EXCEPTION_INFO_VALUE in spec_flow_output_titles + + +@pytest.fixture +def runtime_flow_catching_specific_exceptions(flaky_wayflow_subflow: RuntimeFlow) -> RuntimeFlow: + tool_node_with_catch = CatchExceptionStep( + name="catch_step", + flow=flaky_wayflow_subflow, + except_on={ValueError.__name__: "value_error_branch"}, + ) + tool_sucess_output = OutputMessageStep( + name="success_output_step", + message_template="Tool succeeded without exceptions.", + ) + tool_failure_output_step = OutputMessageStep( + name="failure_output_step", + message_template="Tool failed with ValueError: {{tool_error}}", + ) + return RuntimeFlow( + name="flow_catch_value_error", + begin_step=tool_node_with_catch, + control_flow_edges=[ + RuntimeControlFlowEdge( + source_step=tool_node_with_catch, + destination_step=tool_sucess_output, + source_branch=CatchExceptionStep.BRANCH_NEXT, + ), + RuntimeControlFlowEdge( + source_step=tool_node_with_catch, + destination_step=tool_failure_output_step, + source_branch="value_error_branch", + ), + RuntimeControlFlowEdge(source_step=tool_sucess_output, destination_step=None), + RuntimeControlFlowEdge(source_step=tool_failure_output_step, destination_step=None), + ], + data_flow_edges=[ + RuntimeDataFlowEdge( + tool_node_with_catch, + CatchExceptionStep.EXCEPTION_NAME_OUTPUT_NAME, + tool_failure_output_step, + "tool_error", + ), + ], + ) + + +def test_runtime_flow_catching_specific_exceptions_runs_as_expected( + runtime_flow_catching_specific_exceptions: RuntimeFlow, +) -> None: + flow = runtime_flow_catching_specific_exceptions + + # flaky case + conv = flow.start_conversation({"x": -1}) + status = conv.execute() + assert isinstance(status, FinishedStatus) + assert status.output_values == { + "exception_name": "ValueError", + "tool_output": "no_output", + "exception_payload_name": "x must be non-negative", + "output_message": "Tool failed with ValueError: ValueError", + } + + # non-flaky case + conv2 = flow.start_conversation({"x": 10}) + status = conv2.execute() + assert isinstance(status, FinishedStatus) + assert status.output_values == { + "exception_payload_name": "", + "tool_output": "ok", + "exception_name": "", + "output_message": "Tool succeeded without exceptions.", + } + + +def test_wayflow_flow_catching_specific_exceptions_properly_converts_to_agentspec( + runtime_flow_catching_specific_exceptions: RuntimeFlow, +) -> None: + spec_flow: SpecFlow = AgentSpecExporter().to_component( + runtime_flow_catching_specific_exceptions + ) + + # 1. there should be two output message nodes + node_titles = {n.name for n in spec_flow.nodes} + assert "success_output_step" in node_titles + assert "failure_output_step" in node_titles + + # 2. the flow should use the native CatchExceptionNode + catch_node = next((n for n in spec_flow.nodes if n.name == "catch_step"), None) + assert (catch_node is not None) and isinstance(catch_node, PluginCatchExceptionNode) + + # 3. check for correct control flow edge connection + control_flow_edges = { + (c.from_node.name, c.to_node.name, c.from_branch) + for c in spec_flow.control_flow_connections + } + assert ( + "catch_step", + "success_output_step", + Node.DEFAULT_NEXT_BRANCH, + ) in control_flow_edges or ("catch_step", "success_output_step", None) in control_flow_edges + assert ("catch_step", "failure_output_step", "value_error_branch") in control_flow_edges + + # 4. check for correct data flow edge connection + data_flow_edges = { + (d.source_node.name, d.source_output, d.destination_node.name, d.destination_input) + for d in spec_flow.data_flow_connections + } + assert ( + "catch_step", + PluginCatchExceptionNode.EXCEPTION_NAME_OUTPUT_NAME, + "failure_output_step", + "tool_error", + ) in data_flow_edges + + # 5. check for input/output + spec_flow_input_titles = {p.title for p in spec_flow.inputs} + spec_flow_output_titles = {p.title for p in spec_flow.outputs} + + assert spec_flow_input_titles == {"x"} + + assert PluginCatchExceptionNode.EXCEPTION_NAME_OUTPUT_NAME in spec_flow_output_titles + assert PluginCatchExceptionNode.EXCEPTION_PAYLOAD_OUTPUT_NAME in spec_flow_output_titles + + +def test_agentspec_flow_properly_converts_to_wayflow( + spec_flow_with_catchexception: RuntimeFlow, +) -> None: + runtime_flow: RuntimeFlow = AgentSpecLoader({"flaky_tool": flaky_tool}).load_component( + spec_flow_with_catchexception + ) + + # check for correct control flow edge connection + control_flow_edges = { + (c.source_step.name, c.destination_step.name, c.source_branch) + for c in runtime_flow.control_flow_edges + } + assert ("catch_step", "success_end", Step.BRANCH_NEXT) in control_flow_edges or ( + "catch_step", + "success_end", + None, + ) in control_flow_edges + assert ( + "catch_step", + "error_end", + CatchExceptionStep.DEFAULT_EXCEPTION_BRANCH, + ) in control_flow_edges + + # check for correct data flow edge connection + data_flow_edges = { + (d.source_step.name, d.source_output, d.destination_step.name, d.destination_input) + for d in runtime_flow.data_flow_edges + } + assert ( + "catch_step", + CatchExceptionNode.DEFAULT_EXCEPTION_INFO_VALUE, + "error_end", + "error_info", + ) in data_flow_edges + + # 5. check for input/output + spec_flow_input_titles = {p.name for p in runtime_flow.input_descriptors} + spec_flow_output_titles = {p.name for p in runtime_flow.output_descriptors} + + assert spec_flow_input_titles == {"x"} + assert "error_info" in spec_flow_output_titles