diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 5d6a0ccf55..283f422b48 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -92,6 +92,8 @@ def suppress_warnings(): class LLM: + MODELS_WITHOUT_STOP_SUPPORT = ["o3", "o3-mini", "o4-mini"] + def __init__( self, model: str, @@ -155,7 +157,7 @@ def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str "temperature": self.temperature, "top_p": self.top_p, "n": self.n, - "stop": self.stop, + "stop": self.stop if self.supports_stop_words() else None, "max_tokens": self.max_tokens or self.max_completion_tokens, "presence_penalty": self.presence_penalty, "frequency_penalty": self.frequency_penalty, @@ -193,6 +195,19 @@ def supports_function_calling(self) -> bool: return False def supports_stop_words(self) -> bool: + """ + Determines whether the current model supports the 'stop' parameter. + + This method checks if the model is in the list of models known not to support + stop words, and if not, it queries the litellm library to determine if the + model supports the 'stop' parameter. + + Returns: + bool: True if the model supports stop words, False otherwise. + """ + if any(self.model.startswith(model) for model in self.MODELS_WITHOUT_STOP_SUPPORT): + return False + try: params = get_supported_openai_params(model=self.model) return "stop" in params diff --git a/tests/llm_test.py b/tests/llm_test.py index e824d54c92..275237297b 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -28,3 +28,41 @@ def test_llm_callback_replacement(): assert usage_metrics_1.successful_requests == 1 assert usage_metrics_2.successful_requests == 1 assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary() + + +class TestLLMStopWords: + """Tests for LLM stop words functionality.""" + + def test_supports_stop_words_for_o3_model(self): + """Test that supports_stop_words returns False for o3 model.""" + llm = LLM(model="o3") + assert not llm.supports_stop_words() + + def test_supports_stop_words_for_o4_mini_model(self): + """Test that supports_stop_words returns False for o4-mini model.""" + llm = LLM(model="o4-mini") + assert not llm.supports_stop_words() + + def test_supports_stop_words_for_supported_model(self): + """Test that supports_stop_words returns True for models that support stop words.""" + llm = LLM(model="gpt-4") + assert llm.supports_stop_words() + + @pytest.mark.vcr(filter_headers=["authorization"]) + def test_llm_call_excludes_stop_parameter_for_unsupported_models(self, monkeypatch): + """Test that the LLM.call method excludes the stop parameter for models that don't support it.""" + def mock_completion(**kwargs): + assert 'stop' not in kwargs, "Stop parameter should be excluded for o3 model" + assert 'model' in kwargs, "Model parameter should be included" + assert 'messages' in kwargs, "Messages parameter should be included" + return {"choices": [{"message": {"content": "Hello, World!"}}]} + + monkeypatch.setattr("litellm.completion", mock_completion) + + llm = LLM(model="o3") + llm.stop = ["STOP"] + + messages = [{"role": "user", "content": "Say 'Hello, World!'"}] + response = llm.call(messages) + + assert response == "Hello, World!"