Skip to content

Commit 38c0a60

Browse files
Improve vectorizer kwargs and typing (redis#291)
## Changes Made 1. **Expanded Type Support**: - Updated return type signatures across all vectorizers to properly reflect the ability to return either data lists (`List[float]`) or binary buffers (`bytes`) - Added special handling for Cohere's integer embedding types (`List[int]`) 2. **Standardized Interface**: - Uniform type annotations and docstrings across all vectorizer implementations - Consistent default batch sizes (10) for better predictability 3. **Improved Provider-Specific Support**: - Enhanced kwargs forwarding to allow passing provider-specific parameters - Better warnings for deprecated parameters (like Cohere's `embedding_types`) 4. **Fixed Type Checking**: - Added strategic type ignores to resolve MyPy errors - Made minimal changes to consumer code to handle the expanded return types ## Motivation These changes create a more consistent and flexible vectorizer interface that: - Accurately represents what the methods can return - Accommodates provider-specific features (like Cohere's integer embeddings) - Provides clearer documentation for users - Maintains backward compatibility ## Future Improvements For future consideration: - Introduce helper methods (like `embed_as_list()`) that guarantee specific return types when needed - Add more robust type conversion in consumer code that relies on specific types - Develop a cleaner separation between the base vectorizer interface and provider-specific extensions - Consider a more structured approach to provider-specific parameters
1 parent ad9bb21 commit 38c0a60

File tree

14 files changed

+363
-118
lines changed

14 files changed

+363
-118
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,17 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
310310
if not isinstance(prompt, str):
311311
raise TypeError("Prompt must be a string.")
312312

313-
return self._vectorizer.embed(prompt)
313+
result = self._vectorizer.embed(prompt)
314+
return result # type: ignore
314315

315316
async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
316317
"""Converts a text prompt to its vector representation using the
317318
configured vectorizer."""
318319
if not isinstance(prompt, str):
319320
raise TypeError("Prompt must be a string.")
320321

321-
return await self._vectorizer.aembed(prompt)
322+
result = await self._vectorizer.aembed(prompt)
323+
return result # type: ignore
322324

323325
def _check_vector_dims(self, vector: List[float]):
324326
"""Checks the size of the provided vector and raises an error if it

redisvl/extensions/router/semantic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,14 @@ def __call__(
366366
if not vector:
367367
if not statement:
368368
raise ValueError("Must provide a vector or statement to the router")
369-
vector = self.vectorizer.embed(statement)
369+
vector = self.vectorizer.embed(statement) # type: ignore
370370

371371
aggregation_method = (
372372
aggregation_method or self.routing_config.aggregation_method
373373
)
374374

375375
# perform route classification
376-
top_route_match = self._classify_route(vector, aggregation_method)
376+
top_route_match = self._classify_route(vector, aggregation_method) # type: ignore
377377
return top_route_match
378378

379379
@deprecated_argument("distance_threshold")
@@ -400,7 +400,7 @@ def route_many(
400400
if not vector:
401401
if not statement:
402402
raise ValueError("Must provide a vector or statement to the router")
403-
vector = self.vectorizer.embed(statement)
403+
vector = self.vectorizer.embed(statement) # type: ignore
404404

405405
max_k = max_k or self.routing_config.max_k
406406
aggregation_method = (
@@ -409,7 +409,7 @@ def route_many(
409409

410410
# classify routes
411411
top_route_matches = self._classify_multi_route(
412-
vector, max_k, aggregation_method
412+
vector, max_k, aggregation_method # type: ignore
413413
)
414414

415415
return top_route_matches

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def add_messages(
349349
role=message[ROLE_FIELD_NAME],
350350
content=message[CONTENT_FIELD_NAME],
351351
session_tag=session_tag,
352-
vector_field=content_vector,
352+
vector_field=content_vector, # type: ignore
353353
)
354354

355355
if TOOL_FIELD_NAME in message:

redisvl/utils/vectorize/base.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from enum import Enum
3-
from typing import Callable, List, Optional
3+
from typing import Callable, List, Optional, Union
44

55
from pydantic import BaseModel, Field, field_validator
66

@@ -49,34 +49,69 @@ def check_dims(cls, value):
4949
return value
5050

5151
@abstractmethod
52-
def embed_many(
52+
def embed(
5353
self,
54-
texts: List[str],
54+
text: str,
5555
preprocess: Optional[Callable] = None,
56-
batch_size: int = 1000,
5756
as_buffer: bool = False,
5857
**kwargs,
59-
) -> List[List[float]]:
58+
) -> Union[List[float], bytes]:
59+
"""Embed a chunk of text.
60+
61+
Args:
62+
text: Text to embed
63+
preprocess: Optional function to preprocess text
64+
as_buffer: If True, returns a bytes object instead of a list
65+
66+
Returns:
67+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
68+
object if as_buffer=True
69+
"""
6070
raise NotImplementedError
6171

6272
@abstractmethod
63-
def embed(
73+
def embed_many(
6474
self,
65-
text: str,
75+
texts: List[str],
6676
preprocess: Optional[Callable] = None,
77+
batch_size: int = 10,
6778
as_buffer: bool = False,
6879
**kwargs,
69-
) -> List[float]:
80+
) -> Union[List[List[float]], List[bytes]]:
81+
"""Embed multiple chunks of text.
82+
83+
Args:
84+
texts: List of texts to embed
85+
preprocess: Optional function to preprocess text
86+
batch_size: Number of texts to process in each batch
87+
as_buffer: If True, returns each embedding as a bytes object
88+
89+
Returns:
90+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
91+
or as bytes objects if as_buffer=True
92+
"""
7093
raise NotImplementedError
7194

7295
async def aembed_many(
7396
self,
7497
texts: List[str],
7598
preprocess: Optional[Callable] = None,
76-
batch_size: int = 1000,
99+
batch_size: int = 10,
77100
as_buffer: bool = False,
78101
**kwargs,
79-
) -> List[List[float]]:
102+
) -> Union[List[List[float]], List[bytes]]:
103+
"""Asynchronously embed multiple chunks of text.
104+
105+
Args:
106+
texts: List of texts to embed
107+
preprocess: Optional function to preprocess text
108+
batch_size: Number of texts to process in each batch
109+
as_buffer: If True, returns each embedding as a bytes object
110+
111+
Returns:
112+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
113+
or as bytes objects if as_buffer=True
114+
"""
80115
# Fallback to standard embedding call if no async support
81116
return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs)
82117

@@ -86,7 +121,18 @@ async def aembed(
86121
preprocess: Optional[Callable] = None,
87122
as_buffer: bool = False,
88123
**kwargs,
89-
) -> List[float]:
124+
) -> Union[List[float], bytes]:
125+
"""Asynchronously embed a chunk of text.
126+
127+
Args:
128+
text: Text to embed
129+
preprocess: Optional function to preprocess text
130+
as_buffer: If True, returns a bytes object instead of a list
131+
132+
Returns:
133+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
134+
object if as_buffer=True
135+
"""
90136
# Fallback to standard embedding call if no async support
91137
return self.embed(text, preprocess, as_buffer, **kwargs)
92138

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Callable, Dict, List, Optional
2+
from typing import Any, Callable, Dict, List, Optional, Union
33

44
from pydantic import PrivateAttr
55
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -178,7 +178,7 @@ def embed_many(
178178
batch_size: int = 10,
179179
as_buffer: bool = False,
180180
**kwargs,
181-
) -> List[List[float]]:
181+
) -> Union[List[List[float]], List[bytes]]:
182182
"""Embed many chunks of texts using the AzureOpenAI API.
183183
184184
Args:
@@ -191,7 +191,8 @@ def embed_many(
191191
to a byte string. Defaults to False.
192192
193193
Returns:
194-
List[List[float]]: List of embeddings.
194+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
195+
or as bytes objects if as_buffer=True
195196
196197
Raises:
197198
TypeError: If the wrong input type is passed in for the test.
@@ -205,7 +206,9 @@ def embed_many(
205206

206207
embeddings: List = []
207208
for batch in self.batchify(texts, batch_size, preprocess):
208-
response = self._client.embeddings.create(input=batch, model=self.model)
209+
response = self._client.embeddings.create(
210+
input=batch, model=self.model, **kwargs
211+
)
209212
embeddings += [
210213
self._process_embedding(r.embedding, as_buffer, dtype)
211214
for r in response.data
@@ -224,7 +227,7 @@ def embed(
224227
preprocess: Optional[Callable] = None,
225228
as_buffer: bool = False,
226229
**kwargs,
227-
) -> List[float]:
230+
) -> Union[List[float], bytes]:
228231
"""Embed a chunk of text using the AzureOpenAI API.
229232
230233
Args:
@@ -235,7 +238,8 @@ def embed(
235238
to a byte string. Defaults to False.
236239
237240
Returns:
238-
List[float]: Embedding.
241+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
242+
object if as_buffer=True
239243
240244
Raises:
241245
TypeError: If the wrong input type is passed in for the test.
@@ -248,7 +252,9 @@ def embed(
248252

249253
dtype = kwargs.pop("dtype", self.dtype)
250254

251-
result = self._client.embeddings.create(input=[text], model=self.model)
255+
result = self._client.embeddings.create(
256+
input=[text], model=self.model, **kwargs
257+
)
252258
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
253259

254260
@retry(
@@ -261,10 +267,10 @@ async def aembed_many(
261267
self,
262268
texts: List[str],
263269
preprocess: Optional[Callable] = None,
264-
batch_size: int = 1000,
270+
batch_size: int = 10,
265271
as_buffer: bool = False,
266272
**kwargs,
267-
) -> List[List[float]]:
273+
) -> Union[List[List[float]], List[bytes]]:
268274
"""Asynchronously embed many chunks of texts using the AzureOpenAI API.
269275
270276
Args:
@@ -277,7 +283,8 @@ async def aembed_many(
277283
to a byte string. Defaults to False.
278284
279285
Returns:
280-
List[List[float]]: List of embeddings.
286+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
287+
or as bytes objects if as_buffer=True
281288
282289
Raises:
283290
TypeError: If the wrong input type is passed in for the test.
@@ -292,7 +299,7 @@ async def aembed_many(
292299
embeddings: List = []
293300
for batch in self.batchify(texts, batch_size, preprocess):
294301
response = await self._aclient.embeddings.create(
295-
input=batch, model=self.model
302+
input=batch, model=self.model, **kwargs
296303
)
297304
embeddings += [
298305
self._process_embedding(r.embedding, as_buffer, dtype)
@@ -312,7 +319,7 @@ async def aembed(
312319
preprocess: Optional[Callable] = None,
313320
as_buffer: bool = False,
314321
**kwargs,
315-
) -> List[float]:
322+
) -> Union[List[float], bytes]:
316323
"""Asynchronously embed a chunk of text using the OpenAI API.
317324
318325
Args:
@@ -323,7 +330,8 @@ async def aembed(
323330
to a byte string. Defaults to False.
324331
325332
Returns:
326-
List[float]: Embedding.
333+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
334+
object if as_buffer=True
327335
328336
Raises:
329337
TypeError: If the wrong input type is passed in for the test.
@@ -336,7 +344,9 @@ async def aembed(
336344

337345
dtype = kwargs.pop("dtype", self.dtype)
338346

339-
result = await self._aclient.embeddings.create(input=[text], model=self.model)
347+
result = await self._aclient.embeddings.create(
348+
input=[text], model=self.model, **kwargs
349+
)
340350
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
341351

342352
@property

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from typing import Any, Callable, Dict, List, Optional
3+
from typing import Any, Callable, Dict, List, Optional, Union
44

55
from pydantic import PrivateAttr
66
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -135,16 +135,17 @@ def embed(
135135
preprocess: Optional[Callable] = None,
136136
as_buffer: bool = False,
137137
**kwargs,
138-
) -> List[float]:
139-
"""Embed a chunk of text using Amazon Bedrock.
138+
) -> Union[List[float], bytes]:
139+
"""Embed a chunk of text using the AWS Bedrock Embeddings API.
140140
141141
Args:
142142
text (str): Text to embed.
143143
preprocess (Optional[Callable]): Optional preprocessing function.
144144
as_buffer (bool): Whether to return as byte buffer.
145145
146146
Returns:
147-
List[float]: The embedding vector.
147+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
148+
object if as_buffer=True
148149
149150
Raises:
150151
TypeError: If text is not a string.
@@ -156,7 +157,7 @@ def embed(
156157
text = preprocess(text)
157158

158159
response = self._client.invoke_model(
159-
modelId=self.model, body=json.dumps({"inputText": text})
160+
modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
160161
)
161162
response_body = json.loads(response["body"].read())
162163
embedding = response_body["embedding"]
@@ -177,17 +178,18 @@ def embed_many(
177178
batch_size: int = 10,
178179
as_buffer: bool = False,
179180
**kwargs,
180-
) -> List[List[float]]:
181-
"""Embed multiple texts using Amazon Bedrock.
181+
) -> Union[List[List[float]], List[bytes]]:
182+
"""Embed many chunks of text using the AWS Bedrock Embeddings API.
182183
183184
Args:
184185
texts (List[str]): List of texts to embed.
185186
preprocess (Optional[Callable]): Optional preprocessing function.
186-
batch_size (int): Size of batches for processing.
187+
batch_size (int): Size of batches for processing. Defaults to 10.
187188
as_buffer (bool): Whether to return as byte buffers.
188189
189190
Returns:
190-
List[List[float]]: List of embedding vectors.
191+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
192+
or as bytes objects if as_buffer=True
191193
192194
Raises:
193195
TypeError: If texts is not a list of strings.
@@ -206,7 +208,7 @@ def embed_many(
206208
batch_embeddings = []
207209
for text in batch:
208210
response = self._client.invoke_model(
209-
modelId=self.model, body=json.dumps({"inputText": text})
211+
modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
210212
)
211213
response_body = json.loads(response["body"].read())
212214
batch_embeddings.append(response_body["embedding"])

0 commit comments

Comments
 (0)