Skip to content

Commit ce24503

Browse files
authored
Add check to Truncate Embeddings at Max Input (#2153)
* Add JS test and make sure it fails * Add in truncation logic * Fix imports * ruff * No object
1 parent dba6824 commit ce24503

File tree

9 files changed

+94
-9
lines changed

9 files changed

+94
-9
lines changed

js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,15 @@ describe("r2rClient V3 Documents Integration Tests", () => {
166166
const response = await client.documents.delete({ id: documentId });
167167
expect(response.results).toBeDefined();
168168
});
169+
170+
test("Get an embedding that exceeds the context window", async () => {
171+
const longText = "Hello world! ".repeat(8192);
172+
173+
const response = await client.retrieval.embedding({
174+
text: longText,
175+
});
176+
177+
expect(response.results).toBeDefined();
178+
expect(response.results.length).toBeGreaterThan(0);
179+
}, 30000);
169180
});

js/sdk/package-lock.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

js/sdk/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "r2r-js",
3-
"version": "0.4.37",
3+
"version": "0.4.38",
44
"description": "",
55
"main": "dist/index.js",
66
"browser": "dist/index.browser.js",

js/sdk/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ export type WrappedRelationshipsResponse = PaginatedResultsWrapper<
435435
// Retrieval Responses
436436
export type WrappedVectorSearchResponse = ResultsWrapper<VectorSearchResult[]>;
437437
export type WrappedSearchResponse = ResultsWrapper<CombinedSearchResponse>;
438+
export type WrappedEmbeddingResponse = ResultsWrapper<number[]>;
438439

439440
// System Responses
440441
export type WrappedSettingsResponse = ResultsWrapper<SettingsResponse>;

js/sdk/src/v3/clients/retrieval.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
GenerationConfig,
55
Message,
66
SearchSettings,
7+
WrappedEmbeddingResponse,
78
WrappedSearchResponse,
89
} from "../../types";
910
import { ensureSnakeCase } from "../../utils";
@@ -312,9 +313,11 @@ export class RetrievalClient {
312313
* @param text Text to generate embeddings for
313314
* @returns Vector embedding of the input text
314315
*/
315-
async embedding(text: string): Promise<number[]> {
316+
async embedding(options: {
317+
text: string;
318+
}): Promise<WrappedEmbeddingResponse> {
316319
return await this.client.makeRequest("POST", "retrieval/embedding", {
317-
data: { text },
320+
data: options.text,
318321
});
319322
}
320323
}

py/core/providers/embeddings/litellm.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import logging
23
import math
34
import os
@@ -16,6 +17,8 @@
1617
R2RException,
1718
)
1819

20+
from .utils import truncate_texts_to_token_limit
21+
1922
logger = logging.getLogger()
2023

2124

@@ -48,16 +51,16 @@ def __init__(
4851
"LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
4952
)
5053

51-
url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url
52-
if not url:
54+
if url := os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url:
55+
self.rerank_url = url
56+
else:
5357
raise ValueError(
5458
"LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`."
5559
)
56-
self.rerank_url = url
5760

5861
self.base_model = config.base_model
5962
if "amazon" in self.base_model:
60-
logger.warn("Amazon embedding model detected, dropping params")
63+
logger.warning("Amazon embedding model detected, dropping params")
6164
litellm.drop_params = True
6265
self.base_dimension = config.base_dimension
6366

@@ -78,6 +81,13 @@ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
7881
logger.warning("Dropping nan dimensions from kwargs")
7982

8083
try:
84+
# Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
85+
if kwargs.get("model"):
86+
with contextlib.suppress(Exception):
87+
texts = truncate_texts_to_token_limit(
88+
texts, kwargs["model"]
89+
)
90+
8191
response = await self.litellm_aembedding(
8292
input=texts,
8393
**kwargs,
@@ -98,6 +108,13 @@ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
98108
texts = task["texts"]
99109
kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
100110
try:
111+
# Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
112+
if kwargs.get("model"):
113+
with contextlib.suppress(Exception):
114+
texts = truncate_texts_to_token_limit(
115+
texts, kwargs["model"]
116+
)
117+
101118
response = self.litellm_embedding(
102119
input=texts,
103120
**kwargs,

py/core/providers/embeddings/openai.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import logging
23
import os
34
from typing import Any
@@ -12,6 +13,8 @@
1213
EmbeddingProvider,
1314
)
1415

16+
from .utils import truncate_texts_to_token_limit
17+
1518
logger = logging.getLogger()
1619

1720

@@ -101,6 +104,13 @@ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
101104
kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
102105

103106
try:
107+
# Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
108+
if kwargs.get("model"):
109+
with contextlib.suppress(Exception):
110+
texts = truncate_texts_to_token_limit(
111+
texts, kwargs["model"]
112+
)
113+
104114
response = await self.async_client.embeddings.create(
105115
input=texts,
106116
**kwargs,
@@ -119,6 +129,13 @@ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
119129
texts = task["texts"]
120130
kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
121131
try:
132+
# Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
133+
if kwargs.get("model"):
134+
with contextlib.suppress(Exception):
135+
texts = truncate_texts_to_token_limit(
136+
texts, kwargs["model"]
137+
)
138+
122139
response = self.client.embeddings.create(
123140
input=texts,
124141
**kwargs,
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import logging
2+
3+
from litellm import get_model_info, token_counter
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def truncate_texts_to_token_limit(texts: list[str], model: str) -> list[str]:
9+
"""
10+
Truncate texts to fit within the model's token limit.
11+
"""
12+
try:
13+
model_info = get_model_info(model=model)
14+
if not model_info.get("max_input_tokens"):
15+
return texts # No truncation needed if no limit specified
16+
17+
truncated_texts = []
18+
for text in texts:
19+
text_tokens = token_counter(model=model, text=text)
20+
assert model_info["max_input_tokens"]
21+
if text_tokens > model_info["max_input_tokens"]:
22+
estimated_chars = (
23+
model_info["max_input_tokens"] * 3
24+
) # Estimate 3 chars per token
25+
truncated_text = text[:estimated_chars]
26+
truncated_texts.append(truncated_text)
27+
logger.warning(
28+
f"Truncated text from {text_tokens} to ~{model_info['max_input_tokens']} tokens"
29+
)
30+
else:
31+
truncated_texts.append(text)
32+
33+
return truncated_texts
34+
except Exception as e:
35+
logger.warning(f"Failed to truncate texts: {str(e)}")
36+
return texts # Return original texts if truncation fails

py/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "r2r"
7-
version = "3.5.14"
7+
version = "3.5.15"
88
description = "SciPhi R2R"
99
readme = "README.md"
1010
license = {text = "MIT"}

0 commit comments

Comments
 (0)