Skip to content

Commit

Permalink
Ensure @start methods emit MethodExecutionStartedEvent
Browse files Browse the repository at this point in the history
Previously, `@start` methods triggered a `FlowStartedEvent` but did not
emit a `MethodExecutionStartedEvent`. This was fine for a single entry
point but caused ambiguity when multiple `@start` methods existed.

This commit (1) emits events for starting points, (2) adds tests
ensuring ordering, (3) adds more fields to events.
  • Loading branch information
vinibrsl committed Feb 12, 2025
1 parent 2fd7506 commit 1572bac
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 22 deletions.
67 changes: 47 additions & 20 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ def __new__(mcs, name, bases, dct):
or hasattr(attr_value, "__trigger_methods__")
or hasattr(attr_value, "__is_router__")
):

# Register start methods
if hasattr(attr_value, "__is_start_method__"):
start_methods.append(attr_name)
Expand Down Expand Up @@ -569,6 +568,28 @@ class StateWithId(state_type, FlowState): # type: ignore
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)

def _dump_state(self) -> Optional[Dict[str, Any]]:
"""
Dumps the current flow state as a dictionary.
This method converts the internal state into a serializable dictionary format,
ensuring compatibility with both dictionary and Pydantic BaseModel states.
Returns:
Optional[Dict[str, Any]]: The serialized state dictionary, or None if state is not available.
"""
if self._state is None:
return None

if isinstance(self._state, dict):
return self._state.copy()

if isinstance(self._state, BaseModel):
return self._state.model_dump()

logger.warning("Unsupported flow state type for dumping.")
return None

@property
def state(self) -> T:
return self._state
Expand Down Expand Up @@ -740,6 +761,7 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
event=FlowStartedEvent(
type="flow_started",
flow_name=self.__class__.__name__,
inputs=inputs,
),
)
self._log_flow_event(
Expand Down Expand Up @@ -803,6 +825,18 @@ async def _execute_start_method(self, start_method_name: str) -> None:
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
) -> Any:
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
self.event_emitter.send(
self,
event=MethodExecutionStartedEvent(
type="method_execution_started",
method_name=method_name,
flow_name=self.__class__.__name__,
params=dumped_params,
state=self._dump_state(),
),
)

result = (
await method(*args, **kwargs)
if asyncio.iscoroutinefunction(method)
Expand All @@ -812,6 +846,18 @@ async def _execute_method(
self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1
)

self.event_emitter.send(
self,
event=MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.__class__.__name__,
state=self._dump_state(),
result=result,
),
)

return result

async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
Expand Down Expand Up @@ -950,16 +996,6 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non
"""
try:
method = self._methods[listener_name]

self.event_emitter.send(
self,
event=MethodExecutionStartedEvent(
type="method_execution_started",
method_name=listener_name,
flow_name=self.__class__.__name__,
),
)

sig = inspect.signature(method)
params = list(sig.parameters.values())
method_params = [p for p in params if p.name != "self"]
Expand All @@ -971,15 +1007,6 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non
else:
listener_result = await self._execute_method(listener_name, method)

self.event_emitter.send(
self,
event=MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=listener_name,
flow_name=self.__class__.__name__,
),
)

# Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result)

Expand Down
8 changes: 6 additions & 2 deletions src/crewai/flow/flow_events.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Optional
from typing import Any, Dict, Optional


@dataclass
Expand All @@ -15,17 +15,21 @@ def __post_init__(self):

@dataclass
class FlowStartedEvent(Event):
pass
inputs: Optional[Dict[str, Any]] = None


@dataclass
class MethodExecutionStartedEvent(Event):
method_name: str
params: Optional[Dict[str, Any]] = None
state: Optional[Dict[str, Any]] = None


@dataclass
class MethodExecutionFinishedEvent(Event):
method_name: str
result: Any = None
state: Optional[Dict[str, Any]] = None


@dataclass
Expand Down
224 changes: 224 additions & 0 deletions tests/flow_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
"""Test Flow creation and execution basic functionality."""

import asyncio
from datetime import datetime

import pytest
from pydantic import BaseModel

from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)


def test_simple_sequential_flow():
Expand Down Expand Up @@ -398,3 +405,220 @@ def log_final_step(self):

# final_step should run after router_and
assert execution_order.index("log_final_step") > execution_order.index("router_and")


def test_unstructured_flow_event_emission():
"""Test that the correct events are emitted during unstructured flow
execution with all fields validated."""

class PoemFlow(Flow):
@start()
def prepare_flower(self):
self.state["flower"] = "roses"
return "foo"

@start()
def prepare_color(self):
self.state["color"] = "red"
return "bar"

@listen(prepare_color)
def write_first_sentence(self):
return f"{self.state["flower"]} are {self.state["color"]}"

@listen(write_first_sentence)
def finish_poem(self, first_sentence):
separator = self.state.get("separator", "\n")
return separator.join([first_sentence, "violets are blue"])

@listen(finish_poem)
def save_poem_to_database(self):
# A method without args/kwargs to ensure events are sent correctly
pass

event_log = []

def handle_event(_, event):
event_log.append(event)

flow = PoemFlow()
flow.event_emitter.connect(handle_event)
flow.kickoff(inputs={"separator": ", "})

assert isinstance(event_log[0], FlowStartedEvent)
assert event_log[0].flow_name == "PoemFlow"
assert event_log[0].inputs == {"separator": ", "}
assert isinstance(event_log[0].timestamp, datetime)

# Asserting for concurrent start method executions in a for loop as you
# can't guarantee ordering in asynchronous executions
for i in range(1, 5):
event = event_log[i]
assert isinstance(event.state, dict)
assert isinstance(event.state["id"], str)

if event.method_name == "prepare_flower":
if isinstance(event, MethodExecutionStartedEvent):
assert event.params == {}
assert event.state.get("separator") == ", "
elif isinstance(event, MethodExecutionFinishedEvent):
assert event.result == "foo"
assert event.state.get("flower") == "roses"
assert event.state.get("separator") == ", "
else:
assert False, "Unexpected event type for prepare_flower"
elif event.method_name == "prepare_color":
if isinstance(event, MethodExecutionStartedEvent):
assert event.params == {}
assert event.state.get("separator") == ", "
elif isinstance(event, MethodExecutionFinishedEvent):
assert event.result == "bar"
assert event.state.get("color") == "red"
assert event.state.get("separator") == ", "
else:
assert False, "Unexpected event type for prepare_color"
else:
assert False, f"Unexpected method {event.method_name} in prepare events"

assert isinstance(event_log[5], MethodExecutionStartedEvent)
assert event_log[5].method_name == "write_first_sentence"
assert event_log[5].params == {}
assert isinstance(event_log[5].state, dict)
assert event_log[5].state.get("flower") == "roses"
assert event_log[5].state.get("color") == "red"
assert event_log[5].state.get("separator") == ", "

assert isinstance(event_log[6], MethodExecutionFinishedEvent)
assert event_log[6].method_name == "write_first_sentence"
assert event_log[6].result == "roses are red"

assert isinstance(event_log[7], MethodExecutionStartedEvent)
assert event_log[7].method_name == "finish_poem"
assert event_log[7].params == {"_0": "roses are red"}
assert isinstance(event_log[7].state, dict)
assert event_log[7].state.get("flower") == "roses"
assert event_log[7].state.get("color") == "red"

assert isinstance(event_log[8], MethodExecutionFinishedEvent)
assert event_log[8].method_name == "finish_poem"
assert event_log[8].result == "roses are red, violets are blue"

assert isinstance(event_log[9], MethodExecutionStartedEvent)
assert event_log[9].method_name == "save_poem_to_database"
assert event_log[9].params == {}
assert isinstance(event_log[9].state, dict)
assert event_log[9].state.get("flower") == "roses"
assert event_log[9].state.get("color") == "red"

assert isinstance(event_log[10], MethodExecutionFinishedEvent)
assert event_log[10].method_name == "save_poem_to_database"
assert event_log[10].result is None

assert isinstance(event_log[11], FlowFinishedEvent)
assert event_log[11].flow_name == "PoemFlow"
assert event_log[11].result is None
assert isinstance(event_log[11].timestamp, datetime)


def test_structured_flow_event_emission():
"""Test that the correct events are emitted during structured flow
execution with all fields validated."""

class OnboardingState(BaseModel):
name: str = ""
sent: bool = False

class OnboardingFlow(Flow[OnboardingState]):
@start()
def user_signs_up(self):
self.state.sent = False

@listen(user_signs_up)
def send_welcome_message(self):
self.state.sent = True
return f"Welcome, {self.state.name}!"

event_log = []

def handle_event(_, event):
event_log.append(event)

flow = OnboardingFlow()
flow.event_emitter.connect(handle_event)
flow.kickoff(inputs={"name": "Anakin"})

assert isinstance(event_log[0], FlowStartedEvent)
assert event_log[0].flow_name == "OnboardingFlow"
assert event_log[0].inputs == {"name": "Anakin"}
assert isinstance(event_log[0].timestamp, datetime)

assert isinstance(event_log[1], MethodExecutionStartedEvent)
assert event_log[1].method_name == "user_signs_up"

assert isinstance(event_log[2], MethodExecutionFinishedEvent)
assert event_log[2].method_name == "user_signs_up"

assert isinstance(event_log[3], MethodExecutionStartedEvent)
assert event_log[3].method_name == "send_welcome_message"
assert event_log[3].params == {}
assert isinstance(event_log[3].state, dict)
assert event_log[3].state.get("sent") == False

assert isinstance(event_log[4], MethodExecutionFinishedEvent)
assert event_log[4].method_name == "send_welcome_message"
assert isinstance(event_log[4].state, dict)
assert event_log[4].state.get("sent") == True
assert event_log[4].result == "Welcome, Anakin!"

assert isinstance(event_log[5], FlowFinishedEvent)
assert event_log[5].flow_name == "OnboardingFlow"
assert event_log[5].result == "Welcome, Anakin!"
assert isinstance(event_log[5].timestamp, datetime)


def test_stateless_flow_event_emission():
"""Test that the correct events are emitted stateless during flow execution
with all fields validated."""

class StatelessFlow(Flow):
@start()
def init(self):
pass

@listen(init)
def process(self):
return "Deeds will not be less valiant because they are unpraised."

event_log = []

def handle_event(_, event):
event_log.append(event)

flow = StatelessFlow()
flow.event_emitter.connect(handle_event)
flow.kickoff()

assert isinstance(event_log[0], FlowStartedEvent)
assert event_log[0].flow_name == "StatelessFlow"
assert event_log[0].inputs is None
assert isinstance(event_log[0].timestamp, datetime)

assert isinstance(event_log[1], MethodExecutionStartedEvent)
assert event_log[1].method_name == "init"

assert isinstance(event_log[2], MethodExecutionFinishedEvent)
assert event_log[2].method_name == "init"

assert isinstance(event_log[3], MethodExecutionStartedEvent)
assert event_log[3].method_name == "process"

assert isinstance(event_log[4], MethodExecutionFinishedEvent)
assert event_log[4].method_name == "process"

assert isinstance(event_log[5], FlowFinishedEvent)
assert event_log[5].flow_name == "StatelessFlow"
assert (
event_log[5].result
== "Deeds will not be less valiant because they are unpraised."
)
assert isinstance(event_log[5].timestamp, datetime)

0 comments on commit 1572bac

Please sign in to comment.