Skip to content

Commit 0fbdd4e

Browse files
authored
Refactoring VoyageAI integration (#3878)
Using VoyageAI's python package directly, allowing more features than through langchain
1 parent 238f985 commit 0fbdd4e

File tree

5 files changed

+72
-18
lines changed

5 files changed

+72
-18
lines changed

Diff for: CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.16.17-dev1
22

33
### Enhancements
4+
- **Refactoring the VoyageAI integration** to use voyageai package directly, allowing extra features.
45

56
### Features
67

Diff for: test_unstructured/embed/test_voyageai.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1+
from unittest.mock import Mock
2+
13
from unstructured.documents.elements import Text
24
from unstructured.embed.voyageai import VoyageAIEmbeddingConfig, VoyageAIEmbeddingEncoder
35

46

57
def test_embed_documents_does_not_break_element_to_dict(mocker):
68
# Mocked client with the desired behavior for embed_documents
9+
embed_response = Mock()
10+
embed_response.embeddings = [[1], [2]]
711
mock_client = mocker.MagicMock()
8-
mock_client.embed_documents.return_value = [1, 2]
12+
mock_client.embed.return_value = embed_response
913

1014
# Mock get_client to return our mock_client
1115
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
1216

1317
encoder = VoyageAIEmbeddingEncoder(
14-
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-law-2")
18+
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3-large")
1519
)
1620
elements = encoder.embed_documents(
1721
elements=[Text("This is sentence 1"), Text("This is sentence 2")],

Diff for: test_unstructured_ingest/src/local-embed-voyageai.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ PYTHONPATH=${PYTHONPATH:-.} "$RUN_SCRIPT" \
3737
--work-dir "$WORK_DIR" \
3838
--embedding-provider "voyageai" \
3939
--embedding-api-key "$VOYAGE_API_KEY" \
40-
--embedding-model-name "voyage-large-2"
40+
--embedding-model-name "voyage-3-large"
4141

4242
set +e
4343

Diff for: unstructured/embed/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"langchain-huggingface": HuggingFaceEmbeddingEncoder,
1414
"langchain-aws-bedrock": BedrockEmbeddingEncoder,
1515
"langchain-vertexai": VertexAIEmbeddingEncoder,
16-
"langchain-voyageai": VoyageAIEmbeddingEncoder,
16+
"voyageai": VoyageAIEmbeddingEncoder,
1717
"mixedbread-ai": MixedbreadAIEmbeddingEncoder,
1818
"octoai": OctoAIEmbeddingEncoder,
1919
}

Diff for: unstructured/embed/voyageai.py

+63-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING, List, Optional
2+
from typing import TYPE_CHECKING, Iterable, List, Optional, cast
33

44
import numpy as np
55
from pydantic import Field, SecretStr
@@ -9,30 +9,46 @@
99
from unstructured.utils import requires_dependencies
1010

1111
if TYPE_CHECKING:
12-
from langchain_voyageai import VoyageAIEmbeddings
12+
from voyageai import Client
13+
14+
DEFAULT_VOYAGE_2_BATCH_SIZE = 72
15+
DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30
16+
DEFAULT_VOYAGE_3_BATCH_SIZE = 10
17+
DEFAULT_BATCH_SIZE = 7
1318

1419

1520
class VoyageAIEmbeddingConfig(EmbeddingConfig):
1621
api_key: SecretStr
1722
model_name: str
23+
show_progress_bar: bool = False
1824
batch_size: Optional[int] = Field(default=None)
1925
truncation: Optional[bool] = Field(default=None)
26+
output_dimension: Optional[int] = Field(default=None)
2027

2128
@requires_dependencies(
22-
["langchain", "langchain_voyageai"],
29+
["voyageai"],
2330
extras="embed-voyageai",
2431
)
25-
def get_client(self) -> "VoyageAIEmbeddings":
26-
"""Creates a Langchain VoyageAI python client to embed elements."""
27-
from langchain_voyageai import VoyageAIEmbeddings
28-
29-
return VoyageAIEmbeddings(
30-
voyage_api_key=self.api_key,
31-
model=self.model_name,
32-
batch_size=self.batch_size,
33-
truncation=self.truncation,
32+
def get_client(self) -> "Client":
33+
"""Creates a VoyageAI python client to embed elements."""
34+
from voyageai import Client
35+
36+
return Client(
37+
api_key=self.api_key.get_secret_value(),
3438
)
3539

40+
def get_batch_size(self):
41+
if self.batch_size is None:
42+
if self.model_name in ["voyage-2", "voyage-02"]:
43+
self.batch_size = DEFAULT_VOYAGE_2_BATCH_SIZE
44+
elif self.model_name == "voyage-3-lite":
45+
self.batch_size = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE
46+
elif self.model_name == "voyage-3":
47+
self.batch_size = DEFAULT_VOYAGE_3_BATCH_SIZE
48+
else:
49+
self.batch_size = DEFAULT_BATCH_SIZE
50+
return self.batch_size
51+
3652

3753
@dataclass
3854
class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
@@ -56,12 +72,29 @@ def is_unit_vector(self) -> bool:
5672

5773
def embed_documents(self, elements: List[Element]) -> List[Element]:
5874
client = self.config.get_client()
59-
embeddings = client.embed_documents([str(e) for e in elements])
75+
embeddings: List[List[float]] = []
76+
77+
_iter = self._get_batch_iterator(elements)
78+
for i in _iter:
79+
r = client.embed(
80+
texts=[str(e) for e in elements[i : i + self.config.get_batch_size()]],
81+
model=self.config.model_name,
82+
input_type="document",
83+
truncation=self.config.truncation,
84+
output_dimension=self.config.output_dimension,
85+
).embeddings
86+
embeddings.extend(cast(Iterable[List[float]], r))
6087
return self._add_embeddings_to_elements(elements, embeddings)
6188

6289
def embed_query(self, query: str) -> List[float]:
6390
client = self.config.get_client()
64-
return client.embed_query(query)
91+
return client.embed(
92+
texts=[query],
93+
model=self.config.model_name,
94+
input_type="query",
95+
truncation=self.config.truncation,
96+
output_dimension=self.config.output_dimension,
97+
).embeddings[0]
6598

6699
@staticmethod
67100
def _add_embeddings_to_elements(elements, embeddings) -> List[Element]:
@@ -71,3 +104,19 @@ def _add_embeddings_to_elements(elements, embeddings) -> List[Element]:
71104
element.embeddings = embeddings[i]
72105
elements_w_embedding.append(element)
73106
return elements
107+
108+
def _get_batch_iterator(self, elements: List[Element]) -> Iterable:
109+
if self.config.show_progress_bar:
110+
try:
111+
from tqdm.auto import tqdm # type: ignore
112+
except ImportError as e:
113+
raise ImportError(
114+
"Must have tqdm installed if `show_progress_bar` is set to True. "
115+
"Please install with `pip install tqdm`."
116+
) from e
117+
118+
_iter = tqdm(range(0, len(elements), self.config.get_batch_size()))
119+
else:
120+
_iter = range(0, len(elements), self.config.get_batch_size()) # type: ignore
121+
122+
return _iter

0 commit comments

Comments
 (0)