Skip to content

Commit 93755ba

Browse files
itomekTomasz Iniewicz
andauthored
fix(llm): correct LiteLLM embed() model override + finish amd#1593 review items (amd#1626)
Closes amd#1625 ## Why this matters After amd#1593 added the LiteLLM provider, `LiteLLMProvider.embed()` crashed with `TypeError: embedding() got multiple values for keyword argument 'model'` the moment a caller overrode the model — `model` was passed both explicitly and inside the spread `**call_kwargs`, and that path had zero test coverage. Now `embed()` works with and without a model override, and the rest of the amd#1593 review is closed out: the provider is documented, the `create_client` docstring lists it, and the redundant `drop_params` handling is collapsed to a single per-call default callers can override. ## Test plan - [x] `pytest tests/unit/test_litellm_provider.py` → 12 passed (incl. 2 new `embed()` tests; the override test reproduced the `TypeError` before the fix) - [x] `util/lint.py --black --isort --flake8` → all PASS - [ ] CI green after maintainer approves the workflow run --------- Co-authored-by: Tomasz Iniewicz <heaters-nays0p@icloud.com>
1 parent fb76ccd commit 93755ba

4 files changed

Lines changed: 47 additions & 14 deletions

File tree

docs/sdk/sdks/llm.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ llm_claude = create_client(use_claude=True)
6666
# OpenAI API
6767
llm_openai = create_client(use_openai=True)
6868

69+
# LiteLLM gateway — 100+ providers (Bedrock, Vertex AI, Groq, Azure, ...)
70+
# Requires: pip install 'gaia[litellm]'
71+
llm_litellm = create_client("litellm", model="anthropic/claude-sonnet-4-6")
72+
6973
# Use same interface
7074
response = llm_claude.generate("Explain Python decorators")
7175
```

src/gaia/llm/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def create_client(
2424
Create an LLM client, auto-detecting provider from parameters.
2525
2626
Args:
27-
provider: Explicit provider name ("lemonade", "openai", or "claude").
27+
provider: Explicit provider name ("lemonade", "openai", "claude", or "litellm").
2828
If not specified, auto-detected from use_claude/use_openai flags.
2929
use_claude: If True, use Claude provider (ignored if provider is specified)
3030
use_openai: If True, use OpenAI provider (ignored if provider is specified)

src/gaia/llm/providers/litellm.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,19 @@ def __init__(
1717
system_prompt: Optional[str] = None,
1818
**kwargs,
1919
):
20-
import litellm
21-
20+
try:
21+
import litellm
22+
except ImportError as e:
23+
raise ImportError(
24+
"litellm is not installed. Install it with: pip install 'gaia[litellm]'"
25+
) from e
26+
27+
self._litellm = litellm
2228
self._model = model
2329
self._system_prompt = system_prompt
2430
self._api_key = api_key
2531
self._extra_kwargs = kwargs
2632

27-
litellm.drop_params = True
28-
2933
@property
3034
def provider_name(self) -> str:
3135
return "LiteLLM"
@@ -51,39 +55,36 @@ def chat(
5155
stream: bool = False,
5256
**kwargs,
5357
) -> Union[str, Iterator[str]]:
54-
import litellm
55-
5658
if self._system_prompt:
5759
messages = [{"role": "system", "content": self._system_prompt}] + list(
5860
messages
5961
)
6062

6163
call_kwargs = {**self._extra_kwargs, **kwargs}
64+
call_kwargs.setdefault("drop_params", True)
6265
if self._api_key:
6366
call_kwargs["api_key"] = self._api_key
6467

65-
response = litellm.completion(
68+
response = self._litellm.completion(
6669
model=model or self._model,
6770
messages=messages,
6871
stream=stream,
69-
drop_params=True,
7072
**call_kwargs,
7173
)
7274
if stream:
7375
return self._handle_stream(response)
7476
return response.choices[0].message.content
7577

7678
def embed(self, texts: list[str], **kwargs) -> list[list[float]]:
77-
import litellm
78-
79+
model = kwargs.pop("model", self._model)
7980
call_kwargs = {**self._extra_kwargs, **kwargs}
81+
call_kwargs.setdefault("drop_params", True)
8082
if self._api_key:
8183
call_kwargs["api_key"] = self._api_key
8284

83-
response = litellm.embedding(
84-
model=kwargs.pop("model", self._model),
85+
response = self._litellm.embedding(
86+
model=model,
8587
input=texts,
86-
drop_params=True,
8788
**call_kwargs,
8889
)
8990
return [item["embedding"] for item in response.data]

tests/unit/test_litellm_provider.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,34 @@ def test_generate_delegates_to_chat(self):
132132
del sys.modules["litellm"]
133133

134134

135+
class TestLiteLLMEmbed:
136+
def test_embed_default_model(self):
137+
fake = _stub_litellm()
138+
fake.embedding.return_value = MagicMock(data=[{"embedding": [0.1, 0.2, 0.3]}])
139+
from gaia.llm.providers.litellm import LiteLLMProvider
140+
141+
provider = LiteLLMProvider(api_key="sk-test", model="text-embedding-3-small")
142+
result = provider.embed(["hello"])
143+
144+
assert result == [[0.1, 0.2, 0.3]]
145+
assert fake.embedding.call_args.kwargs["model"] == "text-embedding-3-small"
146+
assert fake.embedding.call_args.kwargs["api_key"] == "sk-test"
147+
del sys.modules["litellm"]
148+
149+
def test_embed_uses_override_model(self):
150+
fake = _stub_litellm()
151+
fake.embedding.return_value = MagicMock(data=[{"embedding": [0.0]}])
152+
from gaia.llm.providers.litellm import LiteLLMProvider
153+
154+
provider = LiteLLMProvider(model="text-embedding-3-small")
155+
# Overriding the model must NOT raise TypeError (model previously landed
156+
# in both the explicit arg and **call_kwargs).
157+
provider.embed(["hello"], model="text-embedding-3-large")
158+
159+
assert fake.embedding.call_args.kwargs["model"] == "text-embedding-3-large"
160+
del sys.modules["litellm"]
161+
162+
135163
class TestLiteLLMNotSupported:
136164
def test_vision_raises_not_supported(self):
137165
_stub_litellm()

0 commit comments

Comments
 (0)