Skip to content

Commit 063e6ce

Browse files
authored
(feat) VoyageAI integration improvements (#4109)
Supporting VoyageAI's contextual model Counting tokens and creating efficient batches Documentation change: Unstructured-IO/docs#790
1 parent 15253a5 commit 063e6ce

File tree

4 files changed

+388
-50
lines changed

4 files changed

+388
-50
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 0.18.20
2+
3+
### Enhancement
4+
- Improve the VoyageAI integration
5+
- Add voyage-context-3 support
6+
17
## 0.18.19-dev0
28

39
### Enhancement

test_unstructured/embed/test_voyageai.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def test_embed_documents_does_not_break_element_to_dict(mocker):
1010
embed_response.embeddings = [[1], [2]]
1111
mock_client = mocker.MagicMock()
1212
mock_client.embed.return_value = embed_response
13+
mock_client.tokenize.return_value = [[1], [1]] # Mock token counts
1314

1415
# Mock get_client to return our mock_client
1516
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
@@ -23,3 +24,219 @@ def test_embed_documents_does_not_break_element_to_dict(mocker):
2324
assert len(elements) == 2
2425
assert elements[0].to_dict()["text"] == "This is sentence 1"
2526
assert elements[1].to_dict()["text"] == "This is sentence 2"
27+
28+
29+
def test_embed_documents_voyage_3_5(mocker):
30+
"""Test embedding with voyage-3.5 model."""
31+
embed_response = Mock()
32+
embed_response.embeddings = [[1.0] * 1024, [2.0] * 1024]
33+
mock_client = mocker.MagicMock()
34+
mock_client.embed.return_value = embed_response
35+
mock_client.tokenize.return_value = [[1, 2, 3], [1, 2]] # Mock token counts
36+
37+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
38+
39+
encoder = VoyageAIEmbeddingEncoder(
40+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5")
41+
)
42+
elements = encoder.embed_documents(
43+
elements=[Text("Test document 1"), Text("Test document 2")],
44+
)
45+
assert len(elements) == 2
46+
assert len(elements[0].embeddings) == 1024
47+
assert len(elements[1].embeddings) == 1024
48+
49+
50+
def test_embed_documents_voyage_3_5_lite(mocker):
51+
"""Test embedding with voyage-3.5-lite model."""
52+
embed_response = Mock()
53+
embed_response.embeddings = [[1.0] * 512, [2.0] * 512, [3.0] * 512]
54+
mock_client = mocker.MagicMock()
55+
mock_client.embed.return_value = embed_response
56+
mock_client.tokenize.return_value = [[1], [1], [1]] # Mock token counts
57+
58+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
59+
60+
encoder = VoyageAIEmbeddingEncoder(
61+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5-lite")
62+
)
63+
elements = encoder.embed_documents(
64+
elements=[Text("Test 1"), Text("Test 2"), Text("Test 3")],
65+
)
66+
assert len(elements) == 3
67+
assert all(len(e.embeddings) == 512 for e in elements)
68+
69+
70+
def test_embed_documents_contextual_model(mocker):
71+
"""Test embedding with voyage-context-3 model."""
72+
# Mock contextualized_embed response
73+
contextualized_response = Mock()
74+
result_item = Mock()
75+
result_item.embeddings = [[1.0] * 1024, [2.0] * 1024]
76+
contextualized_response.results = [result_item]
77+
78+
mock_client = mocker.MagicMock()
79+
mock_client.contextualized_embed.return_value = contextualized_response
80+
mock_client.tokenize.return_value = [[1, 2], [1, 2, 3]] # Mock token counts
81+
82+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
83+
84+
encoder = VoyageAIEmbeddingEncoder(
85+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-context-3")
86+
)
87+
elements = encoder.embed_documents(
88+
elements=[Text("Context document 1"), Text("Context document 2")],
89+
)
90+
assert len(elements) == 2
91+
assert len(elements[0].embeddings) == 1024
92+
assert len(elements[1].embeddings) == 1024
93+
# Verify contextualized_embed was called
94+
mock_client.contextualized_embed.assert_called_once()
95+
96+
97+
def test_count_tokens(mocker):
98+
"""Test token counting functionality."""
99+
mock_client = mocker.MagicMock()
100+
mock_client.tokenize.return_value = [[1, 2], [1, 2, 3, 4, 5]] # Different token counts
101+
102+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
103+
104+
encoder = VoyageAIEmbeddingEncoder(
105+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5")
106+
)
107+
texts = ["short text", "this is a longer text with more tokens"]
108+
token_counts = encoder.count_tokens(texts)
109+
110+
assert len(token_counts) == 2
111+
assert token_counts[0] == 2
112+
assert token_counts[1] == 5
113+
114+
115+
def test_count_tokens_empty_list(mocker):
116+
"""Test token counting with empty list."""
117+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mocker.MagicMock())
118+
119+
encoder = VoyageAIEmbeddingEncoder(
120+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5")
121+
)
122+
token_counts = encoder.count_tokens([])
123+
assert token_counts == []
124+
125+
126+
def test_get_token_limit(mocker):
127+
"""Test getting token limit for different models."""
128+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mocker.MagicMock())
129+
130+
# Test voyage-3.5 model
131+
config = VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5")
132+
assert config.get_token_limit() == 320_000
133+
134+
# Test voyage-3.5-lite model
135+
config_lite = VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5-lite")
136+
assert config_lite.get_token_limit() == 1_000_000
137+
138+
# Test context model
139+
config_context = VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-context-3")
140+
assert config_context.get_token_limit() == 32_000
141+
142+
# Test voyage-2 model
143+
config_v2 = VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-2")
144+
assert config_v2.get_token_limit() == 320_000
145+
146+
# Test unknown model (should use default)
147+
config_unknown = VoyageAIEmbeddingConfig(api_key="api_key", model_name="unknown-model")
148+
assert config_unknown.get_token_limit() == 120_000
149+
150+
151+
def test_is_context_model(mocker):
152+
"""Test the _is_context_model helper method."""
153+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mocker.MagicMock())
154+
155+
# Test with context model
156+
encoder_context = VoyageAIEmbeddingEncoder(
157+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-context-3")
158+
)
159+
assert encoder_context._is_context_model() is True
160+
161+
# Test with regular model
162+
encoder_regular = VoyageAIEmbeddingEncoder(
163+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5")
164+
)
165+
assert encoder_regular._is_context_model() is False
166+
167+
168+
def test_build_batches_with_token_limits(mocker):
169+
"""Test that batching respects token limits."""
170+
mock_client = mocker.MagicMock()
171+
# Simulate different token counts for each text
172+
mock_client.tokenize.return_value = [[1] * 10, [1] * 20, [1] * 15, [1] * 25]
173+
174+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
175+
176+
encoder = VoyageAIEmbeddingEncoder(
177+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-2")
178+
)
179+
texts = ["text1", "text2", "text3", "text4"]
180+
batches = list(encoder._build_batches(texts, mock_client))
181+
182+
# Should create at least one batch
183+
assert len(batches) >= 1
184+
# Total texts should be preserved
185+
total_texts = sum(len(batch) for batch in batches)
186+
assert total_texts == len(texts)
187+
188+
189+
def test_embed_query(mocker):
190+
"""Test embedding a single query."""
191+
embed_response = Mock()
192+
embed_response.embeddings = [[1.0] * 1024]
193+
mock_client = mocker.MagicMock()
194+
mock_client.embed.return_value = embed_response
195+
196+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
197+
198+
encoder = VoyageAIEmbeddingEncoder(
199+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5")
200+
)
201+
embedding = encoder.embed_query("test query")
202+
203+
assert len(embedding) == 1024
204+
# Verify embed was called with input_type="query"
205+
mock_client.embed.assert_called_once()
206+
call_kwargs = mock_client.embed.call_args[1]
207+
assert call_kwargs["input_type"] == "query"
208+
209+
210+
def test_embed_documents_with_output_dimension(mocker):
211+
"""Test embedding with custom output dimension."""
212+
embed_response = Mock()
213+
embed_response.embeddings = [[1.0] * 512, [2.0] * 512]
214+
mock_client = mocker.MagicMock()
215+
mock_client.embed.return_value = embed_response
216+
mock_client.tokenize.return_value = [[1], [1]]
217+
218+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
219+
220+
encoder = VoyageAIEmbeddingEncoder(
221+
config=VoyageAIEmbeddingConfig(
222+
api_key="api_key", model_name="voyage-3.5", output_dimension=512
223+
)
224+
)
225+
elements = encoder.embed_documents(
226+
elements=[Text("Test 1"), Text("Test 2")],
227+
)
228+
assert len(elements) == 2
229+
# Verify output_dimension was passed
230+
call_kwargs = mock_client.embed.call_args[1]
231+
assert call_kwargs["output_dimension"] == 512
232+
233+
234+
def test_embed_documents_empty_list(mocker):
235+
"""Test embedding empty list of documents."""
236+
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mocker.MagicMock())
237+
238+
encoder = VoyageAIEmbeddingEncoder(
239+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3.5")
240+
)
241+
elements = encoder.embed_documents(elements=[])
242+
assert elements == []

unstructured/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.18.19-dev0" # pragma: no cover
1+
__version__ = "0.18.20" # pragma: no cover

0 commit comments

Comments
 (0)