Skip to content

Fix Gemini Model Integration Issues (#2803) #2804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class AccumulatedToolArgs(BaseModel):


class LLM(BaseLLM):
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
GEMINI_IDENTIFIERS = ("gemini", "gemma-")

def __init__(
self,
model: str,
Expand Down Expand Up @@ -319,8 +322,55 @@ def _is_anthropic_model(self, model: str) -> bool:
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
if not isinstance(model, str):
return False
return any(prefix in model.lower() for prefix in self.ANTHROPIC_PREFIXES)

def _is_gemini_model(self, model: str) -> bool:
"""Determine if the model is from Google Gemini provider.

Args:
model: The model identifier string.

Returns:
bool: True if the model is from Gemini, False otherwise.
"""
if not isinstance(model, str):
return False
return any(identifier in model.lower() for identifier in self.GEMINI_IDENTIFIERS)

def _normalize_gemini_model(self, model: str) -> str:
"""Normalize Gemini model name to the format expected by LiteLLM.

Handles formats like "models/gemini-pro" or "gemini-pro" and converts
them to "gemini/gemini-pro" format.

Args:
model: The model identifier string.

Returns:
str: Normalized model name.

Raises:
ValueError: If model is not a string or is empty.
"""
if not isinstance(model, str):
raise ValueError(f"Model must be a string, got {type(model)}")

if not model.strip():
raise ValueError("Model name cannot be empty")

if model.startswith("gemini/"):
return model

if model.startswith("models/"):
model_name = model.split("/", 1)[1]
return f"gemini/{model_name}"

if self._is_gemini_model(model) and "/" not in model:
return f"gemini/{model}"

return model

def _prepare_completion_params(
self,
Expand All @@ -343,9 +393,23 @@ def _prepare_completion_params(
messages = [{"role": "user", "content": messages}]
formatted_messages = self._format_messages_for_provider(messages)

# --- 2) Prepare the parameters for the completion call
model = self.model
if self._is_gemini_model(model):
try:
model = self._normalize_gemini_model(model)
logging.info(f"Normalized Gemini model name from '{self.model}' to '{model}'")

# --- 2.1) Map GOOGLE_API_KEY to GEMINI_API_KEY if needed
if not os.environ.get("GEMINI_API_KEY") and os.environ.get("GOOGLE_API_KEY"):
os.environ["GEMINI_API_KEY"] = os.environ["GOOGLE_API_KEY"]
logging.info("Mapped GOOGLE_API_KEY to GEMINI_API_KEY for Gemini model")
except ValueError as e:
logging.error(f"Error normalizing Gemini model: {str(e)}")
model = self.model

# --- 3) Prepare the parameters for the completion call
params = {
"model": self.model,
"model": model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,
Expand Down
107 changes: 107 additions & 0 deletions tests/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,37 @@ def test_get_custom_llm_provider_gemini():
assert llm._get_custom_llm_provider() == "gemini"


def test_is_gemini_model():
"""Test the _is_gemini_model method with various model names."""
llm = LLM(model="gpt-4") # Model doesn't matter for this test

assert llm._is_gemini_model("gemini-pro") == True
assert llm._is_gemini_model("gemini/gemini-1.5-pro") == True
assert llm._is_gemini_model("models/gemini-pro") == True
assert llm._is_gemini_model("gemma-7b") == True

# Should not identify as Gemini models
assert llm._is_gemini_model("gpt-4") == False
assert llm._is_gemini_model("claude-3") == False
assert llm._is_gemini_model("mistral-7b") == False


def test_normalize_gemini_model():
"""Test the _normalize_gemini_model method with various model formats."""
llm = LLM(model="gpt-4") # Model doesn't matter for this test

assert llm._normalize_gemini_model("gemini/gemini-1.5-pro") == "gemini/gemini-1.5-pro"

assert llm._normalize_gemini_model("models/gemini-pro") == "gemini/gemini-pro"
assert llm._normalize_gemini_model("models/gemini-1.5-flash") == "gemini/gemini-1.5-flash"

assert llm._normalize_gemini_model("gemini-pro") == "gemini/gemini-pro"
assert llm._normalize_gemini_model("gemini-1.5-flash") == "gemini/gemini-1.5-flash"

assert llm._normalize_gemini_model("gpt-4") == "gpt-4"
assert llm._normalize_gemini_model("claude-3") == "claude-3"


def test_get_custom_llm_provider_openai():
llm = LLM(model="gpt-4")
assert llm._get_custom_llm_provider() == None
Expand Down Expand Up @@ -274,6 +305,82 @@ def test_gemini_models(model):
assert "Paris" in result


@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
@pytest.mark.parametrize(
"model",
[
"models/gemini-pro", # Format from issue #2803
"gemini-pro", # Format without provider prefix
],
)
def test_gemini_model_normalization(model):
"""Test that different Gemini model formats are normalized correctly."""
llm = LLM(model=model)

with patch("litellm.completion") as mock_completion:
# Create mocks for response structure
mock_message = MagicMock()
mock_message.content = "Paris"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]

# Set up the mocked completion to return the mock response
mock_completion.return_value = mock_response

llm.call("What is the capital of France?")

# Check that the model was normalized correctly in the call to litellm
args, kwargs = mock_completion.call_args
assert kwargs["model"].startswith("gemini/")
assert "gemini-pro" in kwargs["model"]


@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
def test_gemini_api_key_mapping():
"""Test that GOOGLE_API_KEY is mapped to GEMINI_API_KEY for Gemini models."""
original_google_api_key = os.environ.get("GOOGLE_API_KEY")
original_gemini_api_key = os.environ.get("GEMINI_API_KEY")

try:
# Set up test environment
test_api_key = "test_google_api_key"
os.environ["GOOGLE_API_KEY"] = test_api_key
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]

llm = LLM(model="gemini-pro")

with patch("litellm.completion") as mock_completion:
# Create mocks for response structure
mock_message = MagicMock()
mock_message.content = "Paris"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]

# Set up the mocked completion to return the mock response
mock_completion.return_value = mock_response

llm.call("What is the capital of France?")

# Check that GEMINI_API_KEY was set from GOOGLE_API_KEY
assert os.environ.get("GEMINI_API_KEY") == test_api_key

finally:
if original_google_api_key is not None:
os.environ["GOOGLE_API_KEY"] = original_google_api_key
else:
os.environ.pop("GOOGLE_API_KEY", None)

if original_gemini_api_key is not None:
os.environ["GEMINI_API_KEY"] = original_gemini_api_key
else:
os.environ.pop("GEMINI_API_KEY", None)


@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
@pytest.mark.parametrize(
"model",
Expand Down
Loading