Skip to content

Commit 6595e0e

Browse files
committed
Added batch embedding capability.
1 parent a09d851 commit 6595e0e

9 files changed

Lines changed: 415 additions & 50 deletions

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Changelog
22

33
## Unreleased
4+
- Added batch embedding support: embedding adapters expose `execute_one` and `execute_batch`;
5+
`execute` remains as a deprecated alias for `execute_one` (removed in 1.0). `llmEmbed` uses
6+
`execute_one` / `execute_batch` only, accepts list-shaped items (e.g. after `makeLists`), and
7+
optional `batch_size` for built-in buffering.
48
- Added model2vec embedding support via `Model2VecEmbeddingAdapter` and `llmEmbed`
59
source `model2vec`, using in-process static embeddings with offline-ready HF cache
610
precaching. Install with `pip install talkpipe[model2vec]` or `talkpipe[all]`. Added

docs/guides/model-and-source-configuration.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ If a CLI flag is omitted and the matching `DEFAULT_*` key is unset, the value pa
223223

224224
## Segment parameters
225225

226-
Those configuration keys provide the fallback values; the per-segment parameters below override them and take final precedence at construction time.
226+
For `llmPrompt`, `llmVisionPrompt`, and `llmEmbed`, only **`model`** and **`source`** fall back to `default_*` config keys when omitted. Every other segment parameter must be set on the segment (ChatterLang or Python); it is not read from `~/.talkpipe.toml` or `TALKPIPE_*` unless noted below for a specific higher-level segment.
227227

228228
### `llmPrompt` / `LLMPrompt`
229229

@@ -269,7 +269,26 @@ segment = LLMVisionPrompt(
269269

270270
### `llmEmbed` / `LLMEmbed`
271271

272-
Required (directly or via config): `model`, `source`. Optional: `field` (text field to embed), `set_as` (field to store the vector on the item).
272+
| Parameter | From config? | Notes |
273+
|-----------|--------------|--------|
274+
| `model` | Yes — `default_embedding_model_name` | Required if not passed on the segment |
275+
| `source` | Yes — `default_embedding_model_source` | Required if not passed on the segment |
276+
| `field` | No | Text field to embed on structured items |
277+
| `set_as` | No | Field on the item where the vector is stored |
278+
| `batch_size` | No | Scalar items per provider call (default `1`) |
279+
| `fail_on_error` | No | Default `true` |
280+
281+
**Batching (two patterns):**
282+
283+
1. **Built-in buffering** — set `batch_size` greater than `1` on `llmEmbed` to amortize API round-trips without changing upstream segments.
284+
2. **Composable buffering** — group items with `makeLists`, then embed the batch in one call:
285+
286+
```chatterlang
287+
| makeLists[num_items=100, field="_"]
288+
| llmEmbed[model="mxbai-embed-large", source="ollama", field="content", set_as="vector"]
289+
```
290+
291+
List-shaped items are expanded back to one output per document (with `set_as`, each dict is updated and yielded).
273292

274293
```chatterlang
275294
INPUT FROM echo[data="Hello world"]

docs/guides/model2vec-embeddings.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ precache_model("minishlab/potion-base-8M")
113113
embedder = Model2VecEmbedder()
114114
vector = embedder.embed_one("Paragraph text.")
115115
batch = embedder.embed(["first", "second"])
116+
117+
# ChatterLang / llmEmbed use Model2VecEmbeddingAdapter, which batches via the same encode path:
118+
# adapter.execute_batch(["first", "second"]) or adapter(["first", "second"])
116119
```
117120

118121
Pin to a specific HF commit by passing `revision="<commit-sha>"` to either

src/talkpipe/llm/embedding.py

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Module for embedding text using different models"""
22

3-
from typing import Optional, Annotated
3+
from typing import Optional, Annotated, Iterator, Any, List
44
import logging
5-
from talkpipe.llm.embedding_adapters import OllamaEmbedderAdapter
65
from talkpipe.pipe.core import AbstractSegment
76
from talkpipe.chatterlang.registry import register_segment
87
from talkpipe.util.data_manipulation import extract_property, assign_property
@@ -30,7 +29,8 @@ def __init__(
3029
source: Annotated[Optional[str], "The source of the embedding model (e.g., 'ollama')"] = None,
3130
field: Annotated[Optional[str], "If provided, extract text from this field in the input items"] = None,
3231
set_as: Annotated[Optional[str], "If provided, append embeddings to input items under this field name"] = None,
33-
fail_on_error: Annotated[bool, "Whether to raise an error on failure or to silently ignore it"] = True
32+
fail_on_error: Annotated[bool, "Whether to raise an error on failure or to silently ignore it"] = True,
33+
batch_size: Annotated[int, "Number of texts to embed per provider call when items are scalars"] = 1,
3434
):
3535
"""Initialize the embedding segment with the specified parameters.
3636
@@ -44,10 +44,69 @@ def __init__(
4444
if source not in getEmbeddingSources():
4545
logger.error(f"Source '{source}' is not supported. Supported sources are: {getEmbeddingSources()}")
4646
raise ValueError(f"Source '{source}' is not supported. Supported sources are: {getEmbeddingSources()}")
47+
if batch_size < 1:
48+
raise ValueError("batch_size must be a positive integer")
4749
self.embedder = getEmbeddingAdapter(source)(model=model)
4850
self.field = field
4951
self.set_as = set_as
5052
self.fail_on_error = fail_on_error
53+
self.batch_size = batch_size
54+
55+
def _text_from_item(self, item: Any) -> str:
56+
if self.field is not None:
57+
return str(extract_property(item, self.field))
58+
return str(item)
59+
60+
def _yield_embedded(
61+
self, items: List[Any], vectors: List[List[float]]
62+
) -> Iterator[Any]:
63+
for item, ans in zip(items, vectors):
64+
logger.debug(f"Received embedding: {ans}")
65+
if self.set_as is not None:
66+
logger.debug(f"Appending embedding to field {self.set_as}")
67+
assign_property(item, self.set_as, ans)
68+
yield item
69+
else:
70+
logger.debug("Yielding embedding directly")
71+
yield ans
72+
73+
def _vectors_for_texts(self, texts: List[str]) -> List[List[float]]:
74+
if not texts:
75+
return []
76+
if len(texts) == 1:
77+
return [self.embedder.execute_one(texts[0])]
78+
return self.embedder.execute_batch(texts)
79+
80+
def _embed_and_emit(self, items: List[Any], texts: List[str]) -> Iterator[Any]:
81+
if not items or not texts:
82+
return
83+
logger.debug(f"Embedding batch of {len(texts)} texts")
84+
try:
85+
vectors = self._vectors_for_texts(texts)
86+
yield from self._yield_embedded(items, vectors)
87+
except Exception as e:
88+
logger.error(f"Error during batch embedding: {e}")
89+
if self.fail_on_error:
90+
raise
91+
if len(texts) == 1:
92+
return
93+
logger.warning(
94+
"Batch embedding failed; falling back to per-item embedding"
95+
)
96+
for item, text in zip(items, texts):
97+
try:
98+
ans = self.embedder.execute_one(text)
99+
except Exception as item_error:
100+
logger.error(f"Error during embedding: {item_error}")
101+
continue
102+
yield from self._yield_embedded([item], [ans])
103+
104+
def _embed_list_item(self, list_item: list) -> Iterator[Any]:
105+
if not list_item:
106+
return
107+
items = list(list_item)
108+
texts = [self._text_from_item(item) for item in items]
109+
yield from self._embed_and_emit(items, texts)
51110

52111
def transform(self, input_iter):
53112
"""Transform input items by creating embeddings.
@@ -59,30 +118,32 @@ def transform(self, input_iter):
59118
If set_as is specified, yields the original items with embeddings added.
60119
Otherwise, yields the embeddings directly.
61120
"""
121+
buffer_items: List[Any] = []
122+
buffer_texts: List[str] = []
123+
124+
def flush_buffer() -> Iterator[Any]:
125+
if not buffer_items:
126+
return
127+
yield from self._embed_and_emit(buffer_items, buffer_texts)
128+
buffer_items.clear()
129+
buffer_texts.clear()
130+
62131
for item in input_iter:
63132
logging.debug(f"Processing input item: {item}")
64-
if self.field is not None:
65-
text = extract_property(item, self.field)
66-
logging.debug(f"Extracted text from field {self.field}: {text}")
67-
else:
68-
text = item
69-
logging.debug(f"Using item as text: {text}")
133+
if isinstance(item, list):
134+
yield from flush_buffer()
135+
yield from self._embed_list_item(item)
136+
continue
70137

71-
logger.debug(f"Embedding text: {text}")
72-
try:
73-
ans = self.embedder.execute(str(text))
74-
except Exception as e:
75-
logger.error(f"Error during embedding: {e}")
76-
if self.fail_on_error:
77-
raise e
78-
else:
79-
continue
80-
logger.debug(f"Received embedding: {ans}")
138+
text = self._text_from_item(item)
139+
logging.debug(f"Embedding text: {text}")
81140

82-
if self.set_as is not None:
83-
logger.debug(f"Appending embedding to field {self.set_as}")
84-
assign_property(item, self.set_as, ans)
85-
yield item
141+
if self.batch_size <= 1:
142+
yield from self._embed_and_emit([item], [text])
86143
else:
87-
logger.debug("Yielding embedding directly")
88-
yield ans
144+
buffer_items.append(item)
145+
buffer_texts.append(text)
146+
if len(buffer_items) >= self.batch_size:
147+
yield from flush_buffer()
148+
149+
yield from flush_buffer()

src/talkpipe/llm/embedding_adapters.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,35 @@
1-
from typing import List
1+
from __future__ import annotations
2+
3+
import warnings
4+
from typing import List, overload, Sequence, Union
5+
26
import numpy as np
7+
38
from talkpipe.util.config import get_config
49
from talkpipe.util.constants import OLLAMA_SERVER_URL
510

11+
12+
def _vector_to_list(vec) -> List[float]:
13+
return np.asarray(vec, dtype=float).tolist()
14+
15+
16+
def _vectors_to_lists(arr) -> List[List[float]]:
17+
a = np.asarray(arr, dtype=float)
18+
if a.size == 0:
19+
return []
20+
if a.ndim == 1:
21+
return [_vector_to_list(a)]
22+
return [_vector_to_list(row) for row in a]
23+
24+
625
class AbstractEmbeddingAdapter:
726
"""Abstract class for embedding text.
827
928
This class represents an abstract adapter to embedding models.
1029
It defines the API and a common way to interact with different embedding models. The
1130
specifics for embedding the text themselves are implemented in subclasses.
1231
"""
32+
1333
_model_name: str
1434
_source: str
1535

@@ -35,11 +55,41 @@ def __str__(self):
3555
def __repr__(self):
3656
return self.__str__()
3757

58+
def execute_one(self, text: str) -> List[float]:
59+
raise NotImplementedError("Subclasses must implement execute_one.")
60+
61+
def execute_batch(self, texts: Sequence[str]) -> List[List[float]]:
62+
if not texts:
63+
return []
64+
return [self.execute_one(t) for t in texts]
65+
3866
def execute(self, text: str) -> List[float]:
39-
raise NotImplementedError("This method must be implemented in a subclass.")
67+
"""Embed a single string (deprecated).
68+
69+
.. deprecated::
70+
Use :meth:`execute_one` or :meth:`execute_batch` instead.
71+
``execute`` will be removed in TalkPipe 1.0.
72+
"""
73+
warnings.warn(
74+
"EmbeddingAdapter.execute() is deprecated and will be removed in "
75+
"TalkPipe 1.0. Use execute_one() or execute_batch() instead.",
76+
DeprecationWarning,
77+
stacklevel=2,
78+
)
79+
return self.execute_one(text)
80+
81+
@overload
82+
def __call__(self, text: str) -> List[float]: ...
83+
84+
@overload
85+
def __call__(self, text: Sequence[str]) -> List[List[float]]: ...
4086

41-
def __call__(self, text: str) -> List[float]:
42-
return self.execute(text)
87+
def __call__(
88+
self, text: Union[str, Sequence[str]]
89+
) -> Union[List[float], List[List[float]]]:
90+
if isinstance(text, str):
91+
return self.execute_one(text)
92+
return self.execute_batch(list(text))
4393

4494

4595
class OllamaEmbedderAdapter(AbstractEmbeddingAdapter):
@@ -49,21 +99,24 @@ def __init__(self, model: str, server_url: str = None):
4999
super().__init__(model, "ollama")
50100
self._server_url = server_url
51101

52-
def execute(self, text: str) -> List[float]:
102+
def _client(self):
53103
try:
54104
import ollama
55105
except ImportError:
56106
raise ImportError(
57107
"Ollama is not installed. Please install it with: pip install talkpipe[ollama]"
58108
)
59-
60109
server_url = self._server_url
61110
if not server_url:
62111
server_url = get_config().get(OLLAMA_SERVER_URL, None)
63-
client = ollama.Client(server_url) if server_url else ollama
64-
response = client.embed(
65-
model=self.model_name,
66-
input=text
67-
)
68-
result = response["embeddings"][0]
69-
return np.array(result)
112+
return ollama.Client(server_url) if server_url else ollama
113+
114+
def execute_batch(self, texts: Sequence[str]) -> List[List[float]]:
115+
if not texts:
116+
return []
117+
client = self._client()
118+
response = client.embed(model=self.model_name, input=list(texts))
119+
return _vectors_to_lists(response["embeddings"])
120+
121+
def execute_one(self, text: str) -> List[float]:
122+
return self.execute_batch([text])[0]

src/talkpipe/llm/embedding_adapters_model2vec.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Sequence
22

33
from talkpipe.util.config import get_config
44
from talkpipe.util.constants import MODEL2VEC_CACHE_DIR, MODEL2VEC_REVISION
55

6-
from .embedding_adapters import AbstractEmbeddingAdapter
6+
from .embedding_adapters import AbstractEmbeddingAdapter, _vectors_to_lists
77
from .model2vec_embeddings import DEFAULT_MODEL, Model2VecEmbedder
88

99

@@ -20,5 +20,10 @@ def __init__(self, model: Optional[str] = None):
2020
cache_folder=cfg.get(MODEL2VEC_CACHE_DIR),
2121
)
2222

23-
def execute(self, text: str) -> List[float]:
24-
return self._embedder.embed_one(text)
23+
def execute_batch(self, texts: Sequence[str]) -> List[List[float]]:
24+
if not texts:
25+
return []
26+
return _vectors_to_lists(self._embedder.embed(list(texts)))
27+
28+
def execute_one(self, text: str) -> List[float]:
29+
return self.execute_batch([text])[0]

src/talkpipe/llm/embedding_adapters_openai.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Sequence
22

33
from .embedding_adapters import AbstractEmbeddingAdapter
44

@@ -21,7 +21,16 @@ def __init__(self, model: str):
2121
openai = _require_openai()
2222
self.client = openai.OpenAI()
2323

24-
def execute(self, text: str) -> List[float]:
24+
def execute_batch(self, texts: Sequence[str]) -> List[List[float]]:
25+
if not texts:
26+
return []
27+
response = self.client.embeddings.create(
28+
model=self.model_name,
29+
input=list(texts),
30+
)
31+
return [list(d.embedding) for d in response.data]
32+
33+
def execute_one(self, text: str) -> List[float]:
2534
response = self.client.embeddings.create(
2635
model=self.model_name,
2736
input=text,

tests/talkpipe/llm/test_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_fail_on_error_parameter():
101101
"""Test that fail_on_error parameter works correctly to prevent duplicate execute() calls."""
102102
# Create a mock embedder that always fails
103103
mock_embedder = Mock()
104-
mock_embedder.execute = Mock(side_effect=RuntimeError("Embedding failed"))
104+
mock_embedder.execute_one = Mock(side_effect=RuntimeError("Embedding failed"))
105105

106106
# Test with fail_on_error=True (should raise exception)
107107
embedder_true = LLMEmbed(model="test-model", source="ollama", fail_on_error=True)
@@ -111,7 +111,7 @@ def test_fail_on_error_parameter():
111111
list(embedder_true(["test input"]))
112112

113113
# Verify execute was called only once (not twice due to duplicate line bug)
114-
assert mock_embedder.execute.call_count == 1
114+
assert mock_embedder.execute_one.call_count == 1
115115

116116
# Reset mock
117117
mock_embedder.reset_mock()
@@ -126,5 +126,5 @@ def test_fail_on_error_parameter():
126126
# Result should be empty since the embedding failed and was skipped
127127
assert result == []
128128

129-
# Verify execute was called only once (not twice due to duplicate line bug)
130-
assert mock_embedder.execute.call_count == 1
129+
# Verify execute_one was called only once (not twice due to duplicate line bug)
130+
assert mock_embedder.execute_one.call_count == 1

0 commit comments

Comments
 (0)