Skip to content

Commit f4d2cb3

Browse files
committed
feat(targets): add Turbopuffer target connector
Add a new target connector for Turbopuffer vector database. The implementation follows the same pattern as the ChromaDB connector with turbopuffer-specific adaptations: - Namespaces are created implicitly on first write (no explicit setup) - distance_metric is passed on every upsert call - Row-oriented write format via ns.write(upsert_rows=[...]) - Supports cosine, euclidean, and dot product distance metrics Includes 30 unit tests covering all connector logic and helper functions, with no turbopuffer account required. Closes #1562
1 parent fe6b678 commit f4d2cb3

File tree

3 files changed

+560
-1
lines changed

3 files changed

+560
-1
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ colpali = ["colpali-engine"]
7373
lancedb = ["lancedb>=0.25.0", "pyarrow>=19.0.0"]
7474
chromadb = ["chromadb>=0.4.0"]
7575
doris = ["aiohttp>=3.8.0", "aiomysql>=0.2.0", "pymysql>=1.0.0"]
76+
turbopuffer = ["turbopuffer>=1.0.0"]
7677

7778
all = [
7879
"sentence-transformers>=3.3.1",
@@ -83,6 +84,7 @@ all = [
8384
"aiohttp>=3.8.0",
8485
"aiomysql>=0.2.0",
8586
"pymysql>=1.0.0",
87+
"turbopuffer>=1.0.0",
8688
]
8789

8890
[dependency-groups]
@@ -132,7 +134,7 @@ disable_error_code = ["unused-ignore"]
132134

133135
[[tool.mypy.overrides]]
134136
# Ignore missing imports for optional dependencies from cocoindex library
135-
module = ["sentence_transformers", "torch", "colpali_engine", "PIL", "aiohttp", "aiomysql", "pymysql", "chromadb"]
137+
module = ["sentence_transformers", "torch", "colpali_engine", "PIL", "aiohttp", "aiomysql", "pymysql", "chromadb", "turbopuffer"]
136138
ignore_missing_imports = true
137139

138140
[[tool.mypy.overrides]]
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import dataclasses
2+
import json
3+
import logging
4+
from typing import Any
5+
6+
import turbopuffer # type: ignore
7+
8+
from cocoindex import op
9+
from cocoindex.engine_type import FieldSchema, BasicValueType
10+
from cocoindex.index import IndexOptions, VectorSimilarityMetric
11+
12+
_logger = logging.getLogger(__name__)
13+
14+
_TURBOPUFFER_DISTANCE_METRIC: dict[VectorSimilarityMetric, str] = {
15+
VectorSimilarityMetric.COSINE_SIMILARITY: "cosine_distance",
16+
VectorSimilarityMetric.L2_DISTANCE: "euclidean_squared",
17+
VectorSimilarityMetric.INNER_PRODUCT: "dot_product",
18+
}
19+
20+
21+
class Turbopuffer(op.TargetSpec):
22+
namespace_name: str
23+
api_key: str
24+
region: str = "gcp-us-central1"
25+
26+
27+
@dataclasses.dataclass
28+
class _NamespaceKey:
29+
region: str
30+
namespace_name: str
31+
32+
33+
@dataclasses.dataclass
34+
class _State:
35+
key_field_schema: FieldSchema
36+
value_fields_schema: list[FieldSchema]
37+
distance_metric: str
38+
api_key: str
39+
40+
41+
@dataclasses.dataclass
42+
class _MutateContext:
43+
client: Any # turbopuffer.Turbopuffer
44+
namespace: Any # turbopuffer.lib.namespace.Namespace
45+
key_field_schema: FieldSchema
46+
value_fields_schema: list[FieldSchema]
47+
distance_metric: str
48+
49+
50+
def _get_client(spec: Turbopuffer) -> Any:
51+
return turbopuffer.Turbopuffer(
52+
api_key=spec.api_key,
53+
region=spec.region,
54+
)
55+
56+
57+
def _convert_key_to_id(key: Any) -> str:
58+
if isinstance(key, str):
59+
return key
60+
elif isinstance(key, (int, float, bool)):
61+
return str(key)
62+
else:
63+
return json.dumps(key, sort_keys=True, default=str)
64+
65+
66+
def _convert_value_to_attribute(value: Any) -> str | int | float | bool | None:
67+
if value is None:
68+
return None
69+
if isinstance(value, (str, int, float, bool)):
70+
return value
71+
return json.dumps(value, sort_keys=True, default=str)
72+
73+
74+
def _is_vector_field(field: FieldSchema) -> bool:
75+
value_type = field.value_type.type
76+
if isinstance(value_type, BasicValueType):
77+
return value_type.kind == "Vector"
78+
return False
79+
80+
81+
@op.target_connector(
82+
spec_cls=Turbopuffer, persistent_key_type=_NamespaceKey, setup_state_cls=_State
83+
)
84+
class _Connector:
85+
@staticmethod
86+
def get_persistent_key(spec: Turbopuffer) -> _NamespaceKey:
87+
return _NamespaceKey(
88+
region=spec.region,
89+
namespace_name=spec.namespace_name,
90+
)
91+
92+
@staticmethod
93+
def get_setup_state(
94+
spec: Turbopuffer,
95+
key_fields_schema: list[FieldSchema],
96+
value_fields_schema: list[FieldSchema],
97+
index_options: IndexOptions,
98+
) -> _State:
99+
if len(key_fields_schema) != 1:
100+
raise ValueError("Turbopuffer only supports a single key field")
101+
102+
vector_fields = [f for f in value_fields_schema if _is_vector_field(f)]
103+
if not vector_fields:
104+
raise ValueError(
105+
"Turbopuffer requires a vector field in the value schema for embeddings."
106+
)
107+
if len(vector_fields) > 1:
108+
raise ValueError(
109+
f"Turbopuffer only supports a single vector field per namespace, "
110+
f"but found {len(vector_fields)}: {[f.name for f in vector_fields]}. "
111+
f"Consider using LanceDB or Qdrant for multiple vector fields."
112+
)
113+
114+
distance_metric = "cosine_distance" # Default
115+
if index_options.vector_indexes:
116+
if len(index_options.vector_indexes) > 1:
117+
raise ValueError(
118+
"Turbopuffer only supports a single vector index per namespace"
119+
)
120+
vector_index = index_options.vector_indexes[0]
121+
distance_metric = _TURBOPUFFER_DISTANCE_METRIC.get(
122+
vector_index.metric, "cosine_distance"
123+
)
124+
125+
return _State(
126+
key_field_schema=key_fields_schema[0],
127+
value_fields_schema=value_fields_schema,
128+
distance_metric=distance_metric,
129+
api_key=spec.api_key,
130+
)
131+
132+
@staticmethod
133+
def describe(key: _NamespaceKey) -> str:
134+
return f"Turbopuffer namespace {key.namespace_name}@{key.region}"
135+
136+
@staticmethod
137+
def check_state_compatibility(
138+
previous: _State, current: _State
139+
) -> op.TargetStateCompatibility:
140+
if previous.key_field_schema != current.key_field_schema:
141+
return op.TargetStateCompatibility.NOT_COMPATIBLE
142+
if previous.distance_metric != current.distance_metric:
143+
return op.TargetStateCompatibility.NOT_COMPATIBLE
144+
145+
return op.TargetStateCompatibility.COMPATIBLE
146+
147+
@staticmethod
148+
def apply_setup_change(
149+
key: _NamespaceKey, previous: _State | None, current: _State | None
150+
) -> None:
151+
if previous is None and current is None:
152+
return
153+
state = current or previous
154+
if state is None:
155+
return
156+
157+
# Delete namespace data if previous state exists and we're removing or recreating
158+
if previous is not None:
159+
should_delete = current is None or (
160+
previous.key_field_schema != current.key_field_schema
161+
or previous.distance_metric != current.distance_metric
162+
)
163+
if should_delete:
164+
try:
165+
client = turbopuffer.Turbopuffer(
166+
api_key=state.api_key,
167+
region=key.region,
168+
)
169+
ns = client.namespace(key.namespace_name)
170+
ns.delete_all()
171+
except Exception as e: # pylint: disable=broad-exception-caught
172+
_logger.debug(
173+
"Namespace %s not found for deletion: %s",
174+
key.namespace_name,
175+
e,
176+
)
177+
178+
# Turbopuffer namespaces are created implicitly on first write — no setup needed.
179+
180+
@staticmethod
181+
def prepare(
182+
spec: Turbopuffer,
183+
setup_state: _State,
184+
) -> _MutateContext:
185+
client = _get_client(spec)
186+
ns = client.namespace(spec.namespace_name)
187+
188+
return _MutateContext(
189+
client=client,
190+
namespace=ns,
191+
key_field_schema=setup_state.key_field_schema,
192+
value_fields_schema=setup_state.value_fields_schema,
193+
distance_metric=setup_state.distance_metric,
194+
)
195+
196+
@staticmethod
197+
def mutate(
198+
*all_mutations: tuple[_MutateContext, dict[Any, dict[str, Any] | None]],
199+
) -> None:
200+
for context, mutations in all_mutations:
201+
if not mutations:
202+
continue
203+
204+
ids_to_delete: list[str] = []
205+
rows_to_upsert: list[dict[str, Any]] = []
206+
207+
# Find the vector field name
208+
vector_field_name: str | None = None
209+
for field in context.value_fields_schema:
210+
if _is_vector_field(field):
211+
vector_field_name = field.name
212+
break
213+
214+
for key, value in mutations.items():
215+
doc_id = _convert_key_to_id(key)
216+
217+
if value is None:
218+
ids_to_delete.append(doc_id)
219+
else:
220+
row: dict[str, Any] = {"id": doc_id}
221+
222+
# Extract vector
223+
if vector_field_name and vector_field_name in value:
224+
embedding = value[vector_field_name]
225+
if embedding is None:
226+
raise ValueError(
227+
f"Missing embedding for document {doc_id}. "
228+
f"Turbopuffer requires an embedding for each document."
229+
)
230+
row["vector"] = embedding
231+
232+
# Build attributes from non-vector fields
233+
for field in context.value_fields_schema:
234+
if field.name == vector_field_name:
235+
continue
236+
if field.name in value:
237+
converted = _convert_value_to_attribute(value[field.name])
238+
if converted is not None:
239+
row[field.name] = converted
240+
241+
rows_to_upsert.append(row)
242+
243+
# Execute upserts
244+
if rows_to_upsert:
245+
context.namespace.write(
246+
upsert_rows=rows_to_upsert,
247+
distance_metric=context.distance_metric,
248+
)
249+
250+
# Execute deletes
251+
if ids_to_delete:
252+
context.namespace.write(
253+
deletes=ids_to_delete,
254+
)

0 commit comments

Comments
 (0)