diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a8461d0..48e5b237 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Non-decorated direct functions continue to work with `cancel_on_interruption=True` (the default behavior). +### Changed + +- When transitioning from one node to the next, any functions in the previous + node that are absent in the new node now get "carried over", in that they're + still passed to the LLM, but in a "deactivated" state. Deactivation involves: + - Adding a special prefix to the function description + - Adding special context messages telling the LLM which functions to avoid + calling + + Providing LLMs with deactivated functions helps them understand historical + context that might contain references to previously-active functions. + ## [0.0.22] - 2025-11-18 ### Added diff --git a/examples/food_ordering.py b/examples/food_ordering.py index 21f908d6..b61cf4ee 100644 --- a/examples/food_ordering.py +++ b/examples/food_ordering.py @@ -139,7 +139,7 @@ async def choose_sushi(args: FlowArgs, flow_manager: FlowManager) -> tuple[None, role_messages=[ { "role": "system", - "content": "You are an order-taking assistant. You must ALWAYS use the available functions to progress the conversation. This is a phone conversation and your responses will be converted to audio. Keep the conversation friendly, casual, and polite. Avoid outputting special characters and emojis.", + "content": "You are an order-taking assistant. You must ALWAYS use the available functions to progress the conversation. When you've decided to call a function to progress the conversation, do not also respond; the function call by itself is enough. This is a phone conversation and your responses will be converted to audio. Keep the conversation friendly, casual, and polite. Avoid outputting special characters and emojis.", } ], task_messages=[ diff --git a/examples/restaurant_reservation.py b/examples/restaurant_reservation.py index 39c2c848..8c945928 100644 --- a/examples/restaurant_reservation.py +++ b/examples/restaurant_reservation.py @@ -235,7 +235,7 @@ def create_confirmation_node() -> NodeConfig: "task_messages": [ { "role": "system", - "content": "Confirm the reservation details and ask if they need anything else.", + "content": "Confirm the reservation details and ask if they need anything else. If they don't, go ahead and end the conversation by calling the appropriate function.", } ], "functions": [end_conversation_schema], diff --git a/examples/restaurant_reservation_direct_functions.py b/examples/restaurant_reservation_direct_functions.py index 064c3362..2321b88b 100644 --- a/examples/restaurant_reservation_direct_functions.py +++ b/examples/restaurant_reservation_direct_functions.py @@ -150,7 +150,7 @@ async def check_availability( Check availability for requested time. Args: - time (str): Requested reservation time in "HH:MM AM/PM" format. Must be between 5 PM and 10 PM. + time (str): Requested reservation time in "HH:MM AM/PM" format, with NO leading 0s (e.g. "6:00 PM"). Must be between 5 PM and 10 PM. party_size (int): Number of people in the party. """ # Check availability with mock API @@ -163,8 +163,10 @@ async def check_availability( # Next node: confirmation or no availability if is_available: + logger.debug("Time is available, transitioning to confirmation node") next_node = create_confirmation_node() else: + logger.debug(f"Time not available, storing alternatives: {alternative_times}") next_node = create_no_availability_node(alternative_times) return result, next_node @@ -219,7 +221,7 @@ def create_confirmation_node() -> NodeConfig: "task_messages": [ { "role": "system", - "content": "Confirm the reservation details and ask if they need anything else.", + "content": "Confirm the reservation details and ask if they need anything else. If they don't, go ahead and end the conversation by calling the appropriate function.", } ], "functions": [end_conversation], diff --git a/examples/warm_transfer.py b/examples/warm_transfer.py index 5cb307af..37ff46d4 100644 --- a/examples/warm_transfer.py +++ b/examples/warm_transfer.py @@ -732,8 +732,13 @@ async def on_participant_left( # Prepare hold music args flow_manager.state["hold_music_args"] = { - "script_path": Path(__file__).parent.parent / "assets" / "hold_music" / "hold_music.py", + "script_path": Path(__file__).parent.parent + / "examples" + / "assets" + / "hold_music" + / "hold_music.py", "wav_file_path": Path(__file__).parent.parent + / "examples" / "assets" / "hold_music" / "hold_music.wav", diff --git a/pyproject.toml b/pyproject.toml index 51d849dc..fff86f20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ "Topic :: Multimedia :: Sound/Audio", ] dependencies = [ - "pipecat-ai>=0.0.85", + "pipecat-ai>=0.0.102", "loguru~=0.7.2", "docstring_parser~=0.16" ] diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index a8c61856..0fd9f63f 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -35,13 +35,14 @@ from pipecat.frames.frames import ( FunctionCallResultProperties, LLMMessagesAppendFrame, + LLMMessagesTransformFrame, LLMMessagesUpdateFrame, LLMRunFrame, LLMSetToolsFrame, ) from pipecat.pipeline.llm_switcher import LLMSwitcher from pipecat.pipeline.task import PipelineTask -from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.llm_service import FunctionCallParams from pipecat.transports.base_transport import BaseTransport @@ -154,6 +155,18 @@ def __init__( self._transport = transport self._global_functions = global_functions or [] + # Cache global function names as an optimization when detecting deactivated functions later + self._global_function_names: Set[str] = set() + for func in self._global_functions: + if callable(func): + wrapper = FlowsDirectFunctionWrapper(function=func) + self._global_function_names.add(wrapper.name) + elif isinstance(func, FlowsFunctionSchema): + self._global_function_names.add(func.name) + else: + name = create_adapter(llm, context_aggregator).get_function_name(func) + self._global_function_names.add(name) + # Set up static or dynamic mode if flow_config: warnings.warn( @@ -170,7 +183,7 @@ def __init__( logger.debug("Initialized in dynamic mode") self._state: Dict[str, Any] = {} # Internal state storage - self._current_functions: Set[str] = set() # Track registered functions + self._current_functions: Dict[str, FlowsFunctionSchema] = {} # Track registered functions self._current_node: Optional[str] = None self._showed_deprecation_warning_for_transition_fields = False @@ -645,17 +658,93 @@ def _lookup_function(self, func_name: str) -> Callable: raise FlowError(error_message) + # Prefix used to mark deactivated function descriptions + _DEACTIVATED_FUNCTION_PREFIX = "[DEACTIVATED] " + + # Prefix used to mark "deactivated functions" messages in context + _DEACTIVATED_FUNCTIONS_MESSAGE_PREFIX = "[DEACTIVATED_FUNCTIONS_MESSAGE] " + + def _add_deactivated_prefix(self, description: str) -> str: + """Add the [DEACTIVATED] prefix to a description if not already present.""" + if not description.startswith(self._DEACTIVATED_FUNCTION_PREFIX): + return f"{self._DEACTIVATED_FUNCTION_PREFIX}{description}" + return description + + def _remove_deactivated_functions_messages( + self, messages: List[LLMContextMessage] + ) -> List[LLMContextMessage]: + """Transform function that removes deactivated function warning messages. + + Suitable for use with LLMMessagesTransformFrame. Filters out any messages + with the deactivated functions message prefix, ensuring only the most + recent warning is present in the context. + """ + return [ + msg + for msg in messages + if not ( + isinstance(msg.get("content"), str) + and msg["content"].startswith(self._DEACTIVATED_FUNCTIONS_MESSAGE_PREFIX) + ) + ] + + def _create_deactivated_function_schema( + self, original_schema: FlowsFunctionSchema + ) -> FlowsFunctionSchema: + """Create a function schema for deactivated functions with a dummy handler. + + Creates a copy of the original function schema but replaces the handler + with a dummy that returns an error message indicating the function is + deactivated. This is used when carrying over functions from a previous + node during APPEND context strategy transitions. + + Args: + original_schema: The original FlowsFunctionSchema to create a + deactivated version of. + + Returns: + A new FlowsFunctionSchema with the same name/description/properties + but with a dummy handler that returns a deactivation message. + """ + + async def deactivated_handler(args: FlowArgs, flow_manager: "FlowManager") -> FlowResult: + """Dummy handler for deactivated functions.""" + return { + "status": "error", + "error": ( + f"Function '{original_schema.name}' is deactivated and should not be called. " + f"Please use the currently active functions instead." + ), + } + + # Only add prefix if not already deactivated (idempotent) + description = self._add_deactivated_prefix(original_schema.description) + + return FlowsFunctionSchema( + name=original_schema.name, + description=description, + properties=original_schema.properties, + required=original_schema.required, + handler=deactivated_handler, + cancel_on_interruption=True, + transition_to=None, + transition_callback=None, + ) + async def _register_function( self, name: str, - new_functions: Set[str], handler: Optional[Callable | FlowsDirectFunctionWrapper], transition_to: Optional[str] = None, transition_callback: Optional[Callable] = None, *, cancel_on_interruption: bool = True, ) -> None: - """Register a function with the LLM if not already registered. + """Register a function with the LLM. + + Always registers the function, replacing any existing registration. + This ensures that handlers are properly updated when a function is + deactivated or reactivated upon transitioning nodes. Args: name: Name of the function to register @@ -663,37 +752,34 @@ async def _register_function( If string starts with '__function__:', extracts the function name after the prefix. transition_to: Optional node to transition to (static flows) transition_callback: Optional transition callback (dynamic flows) - new_functions: Set to track newly registered functions for this node cancel_on_interruption: Whether to cancel this function call when an interruption occurs. Defaults to True. Raises: FlowError: If function registration fails """ - if name not in self._current_functions: - try: - # Handle special token format (e.g. "__function__:function_name") - if isinstance(handler, str) and handler.startswith("__function__:"): - func_name = handler.split(":")[1] - handler = self._lookup_function(func_name) - - # Create transition function - transition_func = await self._create_transition_func( - name, handler, transition_to, transition_callback - ) + try: + # Handle special token format (e.g. "__function__:function_name") + if isinstance(handler, str) and handler.startswith("__function__:"): + func_name = handler.split(":")[1] + handler = self._lookup_function(func_name) + + # Create transition function + transition_func = await self._create_transition_func( + name, handler, transition_to, transition_callback + ) - # Register function with LLM (or LLMSwitcher) - self._llm.register_function( - name, - transition_func, - cancel_on_interruption=cancel_on_interruption, - ) + # Register function with LLM (or LLMSwitcher) + self._llm.register_function( + name, + transition_func, + cancel_on_interruption=cancel_on_interruption, + ) - new_functions.add(name) - logger.debug(f"Registered function: {name}") - except Exception as e: - logger.error(f"Failed to register function {name}: {str(e)}") - raise FlowError(f"Function registration failed: {str(e)}") from e + logger.debug(f"Registered function: {name}") + except Exception as e: + logger.error(f"Failed to register function {name}: {str(e)}") + raise FlowError(f"Function registration failed: {str(e)}") from e async def set_node_from_config(self, node_config: NodeConfig) -> None: """Set up a new conversation node and transition to it. @@ -790,7 +876,7 @@ async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: # Register functions and prepare tools tools: List[FlowsFunctionSchema | FlowsDirectFunctionWrapper] = [] - new_functions: Set[str] = set() + new_functions: Dict[str, FlowsFunctionSchema] = {} # Get functions list with default empty list if not provided functions_list = node_config.get("functions", []) @@ -803,12 +889,15 @@ async def register_function_schema(schema: FlowsFunctionSchema): tools.append(schema) await self._register_function( name=schema.name, - new_functions=new_functions, handler=schema.handler, transition_to=schema.transition_to, transition_callback=schema.transition_callback, cancel_on_interruption=schema.cancel_on_interruption, ) + # Track that we registered this function in the current node. + # This will be useful if we need to carry it over as a + # "deactivated" function into the next node. + new_functions[schema.name] = schema async def register_direct_function(func): """Helper to register a single direct function.""" @@ -816,12 +905,24 @@ async def register_direct_function(func): tools.append(direct_function) await self._register_function( name=direct_function.name, - new_functions=new_functions, handler=direct_function, transition_to=None, transition_callback=None, cancel_on_interruption=direct_function.cancel_on_interruption, ) + # Track that we registered this function in the current node. + # This will be useful if we need to carry it over as a + # "deactivated" function into the next node. Track it as a + # FlowsFunctionSchema for ease of editing to be "deactivated". + schema = FlowsFunctionSchema( + name=direct_function.name, + description=direct_function.description, + properties=direct_function.properties, + required=direct_function.required, + handler=None, # Will be replaced if deactivated + cancel_on_interruption=direct_function.cancel_on_interruption, + ) + new_functions[direct_function.name] = schema for func_config in functions_list: # Handle direct functions @@ -847,6 +948,38 @@ async def register_direct_function(func): ) await register_function_schema(schema) + # Determine effective context strategy for this transition + effective_strategy = node_config.get("context_strategy") or self._context_strategy + + # For APPEND strategy on non-first nodes, carry over functions from + # the previous node that aren't in the current node (but marking + # them as "deactivated"), since there may be historical context + # messages that call or reference them. This helps LLMs better + # understand their context. + deactivated_function_names: List[str] = [] + if ( + self._current_node is not None + and effective_strategy.strategy == ContextStrategy.APPEND + ): + # Find functions from previous node that aren't in the new node + # (excluding global functions, which are always present) + new_function_names = set(new_functions.keys()) + for prev_name, prev_schema in self._current_functions.items(): + if ( + prev_name not in new_function_names + and prev_name not in self._global_function_names + ): + # Create deactivated stub with same schema but dummy handler + deactivated_schema = self._create_deactivated_function_schema(prev_schema) + await register_function_schema(deactivated_schema) + deactivated_function_names.append(prev_name) + + if deactivated_function_names: + logger.debug( + f"Carrying over {len(deactivated_function_names)} deactivated functions: " + f"{', '.join(deactivated_function_names)}" + ) + # Create ToolsSchema with standard function schemas standard_functions = [] for tool in tools: @@ -858,12 +991,42 @@ async def register_direct_function(func): standard_functions, original_configs=functions_list ) + # Remove any prior "deactivated functions" messages from context + # to prevent buildup (only one up-to-date one is needed). Only + # needed for APPEND strategy on non-first nodes, since + # RESET/RESET_WITH_SUMMARY replace the entire context. + if ( + self._current_node is not None + and effective_strategy.strategy == ContextStrategy.APPEND + ): + await self._task.queue_frames( + [ + LLMMessagesTransformFrame( + transform=self._remove_deactivated_functions_messages + ) + ] + ) + + task_messages = list(node_config["task_messages"]) # Copy to avoid mutating original + if deactivated_function_names: + deactivated_functions_message = { + "role": "system", + "content": ( + f"{self._DEACTIVATED_FUNCTIONS_MESSAGE_PREFIX}" + f"IMPORTANT: The following functions are currently DEACTIVATED and should NOT be " + f"called: {', '.join(deactivated_function_names)}. If you call them, they " + f"will return an error. ALL OTHER AVAILABLE FUNCTIONS should be considered " + "currently ACTIVE and callable." + ), + } + task_messages.insert(0, deactivated_functions_message) + # Update LLM context await self._update_llm_context( role_messages=node_config.get("role_messages"), - task_messages=node_config["task_messages"], + task_messages=task_messages, functions=formatted_tools, - strategy=node_config.get("context_strategy"), + strategy=effective_strategy, ) logger.debug("Updated LLM context") @@ -904,7 +1067,7 @@ async def _update_llm_context( role_messages: Optional[List[dict]], task_messages: List[dict], functions: List[dict], - strategy: Optional[ContextStrategyConfig] = None, + strategy: ContextStrategyConfig, ) -> None: """Update LLM context with new messages and functions. @@ -912,7 +1075,7 @@ async def _update_llm_context( role_messages: Optional role messages to add to context. task_messages: Task messages to add to context. functions: New functions to make available. - strategy: Optional context update configuration. + strategy: Context strategy configuration for this transition. Raises: FlowError: If context update fails. @@ -927,7 +1090,7 @@ async def _update_llm_context( if role_messages: messages.extend(role_messages) - update_config = strategy or self._context_strategy + update_config = strategy if ( update_config.strategy == ContextStrategy.RESET_WITH_SUMMARY diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index be9b7a66..6dce145d 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -313,15 +313,11 @@ def flows_direct_function(*, cancel_on_interruption: bool = True) -> Callable[[C Returns: A decorator that attaches the metadata to the function. - Example: + Example:: + @flows_direct_function(cancel_on_interruption=False) async def long_running_task(flow_manager: FlowManager, query: str): - '''Perform a long-running task that should not be cancelled on interruption. - - Args: - query: The query to process. - ''' - # ... implementation + # ... docstring and implementation return {"status": "complete"}, None """ diff --git a/tests/test_context_strategies.py b/tests/test_context_strategies.py index a9b807d0..bcf0957b 100644 --- a/tests/test_context_strategies.py +++ b/tests/test_context_strategies.py @@ -80,6 +80,14 @@ async def test_context_strategy_config_validation(self): with self.assertRaises(ValueError): ContextStrategyConfig(strategy=ContextStrategy.RESET_WITH_SUMMARY) + def _get_all_queued_frames(self): + """Helper to collect all frames from all queue_frames calls.""" + all_frames = [] + for call in self.mock_task.queue_frames.call_args_list: + frames = call[0][0] + all_frames.extend(frames) + return all_frames + async def test_default_strategy(self): """Test default context strategy (APPEND).""" flow_manager = FlowManager( @@ -91,8 +99,7 @@ async def test_default_strategy(self): # First node should use UpdateFrame regardless of strategy await flow_manager._set_node("first", self.sample_node) - first_call = self.mock_task.queue_frames.call_args_list[0] - first_frames = first_call[0][0] + first_frames = self._get_all_queued_frames() self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in first_frames)) # Reset mock @@ -100,8 +107,7 @@ async def test_default_strategy(self): # Subsequent node should use AppendFrame with default strategy await flow_manager._set_node("second", self.sample_node) - second_call = self.mock_task.queue_frames.call_args_list[0] - second_frames = second_call[0][0] + second_frames = self._get_all_queued_frames() self.assertTrue(any(isinstance(f, LLMMessagesAppendFrame) for f in second_frames)) async def test_reset_strategy(self): @@ -120,8 +126,7 @@ async def test_reset_strategy(self): # Second node should use UpdateFrame with RESET strategy await flow_manager._set_node("second", self.sample_node) - second_call = self.mock_task.queue_frames.call_args_list[0] - second_frames = second_call[0][0] + second_frames = self._get_all_queued_frames() self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in second_frames)) async def test_reset_with_summary_success(self): @@ -148,8 +153,7 @@ async def test_reset_with_summary_success(self): await flow_manager._set_node("second", self.sample_node) # Verify summary was included in context update - second_call = self.mock_task.queue_frames.call_args_list[0] - second_frames = second_call[0][0] + second_frames = self._get_all_queued_frames() update_frame = next(f for f in second_frames if isinstance(f, LLMMessagesUpdateFrame)) self.assertTrue(any(mock_summary in str(m) for m in update_frame.messages)) @@ -175,9 +179,8 @@ async def test_reset_with_summary_timeout(self): await flow_manager._set_node("second", self.sample_node) - # Verify UpdateFrame was used (APPEND behavior) - second_call = self.mock_task.queue_frames.call_args_list[0] - second_frames = second_call[0][0] + # Verify AppendFrame was used (fallback to APPEND behavior on timeout) + second_frames = self._get_all_queued_frames() self.assertTrue(any(isinstance(f, LLMMessagesAppendFrame) for f in second_frames)) async def test_provider_specific_summary_formatting(self): @@ -234,8 +237,7 @@ async def test_node_level_strategy_override(self): await flow_manager._set_node("second", node_with_strategy) # Verify UpdateFrame was used (RESET behavior) despite global APPEND - second_call = self.mock_task.queue_frames.call_args_list[0] - second_frames = second_call[0][0] + second_frames = self._get_all_queued_frames() self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in second_frames)) async def test_summary_generation_content(self): @@ -299,12 +301,1007 @@ async def test_context_structure_after_summary(self): await flow_manager._set_node("second", new_node) # Verify context structure - update_call = self.mock_task.queue_frames.call_args_list[0] - update_frames = update_call[0][0] - messages_frame = next(f for f in update_frames if isinstance(f, LLMMessagesUpdateFrame)) + all_frames = self._get_all_queued_frames() + messages_frame = next( + (f for f in all_frames if isinstance(f, LLMMessagesUpdateFrame)), None + ) + self.assertIsNotNone(messages_frame, "LLMMessagesUpdateFrame should be queued") # Verify order: summary message, then new task messages self.assertTrue(mock_summary in str(messages_frame.messages[0])) self.assertEqual( messages_frame.messages[1]["content"], new_node["task_messages"][0]["content"] ) + + +class TestDeactivatedFunctions(unittest.IsolatedAsyncioTestCase): + """Test suite for deactivated function carry-over during APPEND transitions. + + Tests functionality including: + - Carrying over deactivated functions with dummy handlers + - Injecting warning task messages + - Skipping deactivation for RESET strategies + - Global functions are not deactivated + """ + + async def asyncSetUp(self): + """Set up test fixtures before each test.""" + self.mock_task = AsyncMock() + self.mock_task.event_handler = Mock() + self.mock_task.set_reached_downstream_filter = Mock() + + self.mock_llm = OpenAILLMService(api_key="") + self.mock_llm.register_function = MagicMock() + + self.mock_context_aggregator = MagicMock() + self.mock_context_aggregator.user = MagicMock() + self.mock_context_aggregator.user.return_value = MagicMock() + self.mock_context_aggregator.user.return_value._context = MagicMock() + + def _get_all_queued_frames(self): + """Helper to collect all frames from all queue_frames calls.""" + all_frames = [] + for call in self.mock_task.queue_frames.call_args_list: + frames = call[0][0] + all_frames.extend(frames) + return all_frames + + async def test_deactivated_functions_carried_over_on_append(self): + """Test that functions from previous node are carried over as deactivated on APPEND.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + # First node with function_a and function_b + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Function A", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "function_b", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Function B", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + + # Verify both functions are in current_functions + self.assertIn("function_a", flow_manager._current_functions) + self.assertIn("function_b", flow_manager._current_functions) + + self.mock_task.queue_frames.reset_mock() + self.mock_llm.register_function.reset_mock() + + # Second node with only function_c (function_a and function_b should be deactivated) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_c", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Function C", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("second", second_node) + + # Verify that function_a and function_b were registered as deactivated + registered_function_names = [ + call[0][0] for call in self.mock_llm.register_function.call_args_list + ] + self.assertIn("function_a", registered_function_names) + self.assertIn("function_b", registered_function_names) + self.assertIn("function_c", registered_function_names) + + async def test_deactivated_warning_message_injected(self): + """Test that warning message is injected when functions are deactivated.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + # First node with a function + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "old_function", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Old function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.mock_task.queue_frames.reset_mock() + + # Second node without the old function + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [], + } + + await flow_manager._set_node("second", second_node) + + # Verify warning message was injected + all_frames = self._get_all_queued_frames() + messages_frame = next( + (f for f in all_frames if isinstance(f, LLMMessagesAppendFrame)), None + ) + self.assertIsNotNone(messages_frame, "LLMMessagesAppendFrame should be queued") + + # Check that warning message is present + warning_found = any( + "old_function" in str(m.get("content", "")) + and "deactivated" in str(m.get("content", "")).lower() + for m in messages_frame.messages + ) + self.assertTrue(warning_found, "Warning message about deactivated functions not found") + + async def test_no_deactivation_on_reset_strategy(self): + """Test that functions are not carried over as deactivated on RESET strategy.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.RESET), + ) + await flow_manager.initialize() + + # First node with a function + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "old_function", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Old function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.mock_llm.register_function.reset_mock() + + # Second node without the old function + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "new_function", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "New function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("second", second_node) + + # Verify that old_function was NOT registered in second transition + registered_function_names = [ + call[0][0] for call in self.mock_llm.register_function.call_args_list + ] + self.assertNotIn("old_function", registered_function_names) + self.assertIn("new_function", registered_function_names) + + async def test_global_functions_not_deactivated(self): + """Test that global functions are never carried over as deactivated.""" + from pipecat_flows.types import FlowsFunctionSchema + + global_func = FlowsFunctionSchema( + name="global_function", + description="A global function", + properties={}, + required=[], + handler=AsyncMock(return_value={"status": "success"}), + ) + + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + global_functions=[global_func], + ) + await flow_manager.initialize() + + # First node with an additional function + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "node_function", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Node-specific function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.mock_task.queue_frames.reset_mock() + + # Second node without node_function (but global_function should still be active) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [], + } + + await flow_manager._set_node("second", second_node) + + # Verify warning message mentions node_function but NOT global_function + all_frames = self._get_all_queued_frames() + messages_frame = next( + (f for f in all_frames if isinstance(f, LLMMessagesAppendFrame)), None + ) + self.assertIsNotNone(messages_frame, "LLMMessagesAppendFrame should be queued") + + for msg in messages_frame.messages: + content = str(msg.get("content", "")) + if "deactivated" in content.lower(): + self.assertIn("node_function", content) + self.assertNotIn("global_function", content) + + async def test_deactivated_function_returns_error(self): + """Test that calling a deactivated function returns an error result.""" + from pipecat_flows.types import FlowsFunctionSchema + + # Create the schema directly to test the helper + original_schema = FlowsFunctionSchema( + name="test_function", + description="Test function", + properties={"param": {"type": "string"}}, + required=["param"], + handler=AsyncMock(return_value={"status": "success"}), + ) + + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) + + deactivated_schema = flow_manager._create_deactivated_function_schema(original_schema) + + # Verify schema properties are preserved (with updated description) + self.assertEqual(deactivated_schema.name, original_schema.name) + self.assertTrue(deactivated_schema.description.startswith("[DEACTIVATED]")) + self.assertIn(original_schema.description, deactivated_schema.description) + self.assertEqual(deactivated_schema.properties, original_schema.properties) + self.assertEqual(deactivated_schema.required, original_schema.required) + + # Verify deactivated handler returns error (modern signature with flow_manager) + result = await deactivated_schema.handler({}, flow_manager) + self.assertEqual(result["status"], "error") + self.assertIn("deactivated", result["error"]) + self.assertIn("test_function", result["error"]) + + async def test_node_level_reset_prevents_deactivation(self): + """Test that node-level RESET strategy prevents function deactivation.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + # First node with a function + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "old_function", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Old function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.mock_llm.register_function.reset_mock() + + # Second node with RESET strategy override + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [], + "context_strategy": ContextStrategyConfig(strategy=ContextStrategy.RESET), + } + + await flow_manager._set_node("second", second_node) + + # Verify that old_function was NOT registered (no deactivation on RESET) + registered_function_names = [ + call[0][0] for call in self.mock_llm.register_function.call_args_list + ] + self.assertNotIn("old_function", registered_function_names) + + async def test_deactivated_functions_persist_across_multiple_transitions(self): + """Test that deactivated functions persist through multiple node transitions.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + # First node with function_a + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Function A", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.mock_task.queue_frames.reset_mock() + self.mock_llm.register_function.reset_mock() + + # Second node with function_b only (function_a becomes deactivated) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_b", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Function B", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("second", second_node) + + # Verify function_a was deactivated + all_frames = self._get_all_queued_frames() + messages_frame = next( + (f for f in all_frames if isinstance(f, LLMMessagesAppendFrame)), None + ) + self.assertIsNotNone(messages_frame, "LLMMessagesAppendFrame should be queued") + warning_found = any( + "function_a" in str(msg.get("content", "")) + and "deactivated" in str(msg.get("content", "")).lower() + for msg in messages_frame.messages + ) + self.assertTrue(warning_found, "function_a should be marked as deactivated") + + self.mock_task.queue_frames.reset_mock() + self.mock_llm.register_function.reset_mock() + + # Third node with function_c only (function_a AND function_b should be deactivated) + third_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Third task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_c", + "handler": AsyncMock(return_value={"status": "success"}), + "description": "Function C", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("third", third_node) + + # Verify both function_a and function_b are mentioned as deactivated + all_frames = self._get_all_queued_frames() + messages_frame = next( + (f for f in all_frames if isinstance(f, LLMMessagesAppendFrame)), None + ) + self.assertIsNotNone(messages_frame, "LLMMessagesAppendFrame should be queued") + + warning_content = "" + for msg in messages_frame.messages: + content = str(msg.get("content", "")) + if "deactivated" in content.lower(): + warning_content = content + break + + self.assertIn("function_a", warning_content, "function_a should persist as deactivated") + self.assertIn("function_b", warning_content, "function_b should now be deactivated") + + async def test_deactivated_prefix_is_idempotent(self): + """Test that [DEACTIVATED] prefix is not added multiple times.""" + from pipecat_flows.types import FlowsFunctionSchema + + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) + + # Create an already-deactivated schema + already_deactivated = FlowsFunctionSchema( + name="test_function", + description="[DEACTIVATED] Original description", + properties={}, + required=[], + handler=AsyncMock(return_value={"status": "error", "error": "deactivated"}), + ) + + # Deactivate it again + doubly_deactivated = flow_manager._create_deactivated_function_schema(already_deactivated) + + # Verify description only has one [DEACTIVATED] prefix + self.assertEqual( + doubly_deactivated.description, + "[DEACTIVATED] Original description", + "Should not add multiple [DEACTIVATED] prefixes", + ) + self.assertEqual( + doubly_deactivated.description.count("[DEACTIVATED]"), + 1, + "Should have exactly one [DEACTIVATED] prefix", + ) + + async def test_deactivated_function_reactivated_when_present_in_new_node(self): + """Test that a previously deactivated function becomes active again when included in new node.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + # Create a handler that we can track + function_a_handler = AsyncMock(return_value={"status": "success"}) + + # First node with function_a + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": function_a_handler, + "description": "Function A", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.mock_task.queue_frames.reset_mock() + + # Second node without function_a (it becomes deactivated) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [], + } + + await flow_manager._set_node("second", second_node) + + # Verify function_a is in current_functions as deactivated + self.assertIn("function_a", flow_manager._current_functions) + self.assertTrue( + flow_manager._current_functions["function_a"].description.startswith("[DEACTIVATED]") + ) + + self.mock_task.queue_frames.reset_mock() + + # Third node brings back function_a + third_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Third task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": function_a_handler, + "description": "Function A restored", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("third", third_node) + + # Verify function_a is now active (not deactivated) + self.assertIn("function_a", flow_manager._current_functions) + self.assertFalse( + flow_manager._current_functions["function_a"].description.startswith("[DEACTIVATED]"), + "function_a should be active again, not deactivated", + ) + + # Verify no deactivation warning was issued for function_a + all_frames = self._get_all_queued_frames() + messages_frame = next( + (f for f in all_frames if isinstance(f, LLMMessagesAppendFrame)), None + ) + self.assertIsNotNone(messages_frame, "LLMMessagesAppendFrame should be queued") + + for msg in messages_frame.messages: + content = str(msg.get("content", "")) + if "deactivated" in content.lower(): + self.assertNotIn( + "function_a", + content, + "function_a should not be in deactivation warning since it's now active", + ) + + async def test_reactivated_function_is_re_registered_with_llm(self): + """Test that reactivated functions are re-registered with the LLM.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + function_a_handler = AsyncMock(return_value={"status": "success"}) + + # First node with function_a + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": function_a_handler, + "description": "Function A", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + + # Count initial registrations for function_a + initial_registrations = sum( + 1 + for call in self.mock_llm.register_function.call_args_list + if call[0][0] == "function_a" + ) + self.assertEqual(initial_registrations, 1, "function_a should be registered once initially") + + # Second node without function_a (it becomes deactivated) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [], + } + + await flow_manager._set_node("second", second_node) + + # Count registrations after deactivation + deactivation_registrations = sum( + 1 + for call in self.mock_llm.register_function.call_args_list + if call[0][0] == "function_a" + ) + self.assertEqual( + deactivation_registrations, 2, "function_a should be re-registered when deactivated" + ) + + # Third node brings back function_a + third_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Third task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": function_a_handler, + "description": "Function A restored", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("third", third_node) + + # Count registrations after reactivation + final_registrations = sum( + 1 + for call in self.mock_llm.register_function.call_args_list + if call[0][0] == "function_a" + ) + self.assertEqual( + final_registrations, + 3, + "function_a should be re-registered when reactivated (total 3 registrations)", + ) + + async def test_reactivated_function_handler_works(self): + """Test that a reactivated function's handler actually works (not the deactivated dummy).""" + from pipecat.services.llm_service import FunctionCallParams + + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + # Track actual handler calls + real_handler_called = [] + + async def real_handler(args, flow_manager): + real_handler_called.append(args) + return {"status": "real_success", "data": args.get("input", "none")} + + # First node with function_a + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": real_handler, + "description": "Function A", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + + # Second node without function_a (deactivates it) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [], + } + + await flow_manager._set_node("second", second_node) + + # Third node brings back function_a + third_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Third task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": real_handler, + "description": "Function A restored", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("third", third_node) + + # Get the registered function handler from the most recent registration + # The last call for function_a should be the reactivated one + function_a_calls = [ + call + for call in self.mock_llm.register_function.call_args_list + if call[0][0] == "function_a" + ] + last_registered_handler = function_a_calls[-1][0][1] + + # Create mock params to call the handler + mock_result_callback = AsyncMock() + mock_params = MagicMock(spec=FunctionCallParams) + mock_params.arguments = {"input": "test_value"} + mock_params.result_callback = mock_result_callback + + # Call the registered handler + await last_registered_handler(mock_params) + + # Verify the real handler was called, not the deactivated dummy + self.assertEqual(len(real_handler_called), 1, "Real handler should have been called") + self.assertEqual(real_handler_called[0], {"input": "test_value"}) + + # Verify result callback was called with success (not error) + mock_result_callback.assert_called_once() + result = mock_result_callback.call_args[0][0] + self.assertEqual(result.get("status"), "real_success") + self.assertNotIn("error", result) + + async def test_reactivated_function_description_no_deactivated_prefix(self): + """Test that reactivated function's stored schema has no [DEACTIVATED] prefix.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + handler = AsyncMock(return_value={"status": "success"}) + + # First node with function_a + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": handler, + "description": "Original description", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.assertEqual( + flow_manager._current_functions["function_a"].description, "Original description" + ) + + # Second node without function_a (deactivates it) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [], + } + + await flow_manager._set_node("second", second_node) + self.assertTrue( + flow_manager._current_functions["function_a"].description.startswith("[DEACTIVATED]") + ) + + # Third node reactivates function_a + third_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Third task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": handler, + "description": "Restored description", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("third", third_node) + + # Verify stored schema has clean description + self.assertEqual( + flow_manager._current_functions["function_a"].description, + "Restored description", + "Reactivated function should have clean description without [DEACTIVATED] prefix", + ) + self.assertFalse( + flow_manager._current_functions["function_a"].description.startswith("[DEACTIVATED]") + ) + + async def test_functions_always_registered_on_each_transition(self): + """Test that functions are always re-registered on each node transition.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + handler = AsyncMock(return_value={"status": "success"}) + + # Node with function_a and function_b + node_with_both: NodeConfig = { + "task_messages": [{"role": "system", "content": "Task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": handler, + "description": "Function A", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "function_b", + "handler": handler, + "description": "Function B", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + # First transition + await flow_manager._set_node("first", node_with_both) + + first_transition_registrations = { + call[0][0] for call in self.mock_llm.register_function.call_args_list + } + self.assertIn("function_a", first_transition_registrations) + self.assertIn("function_b", first_transition_registrations) + + initial_call_count = len(self.mock_llm.register_function.call_args_list) + + # Second transition with same functions + await flow_manager._set_node("second", node_with_both) + + # Verify functions were registered again + new_call_count = len(self.mock_llm.register_function.call_args_list) + new_registrations = new_call_count - initial_call_count + + self.assertGreaterEqual( + new_registrations, 2, "Both functions should be re-registered on second transition" + ) + + async def test_deactivated_warnings_transform_frame_queued(self): + """Test that an LLMMessagesTransformFrame is queued to remove old warnings.""" + from pipecat.frames.frames import LLMMessagesTransformFrame + + from pipecat_flows.manager import FlowManager + + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + handler = AsyncMock(return_value={"status": "success"}) + + # First node with function_a + first_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "First task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_a", + "handler": handler, + "description": "Function A", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("first", first_node) + self.mock_task.queue_frames.reset_mock() + + # Second node with only function_b (function_a becomes deactivated) + second_node: NodeConfig = { + "task_messages": [{"role": "system", "content": "Second task."}], + "functions": [ + { + "type": "function", + "function": { + "name": "function_b", + "handler": handler, + "description": "Function B", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + } + + await flow_manager._set_node("second", second_node) + + # Verify LLMMessagesTransformFrame was queued + all_frames = [] + for call in self.mock_task.queue_frames.call_args_list: + frames = call[0][0] + all_frames.extend(frames) + + transform_frames = [f for f in all_frames if isinstance(f, LLMMessagesTransformFrame)] + self.assertGreater( + len(transform_frames), + 0, + "LLMMessagesTransformFrame should be queued to remove old warnings", + ) + + async def test_deactivated_warnings_transform_function_removes_warnings(self): + """Test that the transform function correctly removes deactivated warning messages.""" + from pipecat_flows.manager import FlowManager + + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + context_strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), + ) + await flow_manager.initialize() + + warning_marker = FlowManager._DEACTIVATED_FUNCTIONS_MESSAGE_PREFIX + + # Test messages with warnings mixed in + input_messages = [ + {"role": "system", "content": f"{warning_marker}Old warning 1"}, + {"role": "user", "content": "Hello"}, + {"role": "system", "content": f"{warning_marker}Old warning 2"}, + {"role": "assistant", "content": "Hi there"}, + {"role": "system", "content": f"{warning_marker}Old warning 3"}, + {"role": "system", "content": "Regular system message"}, + ] + + # Apply the transform method directly + result = flow_manager._remove_deactivated_functions_messages(input_messages) + + # Verify all warnings were removed + warning_count = sum( + 1 + for msg in result + if isinstance(msg.get("content", ""), str) and msg["content"].startswith(warning_marker) + ) + self.assertEqual(warning_count, 0, "All warnings should be removed") + + # Verify non-warning messages are preserved + self.assertEqual(len(result), 3, "Should have 3 messages remaining") + user_messages = [msg for msg in result if msg.get("role") == "user"] + assistant_messages = [msg for msg in result if msg.get("role") == "assistant"] + system_messages = [msg for msg in result if msg.get("role") == "system"] + self.assertEqual(len(user_messages), 1, "User message should be preserved") + self.assertEqual(len(assistant_messages), 1, "Assistant message should be preserved") + self.assertEqual(len(system_messages), 1, "Non-warning system message should be preserved") + self.assertEqual( + system_messages[0]["content"], + "Regular system message", + "Regular system message content should be preserved", + ) diff --git a/tests/test_manager.py b/tests/test_manager.py index 70b5969a..84b62d70 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -97,6 +97,14 @@ async def asyncSetUp(self): }, } + def _get_all_queued_frames(self): + """Helper to collect all frames from all queue_frames calls.""" + all_frames = [] + for call in self.mock_task.queue_frames.call_args_list: + frames = call[0][0] + all_frames.extend(frames) + return all_frames + async def test_static_flow_initialization(self): """Test initialization of a static flow configuration.""" flow_manager = FlowManager( @@ -175,12 +183,11 @@ async def test_static_flow_transitions(self, mock_llm_run_frame): # The first call should be for the context update self.assertTrue(self.mock_task.queue_frames.called) - # Get the first call (context update) - first_call = self.mock_task.queue_frames.call_args_list[0] - first_frames = first_call[0][0] + # Collect all queued frames + all_frames = self._get_all_queued_frames() # For subsequent nodes, should use AppendFrame by default - append_frames = [f for f in first_frames if isinstance(f, LLMMessagesAppendFrame)] + append_frames = [f for f in all_frames if isinstance(f, LLMMessagesAppendFrame)] self.assertTrue(len(append_frames) > 0, "Should have at least one AppendFrame") # Verify that LLM completion was triggered by checking LLMRunFrame instantiation @@ -290,7 +297,13 @@ async def result_callback(result, properties=None): # Test new style callback await flow_manager.set_node_from_config(new_style_node) - func = flow_manager._llm.register_function.call_args[0][1] + # Find the registered function by name (not using call_args since deactivated functions may be registered after) + func = None + for call in flow_manager._llm.register_function.call_args_list: + if call[0][0] == "new_style_function": + func = call[0][1] + break + self.assertIsNotNone(func, "new_style_function was not registered") # Reset context_updated callback context_updated_callback = None @@ -338,7 +351,7 @@ async def test_node_validation(self): await flow_manager.set_node_from_config(valid_config) self.assertEqual(flow_manager._current_node, "test") - self.assertEqual(flow_manager._current_functions, set()) + self.assertEqual(flow_manager._current_functions, {}) async def test_function_registration(self): """Test function registration with LLM.""" @@ -779,7 +792,7 @@ async def test_register_function_error_handling(self): # Mock LLM to raise error on register_function flow_manager._llm.register_function.side_effect = Exception("Registration error") - new_functions = set() + new_functions = {} with self.assertRaises(FlowError): await flow_manager._register_function("test", new_functions, None) @@ -823,11 +836,14 @@ async def test_update_llm_context_error_handling(self): # Mock task to raise error on queue_frames flow_manager._task.queue_frames.side_effect = Exception("Queue error") + from pipecat_flows.types import ContextStrategy, ContextStrategyConfig + with self.assertRaises(FlowError): await flow_manager._update_llm_context( role_messages=[], task_messages=[{"role": "system", "content": "Test"}], functions=[], + strategy=ContextStrategyConfig(strategy=ContextStrategy.APPEND), ) async def test_function_declarations_processing(self): @@ -1031,8 +1047,7 @@ async def test_role_message_inheritance(self): # Set first node and verify UpdateFrame await flow_manager.set_node_from_config(first_node) - first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call - first_frames = first_call[0][0] + first_frames = self._get_all_queued_frames() update_frames = [f for f in first_frames if isinstance(f, LLMMessagesUpdateFrame)] self.assertEqual(len(update_frames), 1) @@ -1045,8 +1060,7 @@ async def test_role_message_inheritance(self): await flow_manager.set_node_from_config(second_node) # Verify AppendFrame for second node - first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call - second_frames = first_call[0][0] + second_frames = self._get_all_queued_frames() append_frames = [f for f in second_frames if isinstance(f, LLMMessagesAppendFrame)] self.assertEqual(len(append_frames), 1) @@ -1069,8 +1083,7 @@ async def test_frame_type_selection(self): # First node should use UpdateFrame await flow_manager.set_node_from_config(test_node) - first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call - first_frames = first_call[0][0] + first_frames = self._get_all_queued_frames() self.assertTrue( any(isinstance(f, LLMMessagesUpdateFrame) for f in first_frames), "First node should use UpdateFrame", @@ -1085,8 +1098,7 @@ async def test_frame_type_selection(self): # Second node should use AppendFrame await flow_manager.set_node_from_config(test_node) - first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call - second_frames = first_call[0][0] + second_frames = self._get_all_queued_frames() self.assertTrue( any(isinstance(f, LLMMessagesAppendFrame) for f in second_frames), "Subsequent nodes should use AppendFrame", @@ -1383,8 +1395,8 @@ async def test_node_without_functions(self): # Set node and verify it works without error await flow_manager.set_node_from_config(node_config) - # Verify current_functions is empty set - self.assertEqual(flow_manager._current_functions, set()) + # Verify current_functions is empty dict + self.assertEqual(flow_manager._current_functions, {}) # Verify LLM tools were still set (with empty or placeholder functions) tools_frames_call = [ @@ -1412,8 +1424,8 @@ async def test_node_with_empty_functions(self): # Set node and verify it works without error await flow_manager.set_node_from_config(node_config) - # Verify current_functions is empty set - self.assertEqual(flow_manager._current_functions, set()) + # Verify current_functions is empty dict + self.assertEqual(flow_manager._current_functions, {}) # Verify LLM tools were still set (with empty or placeholder functions) tools_frames_call = [