1
1
from dataclasses import dataclass
2
- from typing import TYPE_CHECKING , List , Optional
2
+ from typing import TYPE_CHECKING , Iterable , List , Optional , cast
3
3
4
4
import numpy as np
5
5
from pydantic import Field , SecretStr
9
9
from unstructured .utils import requires_dependencies
10
10
11
11
if TYPE_CHECKING :
12
- from langchain_voyageai import VoyageAIEmbeddings
12
+ from voyageai import Client
13
+
14
+ DEFAULT_VOYAGE_2_BATCH_SIZE = 72
15
+ DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30
16
+ DEFAULT_VOYAGE_3_BATCH_SIZE = 10
17
+ DEFAULT_BATCH_SIZE = 7
13
18
14
19
15
20
class VoyageAIEmbeddingConfig (EmbeddingConfig ):
16
21
api_key : SecretStr
17
22
model_name : str
23
+ show_progress_bar : bool = False
18
24
batch_size : Optional [int ] = Field (default = None )
19
25
truncation : Optional [bool ] = Field (default = None )
26
+ output_dimension : Optional [int ] = Field (default = None )
20
27
21
28
@requires_dependencies (
22
- ["langchain" , "langchain_voyageai " ],
29
+ ["voyageai " ],
23
30
extras = "embed-voyageai" ,
24
31
)
25
- def get_client (self ) -> "VoyageAIEmbeddings" :
26
- """Creates a Langchain VoyageAI python client to embed elements."""
27
- from langchain_voyageai import VoyageAIEmbeddings
28
-
29
- return VoyageAIEmbeddings (
30
- voyage_api_key = self .api_key ,
31
- model = self .model_name ,
32
- batch_size = self .batch_size ,
33
- truncation = self .truncation ,
32
+ def get_client (self ) -> "Client" :
33
+ """Creates a VoyageAI python client to embed elements."""
34
+ from voyageai import Client
35
+
36
+ return Client (
37
+ api_key = self .api_key .get_secret_value (),
34
38
)
35
39
40
+ def get_batch_size (self ):
41
+ if self .batch_size is None :
42
+ if self .model_name in ["voyage-2" , "voyage-02" ]:
43
+ self .batch_size = DEFAULT_VOYAGE_2_BATCH_SIZE
44
+ elif self .model_name == "voyage-3-lite" :
45
+ self .batch_size = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE
46
+ elif self .model_name == "voyage-3" :
47
+ self .batch_size = DEFAULT_VOYAGE_3_BATCH_SIZE
48
+ else :
49
+ self .batch_size = DEFAULT_BATCH_SIZE
50
+ return self .batch_size
51
+
36
52
37
53
@dataclass
38
54
class VoyageAIEmbeddingEncoder (BaseEmbeddingEncoder ):
@@ -56,12 +72,29 @@ def is_unit_vector(self) -> bool:
56
72
57
73
def embed_documents (self , elements : List [Element ]) -> List [Element ]:
58
74
client = self .config .get_client ()
59
- embeddings = client .embed_documents ([str (e ) for e in elements ])
75
+ embeddings : List [List [float ]] = []
76
+
77
+ _iter = self ._get_batch_iterator (elements )
78
+ for i in _iter :
79
+ r = client .embed (
80
+ texts = [str (e ) for e in elements [i : i + self .config .get_batch_size ()]],
81
+ model = self .config .model_name ,
82
+ input_type = "document" ,
83
+ truncation = self .config .truncation ,
84
+ output_dimension = self .config .output_dimension ,
85
+ ).embeddings
86
+ embeddings .extend (cast (Iterable [List [float ]], r ))
60
87
return self ._add_embeddings_to_elements (elements , embeddings )
61
88
62
89
def embed_query (self , query : str ) -> List [float ]:
63
90
client = self .config .get_client ()
64
- return client .embed_query (query )
91
+ return client .embed (
92
+ texts = [query ],
93
+ model = self .config .model_name ,
94
+ input_type = "query" ,
95
+ truncation = self .config .truncation ,
96
+ output_dimension = self .config .output_dimension ,
97
+ ).embeddings [0 ]
65
98
66
99
@staticmethod
67
100
def _add_embeddings_to_elements (elements , embeddings ) -> List [Element ]:
@@ -71,3 +104,19 @@ def _add_embeddings_to_elements(elements, embeddings) -> List[Element]:
71
104
element .embeddings = embeddings [i ]
72
105
elements_w_embedding .append (element )
73
106
return elements
107
+
108
+ def _get_batch_iterator (self , elements : List [Element ]) -> Iterable :
109
+ if self .config .show_progress_bar :
110
+ try :
111
+ from tqdm .auto import tqdm # type: ignore
112
+ except ImportError as e :
113
+ raise ImportError (
114
+ "Must have tqdm installed if `show_progress_bar` is set to True. "
115
+ "Please install with `pip install tqdm`."
116
+ ) from e
117
+
118
+ _iter = tqdm (range (0 , len (elements ), self .config .get_batch_size ()))
119
+ else :
120
+ _iter = range (0 , len (elements ), self .config .get_batch_size ()) # type: ignore
121
+
122
+ return _iter
0 commit comments