Skip to content
33 changes: 21 additions & 12 deletions cover_agent/AICaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
wait_fixed,
)

from cover_agent import (
NO_SUPPORT_TEMPERATURE_MODELS,
USER_MESSAGE_ONLY_MODELS,
NO_SUPPORT_STREAMING_MODELS,
)

MODEL_RETRIES = 3


Expand Down Expand Up @@ -47,6 +53,10 @@ def __init__(
self.enable_retry = enable_retry
self.max_tokens = max_tokens

self.user_message_only_models = USER_MESSAGE_ONLY_MODELS
self.no_support_temperature_models = NO_SUPPORT_TEMPERATURE_MODELS
self.no_support_streaming_models = NO_SUPPORT_STREAMING_MODELS

@conditional_retry # You can access self.enable_retry here
def call_model(self, prompt: dict, stream=True):
"""
Expand All @@ -66,14 +76,10 @@ def call_model(self, prompt: dict, stream=True):
if prompt["system"] == "":
messages = [{"role": "user", "content": prompt["user"]}]
else:
if self.model in ["o1-preview", "o1-mini"]:
# o1 doesn't accept a system message so we add it to the prompt
messages = [
{
"role": "user",
"content": prompt["system"] + "\n" + prompt["user"],
},
]
if self.model in self.user_message_only_models:
# Combine system and user messages for models that only support user messages
combined_content = (prompt["system"] + "\n" + prompt["user"]).strip()
messages = [{"role": "user", "content": combined_content}]
else:
messages = [
{"role": "system", "content": prompt["system"]},
Expand All @@ -89,11 +95,14 @@ def call_model(self, prompt: dict, stream=True):
"max_tokens": self.max_tokens,
}

# Remove temperature for models that don't support it
if self.model in self.no_support_temperature_models:
completion_params.pop("temperature", None)

# Model-specific adjustments
if self.model in ["o1-preview", "o1-mini", "o1", "o3-mini"]:
stream = False # o1 doesn't support streaming
completion_params["temperature"] = 1
completion_params["stream"] = False # o1 doesn't support streaming
if self.model in self.no_support_streaming_models:
stream = False
completion_params["stream"] = False
completion_params["max_completion_tokens"] = 2 * self.max_tokens
# completion_params["reasoning_effort"] = "high"
completion_params.pop("max_tokens", None) # Remove 'max_tokens' if present
Expand Down
23 changes: 23 additions & 0 deletions cover_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
USER_MESSAGE_ONLY_MODELS = [
"deepseek/deepseek-reasoner",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview"
]

NO_SUPPORT_TEMPERATURE_MODELS = [
"deepseek/deepseek-reasoner",
"o1-mini",
"o1-mini-2024-09-12",
"o1",
"o1-2024-12-17",
"o3-mini",
"o3-mini-2025-01-31",
"o1-preview"
]

NO_SUPPORT_STREAMING_MODELS = [
"deepseek/deepseek-reasoner",
"o1",
"o1-2024-12-17",
]
115 changes: 112 additions & 3 deletions tests/test_AICaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,30 +143,108 @@ def test_call_model_missing_keys(self, ai_caller):
)

@patch("cover_agent.AICaller.litellm.completion")
def test_call_model_o1_preview(self, mock_completion, ai_caller):
def test_call_model_user_message_only(self, mock_completion, ai_caller):
"""
Test the call_model method with the 'o1-preview' model.
Test the call_model method with a model that only supports user messages.
"""
ai_caller.model = "o1-preview"
# Set the model to one that only supports user messages
ai_caller.model = "o1-preview" # This is in USER_MESSAGE_ONLY_MODELS
prompt = {"system": "System instruction", "user": "User query"}

# Mock the response
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="response"))]
mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10)
mock_completion.return_value = mock_response

# Call the method with stream=False for simplicity
response, prompt_tokens, response_tokens = ai_caller.call_model(
prompt, stream=False
)

# Verify the response
assert response == "response"
assert prompt_tokens == 2
assert response_tokens == 10

# Check that litellm.completion was called with the correct arguments
# Most importantly, verify that system and user messages were combined into a single user message
mock_completion.assert_called_once()
call_args = mock_completion.call_args[1]
assert len(call_args["messages"]) == 1
assert call_args["messages"][0]["role"] == "user"
assert call_args["messages"][0]["content"] == "System instruction\nUser query"

@patch("cover_agent.AICaller.litellm.completion")
def test_call_model_no_temperature_support(self, mock_completion, ai_caller):
"""
Test the call_model method with a model that doesn't support temperature.
"""
# Set the model to one that doesn't support temperature
ai_caller.model = "o1-preview" # This is in NO_SUPPORT_TEMPERATURE_MODELS
prompt = {"system": "System message", "user": "Hello, world!"}

# Mock the response
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="response"))]
mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10)
mock_completion.return_value = mock_response

# Call the method
response, prompt_tokens, response_tokens = ai_caller.call_model(
prompt, stream=False
)

# Verify the response
assert response == "response"
assert prompt_tokens == 2
assert response_tokens == 10

# Check that litellm.completion was called without the temperature parameter
mock_completion.assert_called_once()
call_args = mock_completion.call_args[1]
assert "temperature" not in call_args

@patch("cover_agent.AICaller.litellm.completion")
def test_call_model_no_streaming_support(self, mock_completion, ai_caller):
"""
Test the call_model method with a model that doesn't support streaming.
"""
# Set the model to one that doesn't support streaming
ai_caller.model = "o1" # This is in NO_SUPPORT_STREAMING_MODELS
prompt = {"system": "System message", "user": "Hello, world!"}

# Mock the response
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="response"))]
mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10)
mock_completion.return_value = mock_response

# Call the method explicitly requesting streaming, which should be ignored
response, prompt_tokens, response_tokens = ai_caller.call_model(
prompt, stream=True
)

# Verify the response
assert response == "response"
assert prompt_tokens == 2
assert response_tokens == 10

# Check that litellm.completion was called with stream=False
mock_completion.assert_called_once()
call_args = mock_completion.call_args[1]
assert call_args["stream"] == False
# Check if max_tokens was removed and max_completion_tokens was added
assert "max_tokens" not in call_args
assert call_args["max_completion_tokens"] == 2 * ai_caller.max_tokens

@patch("cover_agent.AICaller.litellm.completion")
def test_call_model_streaming_response(self, mock_completion, ai_caller):
"""
Test the call_model method with a streaming response.
"""
# Make sure we're using a model that supports streaming
ai_caller.model = "gpt-4" # Not in NO_SUPPORT_STREAMING_MODELS
prompt = {"system": "", "user": "Hello, world!"}
# Mock the response to be an iterable of chunks
mock_chunk = Mock()
Expand All @@ -183,6 +261,37 @@ def test_call_model_streaming_response(self, mock_completion, ai_caller):
assert response == "response"
assert prompt_tokens == 2

@patch("cover_agent.AICaller.litellm.completion")
def test_call_model_empty_system_prompt(self, mock_completion, ai_caller):
"""
Test the call_model method with an empty system prompt.
"""
# Should work the same for any model type
prompt = {"system": "", "user": "Hello, world!"}

# Mock the response
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="response"))]
mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10)
mock_completion.return_value = mock_response

# Call the method
response, prompt_tokens, response_tokens = ai_caller.call_model(
prompt, stream=False
)

# Verify the response
assert response == "response"
assert prompt_tokens == 2
assert response_tokens == 10

# Check that litellm.completion was called with only a user message
mock_completion.assert_called_once()
call_args = mock_completion.call_args[1]
assert len(call_args["messages"]) == 1
assert call_args["messages"][0]["role"] == "user"
assert call_args["messages"][0]["content"] == "Hello, world!"

@patch("cover_agent.AICaller.litellm.completion")
@patch.dict(os.environ, {"WANDB_API_KEY": "test_key"})
@patch("cover_agent.AICaller.Trace.log")
Expand Down
Loading