From bf2d4b078a931f8fe8a2ffd068bd3240821725ce Mon Sep 17 00:00:00 2001 From: Vinicius Brasil Date: Wed, 12 Feb 2025 14:16:56 -0600 Subject: [PATCH] Perform shallow copy of state without serialization --- src/crewai/flow/flow.py | 28 +++++----------------------- src/crewai/flow/flow_events.py | 8 +++++--- tests/flow_test.py | 32 +++++++++++++++----------------- 3 files changed, 25 insertions(+), 43 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index ebe5fcf996..f1242a2bf8 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,4 +1,5 @@ import asyncio +import copy import inspect import logging from typing import ( @@ -568,27 +569,8 @@ class StateWithId(state_type, FlowState): # type: ignore f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) - def _dump_state(self) -> Optional[Dict[str, Any]]: - """ - Dumps the current flow state as a dictionary. - - This method converts the internal state into a serializable dictionary format, - ensuring compatibility with both dictionary and Pydantic BaseModel states. - - Returns: - Optional[Dict[str, Any]]: The serialized state dictionary, or None if state is not available. - """ - if self._state is None: - return None - - if isinstance(self._state, dict): - return self._state.copy() - - if isinstance(self._state, BaseModel): - return self._state.model_dump() - - logger.warning("Unsupported flow state type for dumping.") - return None + def _copy_state(self) -> T: + return copy.deepcopy(self._state) @property def state(self) -> T: @@ -833,7 +815,7 @@ async def _execute_method( method_name=method_name, flow_name=self.__class__.__name__, params=dumped_params, - state=self._dump_state(), + state=self._copy_state(), ), ) @@ -853,7 +835,7 @@ async def _execute_method( type="method_execution_finished", method_name=method_name, flow_name=self.__class__.__name__, - state=self._dump_state(), + state=self._copy_state(), result=result, ), ) diff --git a/src/crewai/flow/flow_events.py b/src/crewai/flow/flow_events.py index f66c66d23b..c8f9e96948 100644 --- a/src/crewai/flow/flow_events.py +++ b/src/crewai/flow/flow_events.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union + +from pydantic import BaseModel @dataclass @@ -21,15 +23,15 @@ class FlowStartedEvent(Event): @dataclass class MethodExecutionStartedEvent(Event): method_name: str + state: Union[Dict[str, Any], BaseModel] params: Optional[Dict[str, Any]] = None - state: Optional[Dict[str, Any]] = None @dataclass class MethodExecutionFinishedEvent(Event): method_name: str + state: Union[Dict[str, Any], BaseModel] result: Any = None - state: Optional[Dict[str, Any]] = None @dataclass diff --git a/tests/flow_test.py b/tests/flow_test.py index 2a89230955..c416d4a7d6 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -460,21 +460,21 @@ def handle_event(_, event): if event.method_name == "prepare_flower": if isinstance(event, MethodExecutionStartedEvent): assert event.params == {} - assert event.state.get("separator") == ", " + assert event.state["separator"] == ", " elif isinstance(event, MethodExecutionFinishedEvent): assert event.result == "foo" - assert event.state.get("flower") == "roses" - assert event.state.get("separator") == ", " + assert event.state["flower"] == "roses" + assert event.state["separator"] == ", " else: assert False, "Unexpected event type for prepare_flower" elif event.method_name == "prepare_color": if isinstance(event, MethodExecutionStartedEvent): assert event.params == {} - assert event.state.get("separator") == ", " + assert event.state["separator"] == ", " elif isinstance(event, MethodExecutionFinishedEvent): assert event.result == "bar" - assert event.state.get("color") == "red" - assert event.state.get("separator") == ", " + assert event.state["color"] == "red" + assert event.state["separator"] == ", " else: assert False, "Unexpected event type for prepare_color" else: @@ -484,9 +484,9 @@ def handle_event(_, event): assert event_log[5].method_name == "write_first_sentence" assert event_log[5].params == {} assert isinstance(event_log[5].state, dict) - assert event_log[5].state.get("flower") == "roses" - assert event_log[5].state.get("color") == "red" - assert event_log[5].state.get("separator") == ", " + assert event_log[5].state["flower"] == "roses" + assert event_log[5].state["color"] == "red" + assert event_log[5].state["separator"] == ", " assert isinstance(event_log[6], MethodExecutionFinishedEvent) assert event_log[6].method_name == "write_first_sentence" @@ -496,8 +496,8 @@ def handle_event(_, event): assert event_log[7].method_name == "finish_poem" assert event_log[7].params == {"_0": "roses are red"} assert isinstance(event_log[7].state, dict) - assert event_log[7].state.get("flower") == "roses" - assert event_log[7].state.get("color") == "red" + assert event_log[7].state["flower"] == "roses" + assert event_log[7].state["color"] == "red" assert isinstance(event_log[8], MethodExecutionFinishedEvent) assert event_log[8].method_name == "finish_poem" @@ -507,8 +507,8 @@ def handle_event(_, event): assert event_log[9].method_name == "save_poem_to_database" assert event_log[9].params == {} assert isinstance(event_log[9].state, dict) - assert event_log[9].state.get("flower") == "roses" - assert event_log[9].state.get("color") == "red" + assert event_log[9].state["flower"] == "roses" + assert event_log[9].state["color"] == "red" assert isinstance(event_log[10], MethodExecutionFinishedEvent) assert event_log[10].method_name == "save_poem_to_database" @@ -561,13 +561,11 @@ def handle_event(_, event): assert isinstance(event_log[3], MethodExecutionStartedEvent) assert event_log[3].method_name == "send_welcome_message" assert event_log[3].params == {} - assert isinstance(event_log[3].state, dict) - assert event_log[3].state.get("sent") == False + assert getattr(event_log[3].state, "sent") == False assert isinstance(event_log[4], MethodExecutionFinishedEvent) assert event_log[4].method_name == "send_welcome_message" - assert isinstance(event_log[4].state, dict) - assert event_log[4].state.get("sent") == True + assert getattr(event_log[4].state, "sent") == True assert event_log[4].result == "Welcome, Anakin!" assert isinstance(event_log[5], FlowFinishedEvent)