Skip to content

Commit

Permalink
Perform shallow copy of state without serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
vinibrsl committed Feb 12, 2025
1 parent 1572bac commit bf2d4b0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 43 deletions.
28 changes: 5 additions & 23 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import inspect
import logging
from typing import (
Expand Down Expand Up @@ -568,27 +569,8 @@ class StateWithId(state_type, FlowState): # type: ignore
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)

def _dump_state(self) -> Optional[Dict[str, Any]]:
"""
Dumps the current flow state as a dictionary.
This method converts the internal state into a serializable dictionary format,
ensuring compatibility with both dictionary and Pydantic BaseModel states.
Returns:
Optional[Dict[str, Any]]: The serialized state dictionary, or None if state is not available.
"""
if self._state is None:
return None

if isinstance(self._state, dict):
return self._state.copy()

if isinstance(self._state, BaseModel):
return self._state.model_dump()

logger.warning("Unsupported flow state type for dumping.")
return None
def _copy_state(self) -> T:
return copy.deepcopy(self._state)

@property
def state(self) -> T:
Expand Down Expand Up @@ -833,7 +815,7 @@ async def _execute_method(
method_name=method_name,
flow_name=self.__class__.__name__,
params=dumped_params,
state=self._dump_state(),
state=self._copy_state(),
),
)

Expand All @@ -853,7 +835,7 @@ async def _execute_method(
type="method_execution_finished",
method_name=method_name,
flow_name=self.__class__.__name__,
state=self._dump_state(),
state=self._copy_state(),
result=result,
),
)
Expand Down
8 changes: 5 additions & 3 deletions src/crewai/flow/flow_events.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel


@dataclass
Expand All @@ -21,15 +23,15 @@ class FlowStartedEvent(Event):
@dataclass
class MethodExecutionStartedEvent(Event):
method_name: str
state: Union[Dict[str, Any], BaseModel]
params: Optional[Dict[str, Any]] = None
state: Optional[Dict[str, Any]] = None


@dataclass
class MethodExecutionFinishedEvent(Event):
method_name: str
state: Union[Dict[str, Any], BaseModel]
result: Any = None
state: Optional[Dict[str, Any]] = None


@dataclass
Expand Down
32 changes: 15 additions & 17 deletions tests/flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,21 +460,21 @@ def handle_event(_, event):
if event.method_name == "prepare_flower":
if isinstance(event, MethodExecutionStartedEvent):
assert event.params == {}
assert event.state.get("separator") == ", "
assert event.state["separator"] == ", "
elif isinstance(event, MethodExecutionFinishedEvent):
assert event.result == "foo"
assert event.state.get("flower") == "roses"
assert event.state.get("separator") == ", "
assert event.state["flower"] == "roses"
assert event.state["separator"] == ", "
else:
assert False, "Unexpected event type for prepare_flower"
elif event.method_name == "prepare_color":
if isinstance(event, MethodExecutionStartedEvent):
assert event.params == {}
assert event.state.get("separator") == ", "
assert event.state["separator"] == ", "
elif isinstance(event, MethodExecutionFinishedEvent):
assert event.result == "bar"
assert event.state.get("color") == "red"
assert event.state.get("separator") == ", "
assert event.state["color"] == "red"
assert event.state["separator"] == ", "
else:
assert False, "Unexpected event type for prepare_color"
else:
Expand All @@ -484,9 +484,9 @@ def handle_event(_, event):
assert event_log[5].method_name == "write_first_sentence"
assert event_log[5].params == {}
assert isinstance(event_log[5].state, dict)
assert event_log[5].state.get("flower") == "roses"
assert event_log[5].state.get("color") == "red"
assert event_log[5].state.get("separator") == ", "
assert event_log[5].state["flower"] == "roses"
assert event_log[5].state["color"] == "red"
assert event_log[5].state["separator"] == ", "

assert isinstance(event_log[6], MethodExecutionFinishedEvent)
assert event_log[6].method_name == "write_first_sentence"
Expand All @@ -496,8 +496,8 @@ def handle_event(_, event):
assert event_log[7].method_name == "finish_poem"
assert event_log[7].params == {"_0": "roses are red"}
assert isinstance(event_log[7].state, dict)
assert event_log[7].state.get("flower") == "roses"
assert event_log[7].state.get("color") == "red"
assert event_log[7].state["flower"] == "roses"
assert event_log[7].state["color"] == "red"

assert isinstance(event_log[8], MethodExecutionFinishedEvent)
assert event_log[8].method_name == "finish_poem"
Expand All @@ -507,8 +507,8 @@ def handle_event(_, event):
assert event_log[9].method_name == "save_poem_to_database"
assert event_log[9].params == {}
assert isinstance(event_log[9].state, dict)
assert event_log[9].state.get("flower") == "roses"
assert event_log[9].state.get("color") == "red"
assert event_log[9].state["flower"] == "roses"
assert event_log[9].state["color"] == "red"

assert isinstance(event_log[10], MethodExecutionFinishedEvent)
assert event_log[10].method_name == "save_poem_to_database"
Expand Down Expand Up @@ -561,13 +561,11 @@ def handle_event(_, event):
assert isinstance(event_log[3], MethodExecutionStartedEvent)
assert event_log[3].method_name == "send_welcome_message"
assert event_log[3].params == {}
assert isinstance(event_log[3].state, dict)
assert event_log[3].state.get("sent") == False
assert getattr(event_log[3].state, "sent") == False

assert isinstance(event_log[4], MethodExecutionFinishedEvent)
assert event_log[4].method_name == "send_welcome_message"
assert isinstance(event_log[4].state, dict)
assert event_log[4].state.get("sent") == True
assert getattr(event_log[4].state, "sent") == True
assert event_log[4].result == "Welcome, Anakin!"

assert isinstance(event_log[5], FlowFinishedEvent)
Expand Down

0 comments on commit bf2d4b0

Please sign in to comment.