forked from waybarrios/vllm-mlx
-
Notifications
You must be signed in to change notification settings - Fork 65
Expand file tree
/
Copy pathtest_embeddings.py
More file actions
289 lines (224 loc) · 9.71 KB
/
test_embeddings.py
File metadata and controls
289 lines (224 loc) · 9.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# SPDX-License-Identifier: Apache-2.0
"""Tests for the OpenAI-compatible Embeddings API."""
import platform
import sys
from unittest.mock import MagicMock, patch
import pytest
# Skip all tests if not on Apple Silicon
pytestmark = pytest.mark.skipif(
sys.platform != "darwin" or platform.machine() != "arm64",
reason="Requires Apple Silicon",
)
# =============================================================================
# Unit Tests - Pydantic Models
# =============================================================================
class TestEmbeddingModels:
"""Test embedding request/response Pydantic models."""
def test_embedding_request_single_string(self):
"""Test EmbeddingRequest with a single input string."""
from vllm_mlx.api.models import EmbeddingRequest
req = EmbeddingRequest(model="test-model", input="Hello world")
assert req.model == "test-model"
assert req.input == "Hello world"
assert req.encoding_format == "float"
def test_embedding_response_serialization(self):
"""Test that EmbeddingResponse serializes to OpenAI-compatible JSON."""
from vllm_mlx.api.models import (
EmbeddingData,
EmbeddingResponse,
EmbeddingUsage,
)
response = EmbeddingResponse(
data=[EmbeddingData(index=0, embedding=[1.0, 2.0, 3.0])],
model="text-embedding-3-large",
usage=EmbeddingUsage(prompt_tokens=5, total_tokens=5),
)
d = response.model_dump()
assert d["object"] == "list"
assert d["data"][0]["object"] == "embedding"
assert d["data"][0]["index"] == 0
assert d["data"][0]["embedding"] == [1.0, 2.0, 3.0]
assert d["model"] == "text-embedding-3-large"
assert d["usage"]["prompt_tokens"] == 5
assert d["usage"]["total_tokens"] == 5
# =============================================================================
# Unit Tests - Embedding Engine
# =============================================================================
class TestEmbeddingEngine:
"""Test the EmbeddingEngine wrapper."""
@patch("vllm_mlx.embedding.EmbeddingEngine.load")
@patch(
"vllm_mlx.embedding.EmbeddingEngine.is_loaded",
new_callable=lambda: property(lambda self: True),
)
def test_embed_calls_model_directly(self, _mock_loaded, mock_load):
"""Test embed tokenizes and calls model directly (bypasses generate)."""
import numpy as np
from vllm_mlx.embedding import EmbeddingEngine
engine = EmbeddingEngine("test-model")
mock_output = MagicMock()
mock_output.text_embeds.tolist.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_model = MagicMock(return_value=mock_output)
mock_inner_tokenizer = MagicMock()
mock_inner_tokenizer.return_value = {
"input_ids": np.array([[1, 2], [3, 4]]),
"attention_mask": np.array([[1, 1], [1, 1]]),
}
mock_tokenizer = MagicMock()
mock_tokenizer._tokenizer = mock_inner_tokenizer
engine._model = mock_model
engine._tokenizer = mock_tokenizer
result = engine.embed(["hello", "world"])
mock_model.assert_called_once()
assert len(result) == 2
assert result[0] == [0.1, 0.2]
def test_embed_normalises_single_string(self):
"""Test that a single string input is wrapped into a list."""
import numpy as np
from vllm_mlx.embedding import EmbeddingEngine
engine = EmbeddingEngine("test-model")
mock_output = MagicMock()
mock_output.text_embeds.tolist.return_value = [[0.5, 0.6]]
mock_model = MagicMock(return_value=mock_output)
mock_inner_tokenizer = MagicMock()
mock_inner_tokenizer.return_value = {
"input_ids": np.array([[1, 2]]),
"attention_mask": np.array([[1, 1]]),
}
mock_tokenizer = MagicMock()
mock_tokenizer._tokenizer = mock_inner_tokenizer
with patch.object(engine, "_ensure_loaded"):
engine._model = mock_model
engine._tokenizer = mock_tokenizer
result = engine.embed("single text")
assert len(result) == 1
def test_count_tokens(self):
"""Test token counting for usage reporting."""
from vllm_mlx.embedding import EmbeddingEngine
engine = EmbeddingEngine("test-model")
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
engine._tokenizer = mock_tokenizer
engine._model = MagicMock() # mark as loaded
count = engine.count_tokens(["hello", "world"])
assert count == 10 # 5 tokens * 2 texts
# =============================================================================
# Integration Tests - FastAPI Endpoint
# =============================================================================
class TestEmbeddingsEndpoint:
"""Test the /v1/embeddings endpoint via TestClient."""
@pytest.fixture()
def client(self):
"""Create a FastAPI test client with mocked embedding engine."""
from fastapi.testclient import TestClient
from vllm_mlx.server import app
return TestClient(app)
def test_batch_input_preserves_order(self, client):
"""Test batch embedding returns vectors with correct indices."""
import vllm_mlx.server as srv
texts = ["first", "second", "third"]
mock_engine = MagicMock()
mock_engine.model_name = "test-embed"
mock_engine.embed.return_value = [
[1.0, 0.0],
[0.0, 1.0],
[0.5, 0.5],
]
mock_engine.count_tokens.return_value = 9
original = srv._embedding_engine
srv._embedding_engine = mock_engine
try:
resp = client.post(
"/v1/embeddings",
json={"model": "test-embed", "input": texts},
)
finally:
srv._embedding_engine = original
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 3
for i in range(3):
assert body["data"][i]["index"] == i
# Verify order matches
assert body["data"][0]["embedding"] == [1.0, 0.0]
assert body["data"][2]["embedding"] == [0.5, 0.5]
def test_empty_input_returns_400(self, client):
"""Test that empty input list returns 400 error."""
import vllm_mlx.server as srv
mock_engine = MagicMock()
mock_engine.model_name = "test-embed"
original = srv._embedding_engine
srv._embedding_engine = mock_engine
try:
resp = client.post(
"/v1/embeddings",
json={"model": "test-embed", "input": []},
)
finally:
srv._embedding_engine = original
assert resp.status_code == 400
def test_model_hot_swap(self, client):
"""Test that requesting a different model triggers reload."""
import vllm_mlx.server as srv
mock_engine = MagicMock()
mock_engine.model_name = "old-model"
mock_engine.embed.return_value = [[0.1]]
mock_engine.count_tokens.return_value = 1
original = srv._embedding_engine
srv._embedding_engine = mock_engine
try:
with patch("vllm_mlx.embedding.EmbeddingEngine") as mock_cls:
new_engine = MagicMock()
new_engine.model_name = "new-model"
new_engine.embed.return_value = [[0.9]]
new_engine.count_tokens.return_value = 1
mock_cls.return_value = new_engine
resp = client.post(
"/v1/embeddings",
json={"model": "new-model", "input": "test"},
)
assert resp.status_code == 200
mock_cls.assert_called_once_with("new-model")
new_engine.load.assert_called_once()
finally:
srv._embedding_engine = original
def test_model_locked_rejects_different_model(self, client):
"""Test that a locked embedding model rejects requests for different models."""
import vllm_mlx.server as srv
mock_engine = MagicMock()
mock_engine.model_name = "locked-model"
original_engine = srv._embedding_engine
original_locked = srv._embedding_model_locked
srv._embedding_engine = mock_engine
srv._embedding_model_locked = "locked-model"
try:
resp = client.post(
"/v1/embeddings",
json={"model": "other-model", "input": "test"},
)
assert resp.status_code == 400
body = resp.json()
assert "locked-model" in body["detail"]
assert "other-model" in body["detail"]
finally:
srv._embedding_engine = original_engine
srv._embedding_model_locked = original_locked
# =============================================================================
# Slow Integration Test - Real Model
# =============================================================================
@pytest.mark.slow
class TestEmbeddingsRealModel:
"""Integration tests with a real mlx-embeddings model."""
@pytest.fixture(scope="class")
def engine(self):
pytest.importorskip("mlx_embeddings")
from vllm_mlx.embedding import EmbeddingEngine
eng = EmbeddingEngine("mlx-community/all-MiniLM-L6-v2-4bit")
eng.load()
return eng
def test_single_embedding_shape(self, engine):
"""Test that a single text produces a correctly shaped vector."""
result = engine.embed("Hello world")
assert len(result) == 1
assert len(result[0]) > 0 # non-empty embedding
assert all(isinstance(v, float) for v in result[0])