-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathtest_llm.py
More file actions
385 lines (314 loc) · 12.4 KB
/
Copy pathtest_llm.py
File metadata and controls
385 lines (314 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
"""LightningLLM main tests."""
import os
import re
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litai import LLM
def test_initialization_with_config_file(monkeypatch):
"""Test LigtningLLM config."""
mock_llm_instance = MagicMock()
monkeypatch.setattr("litai.client.SDKLLM", mock_llm_instance)
LLM(model="openai/gpt-4", lightning_api_key="my-key", lightning_user_id="my-user-id")
assert os.getenv("LIGHTNING_API_KEY") == "my-key"
assert os.getenv("LIGHTNING_USER_ID") == "my-user-id"
@patch("litai.client.SDKLLM")
def test_invalid_model(mock_llm_class):
"""Test invalid model name."""
dummy_model_name = "dummy-model"
mock_llm_class.side_effect = ValueError(
f"Failed to load model '{dummy_model_name}': Model '{dummy_model_name}' not found. "
)
llm = LLM(model=dummy_model_name)
with pytest.raises(ValueError, match="not found"):
llm._wait_for_model()
def test_default_model(monkeypatch):
"""Test default model name."""
mock_llm_instance = MagicMock()
monkeypatch.setattr("litai.client.SDKLLM", mock_llm_instance)
warning_message = "No model was provided, defaulting to openai/gpt-4o"
with pytest.warns(UserWarning, match=re.escape(warning_message)):
llm = LLM()
assert len(llm.fallback_models) == 0
assert llm.model == "openai/gpt-4o"
@patch("litai.client.SDKLLM")
def test_cloudy_models_preload(mock_sdkllm):
"""Test that CLOUDY_MODELS are preloaded during LLM initialization."""
cloudy_models = {
"openai/gpt-4o",
"openai/gpt-4",
"openai/o3-mini",
"anthropic/claude-3-5-sonnet-20240620",
"google/gemini-2.5-pro",
"google/gemini-2.5-flash",
}
from litai.client import LLM as LLMCLIENT
LLMCLIENT._sdkllm_cache.clear()
llm = LLM()
llm._wait_for_model()
expected_calls = len(cloudy_models) * 2 # for both async and sync
assert mock_sdkllm.call_count == expected_calls, (
f"Expected {expected_calls} calls to SDKLLM, but got {mock_sdkllm.call_count}"
)
enable_async_param = {call.kwargs["enable_async"] for call in mock_sdkllm.call_args_list}
assert set(enable_async_param) == {True, False}
@patch("litai.client.SDKLLM")
def test_llm_chat(mock_llm_class):
"""Test LigtningLLM chat."""
from litai.client import LLM as LLMCLIENT
LLMCLIENT._sdkllm_cache.clear()
mock_llm_instance = MagicMock()
mock_llm_instance.chat.return_value = "Hello! I am a helpful assistant."
mock_llm_class.return_value = mock_llm_instance
llm = LLM(model="openai/gpt-4")
response = llm.chat(
"Hello, who are you?",
system_prompt="You are a helpful assistant.",
metadata={"user_api": "123456"},
my_kwarg="test-kwarg",
)
assert isinstance(response, str)
assert "helpful" in response.lower()
mock_llm_instance.chat.assert_called_once_with(
prompt="Hello, who are you?",
system_prompt="You are a helpful assistant.",
max_completion_tokens=500,
images=None,
conversation=None,
metadata={"user_api": "123456"},
stream=False,
full_response=False,
my_kwarg="test-kwarg",
)
test_kwargs = mock_llm_instance.chat.call_args.kwargs
assert test_kwargs.get("my_kwarg") == "test-kwarg"
llm.reset_conversation("test")
mock_llm_instance.reset_conversation.assert_called_once()
def test_model_override(monkeypatch):
"""Test override model logic when main model fails."""
mock_llm = MagicMock()
mock_llm.name = "default-model"
mock_llm.enable_async = False
mock_fallback_model = MagicMock()
mock_fallback_model.name = "fallback-model"
mock_fallback_model.enable_async = False
mock_override = MagicMock()
mock_override.name = "override-model"
mock_override.chat.return_value = "Override response"
mock_override.enable_async = False
def mock_llm_constructor(name, teamspace="default-teamspace", **kwargs):
if name == "default-model":
return mock_llm
if name == "fallback-model":
return mock_fallback_model
if name == "override-model":
return mock_override
raise ValueError(f"Unknown model: {name}")
monkeypatch.setattr("litai.client.SDKLLM", mock_llm_constructor)
llm = LLM(
model="default-model",
fallback_models=["fallback-model"],
max_retries=3,
full_response=True,
)
llm.chat(prompt="Hello", model="override-model")
assert mock_override.chat.call_count == 1
assert mock_fallback_model.chat.call_count == 0
assert mock_llm.chat.call_count == 0
mock_override.chat.assert_called_once_with(
prompt="Hello",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=True,
)
def test_fallback_models(monkeypatch):
"""Test fallback model logic when main model fails."""
from litai.client import LLM as LLMCLIENT
LLMCLIENT._sdkllm_cache.clear()
mock_main_model = MagicMock()
mock_main_model.name = "main-model"
mock_fallback_model = MagicMock()
mock_fallback_model.name = "fallback-model"
mock_main_model.chat.side_effect = Exception("Primary model error")
mock_fallback_model.chat.side_effect = [
Exception("Fallback error 1"),
Exception("Fallback error 2"),
"Fallback response",
]
def mock_llm_constructor(name, teamspace="default-teamspace", **kwargs):
if name == "main-model":
return mock_main_model
if name == "fallback-model":
return mock_fallback_model
raise ValueError(f"Unknown model: {name}")
monkeypatch.setattr("litai.client.SDKLLM", mock_llm_constructor)
llm = LLM(
model="main-model",
fallback_models=["fallback-model"],
max_retries=3,
)
llm.chat(prompt="Hello")
assert mock_main_model.chat.call_count == 3
assert mock_fallback_model.chat.call_count == 3
mock_fallback_model.chat.assert_called_with(
prompt="Hello",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)
@pytest.mark.asyncio
async def test_llm_async_chat(monkeypatch):
"""Test async requests."""
mock_sdkllm = MagicMock()
mock_sdkllm.name = "mock-model"
mock_sdkllm.chat = AsyncMock(return_value="Hello, async world!")
monkeypatch.setattr("litai.client.SDKLLM", lambda *args, **kwargs: mock_sdkllm)
llm = LLM(model="mock-model", enable_async=True)
result = await llm.chat("Hi there", conversation="async-test")
assert result == "Hello, async world!"
mock_sdkllm.chat.assert_called_once()
def test_get_history(monkeypatch, capsys):
"""Test get history."""
mock_sdkllm = MagicMock()
mock_sdkllm.name = "mock-model"
mock_sdkllm.get_history = MagicMock(
return_value=[
{"role": "user", "content": "Hello, world!", "model": "mock-model"},
{"role": "assistant", "content": "I am a mock model!", "model": "mock-model"},
]
)
monkeypatch.setattr("litai.client.SDKLLM", lambda *args, **kwargs: mock_sdkllm)
llm = LLM(model="mock-model")
# Test default behavior (prints to stdout)
result = llm.get_history("async-test")
assert result is None # get_history returns None when raw=False
# Capture the printed output
captured = capsys.readouterr()
assert "🧠 Conversation: 'async-test'" in captured.out
assert "🟦 You" in captured.out
assert "🟨 mock-model" in captured.out
assert "Hello, world!" in captured.out
assert "I am a mock model!" in captured.out
assert "--- End of conversation ---" in captured.out
# Test raw=True behavior (returns data instead of printing)
result = llm.get_history("async-test", raw=True)
assert result == [
{"role": "user", "content": "Hello, world!", "model": "mock-model"},
{"role": "assistant", "content": "I am a mock model!", "model": "mock-model"},
]
def test_authenticate_method(monkeypatch):
# Mock the login.Auth class
mock_auth = MagicMock()
mock_auth.api_key = "test-api-key"
mock_auth.user_id = "test-user-id"
def mock_auth_constructor():
return mock_auth
monkeypatch.setattr("litai.client.login.Auth", mock_auth_constructor)
# Test case 1: Both api_key and user_id provided
LLM(model="openai/gpt-4", lightning_api_key="my-key", lightning_user_id="my-user-id")
# Verify that the authentication was not called
mock_auth.authenticate.assert_not_called()
# Verify that environment variables were set
assert os.getenv("LIGHTNING_API_KEY") == "my-key"
assert os.getenv("LIGHTNING_USER_ID") == "my-user-id"
# Test case 2: Neither api_key nor user_id provided
mock_auth.reset_mock()
os.environ.pop("LIGHTNING_API_KEY", None)
os.environ.pop("LIGHTNING_USER_ID", None)
LLM(model="openai/gpt-4")
# Verify that authentication was called
mock_auth.authenticate.assert_called_once()
@patch("litai.client.SDKLLM")
def test_llm_if_method(mock_llm_class):
"""Test the LLM if_ method."""
from litai.client import LLM as LLMCLIENT
LLMCLIENT._sdkllm_cache.clear()
mock_llm_instance = MagicMock()
# Test case where the condition is true
mock_llm_instance.chat.return_value = "yes"
mock_llm_class.return_value = mock_llm_instance
llm = LLM(model="openai/gpt-4")
assert llm.if_("is it true?") is True
mock_llm_instance.chat.assert_called_with(
prompt="is it true?\n\nreply 'yes' if the answer is yes, otherwise reply 'no'.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)
# Test case where the condition is false
mock_llm_instance.chat.return_value = "no"
assert llm.if_("is it false?") is False
mock_llm_instance.chat.assert_called_with(
prompt="is it false?\n\nreply 'yes' if the answer is yes, otherwise reply 'no'.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)
# Test case with different capitalization/spacing
mock_llm_instance.chat.return_value = " Yes "
assert llm.if_("is it a positive response?") is True
@patch("litai.client.SDKLLM")
def test_llm_classify_method(mock_llm_class):
"""Test the LLM classify method."""
from litai.client import LLM as LLMCLIENT
LLMCLIENT._sdkllm_cache.clear()
mock_llm_instance = MagicMock()
# Test a simple classification
mock_llm_instance.chat.return_value = "positive"
mock_llm_class.return_value = mock_llm_instance
llm = LLM(model="openai/gpt-4")
result = llm.classify("this movie was great!", "positive", "negative")
assert result == "positive"
mock_llm_instance.chat.assert_called_with(
prompt="this movie was great!\n\nclassify the input as one of these: positive, negative. reply with only the class.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)
# Test another classification
mock_llm_instance.chat.return_value = "negative"
result = llm.classify("this movie was awful.", "positive", "negative")
assert result == "negative"
mock_llm_instance.chat.assert_called_with(
prompt="this movie was awful.\n\nclassify the input as one of these: positive, negative. reply with only the class.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)
# Test with multiple classes
mock_llm_instance.chat.return_value = "neutral"
result = llm.classify("it was okay.", "positive", "negative", "neutral")
assert result == "neutral"
mock_llm_instance.chat.assert_called_with(
prompt="it was okay.\n\nclassify the input as one of these: positive, negative, neutral. reply with only the class.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)