-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Feat: turbopuffer datasink #58910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Feat: turbopuffer datasink #58910
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a TurbopufferDatasink to enable writing Ray datasets to the Turbopuffer vector database. The implementation is comprehensive, covering both single-namespace and multi-namespace writes, along with robust configuration validation and a thorough test suite.
My review focuses on performance and configurability. I've identified a significant performance issue in the multi-namespace write logic and suggest a more efficient implementation using pyarrow.Table.group_by(). I also recommend making the distance_metric configurable to provide more flexibility to users. Additionally, there are a couple of minor improvements for robustness and code style.
Overall, this is a great contribution that adds valuable functionality to Ray Data.
| # Group by namespace column | ||
| # Note: PyArrow doesn't have a built-in group_by for tables, | ||
| # so we'll use a simpler approach: get unique values and filter | ||
| namespace_col = table.column(self.namespace_column) | ||
|
|
||
| # Get unique namespace values | ||
| unique_namespaces = pc.unique(namespace_col) | ||
|
|
||
| logger.debug(f"Writing to {len(unique_namespaces)} namespaces") | ||
|
|
||
| # Process each namespace group | ||
| for i in range(len(unique_namespaces)): | ||
| namespace_value = unique_namespaces[i].as_py() | ||
|
|
||
| # Filter table for this namespace | ||
| mask = pc.equal(namespace_col, namespace_value) | ||
| group_table = table.filter(mask) | ||
|
|
||
| # Format namespace name | ||
| # Convert bytes to UUID string if needed | ||
| if isinstance(namespace_value, bytes) and len(namespace_value) == 16: | ||
| # This is a UUID in binary format | ||
| namespace_str = str(uuid.UUID(bytes=namespace_value)) | ||
| else: | ||
| namespace_str = str(namespace_value) | ||
|
|
||
| namespace_name = self.namespace_format.format(namespace=namespace_str) | ||
|
|
||
| # Write this group | ||
| self._write_single_namespace(client, group_table, namespace_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for grouping by namespace_column iterates through unique namespace values and filters the entire table for each one. This can be very inefficient if there are many unique namespaces, as it results in a complexity of roughly O(num_unique_namespaces * num_rows). A more performant approach is to use pyarrow.Table.group_by(), which processes each group only once. The comment on line 291 is also incorrect, as PyArrow does support table grouping.
| # Group by namespace column | |
| # Note: PyArrow doesn't have a built-in group_by for tables, | |
| # so we'll use a simpler approach: get unique values and filter | |
| namespace_col = table.column(self.namespace_column) | |
| # Get unique namespace values | |
| unique_namespaces = pc.unique(namespace_col) | |
| logger.debug(f"Writing to {len(unique_namespaces)} namespaces") | |
| # Process each namespace group | |
| for i in range(len(unique_namespaces)): | |
| namespace_value = unique_namespaces[i].as_py() | |
| # Filter table for this namespace | |
| mask = pc.equal(namespace_col, namespace_value) | |
| group_table = table.filter(mask) | |
| # Format namespace name | |
| # Convert bytes to UUID string if needed | |
| if isinstance(namespace_value, bytes) and len(namespace_value) == 16: | |
| # This is a UUID in binary format | |
| namespace_str = str(uuid.UUID(bytes=namespace_value)) | |
| else: | |
| namespace_str = str(namespace_value) | |
| namespace_name = self.namespace_format.format(namespace=namespace_str) | |
| # Write this group | |
| self._write_single_namespace(client, group_table, namespace_name) | |
| # Group by namespace column | |
| grouped = table.group_by(self.namespace_column) | |
| logger.debug(f"Writing to {len(grouped)} namespaces") | |
| # Process each namespace group | |
| for group_key, group_table in grouped: | |
| namespace_value = group_key[0].as_py() | |
| # Format namespace name | |
| # Convert bytes to UUID string if needed | |
| if isinstance(namespace_value, bytes) and len(namespace_value) == 16: | |
| # This is a UUID in binary format | |
| namespace_str = str(uuid.UUID(bytes=namespace_value)) | |
| else: | |
| namespace_str = str(namespace_value) | |
| namespace_name = self.namespace_format.format(namespace=namespace_str) | |
| # Write this group | |
| self._write_single_namespace(client, group_table, namespace_name) |
| rows = table.to_pylist() | ||
|
|
||
| # Validate all rows have ID | ||
| if rows and "id" not in rows[0]: | ||
| raise ValueError("Table must have 'id' column") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for the presence of an "id" column is done on the first row of the Python list representation of the table (if rows and "id" not in rows[0]). This is inefficient because it requires converting the table to a list of dictionaries first, and it's not robust as it only checks the first row. A better approach is to check the table's schema directly before converting it to a pylist.
| rows = table.to_pylist() | |
| # Validate all rows have ID | |
| if rows and "id" not in rows[0]: | |
| raise ValueError("Table must have 'id' column") | |
| if "id" not in table.column_names: | |
| raise ValueError("Table must have 'id' column") | |
| # Convert to list of row dictionaries | |
| rows = table.to_pylist() |
| raise ValueError("Table must have 'id' column") | ||
|
|
||
| # Convert bytes to proper formats (e.g., UUIDs) | ||
| import uuid as uuid_lib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The uuid module is imported locally here, but it's already imported at the top of the file. For consistency and to follow best practices (PEP 8), it's better to use the top-level import. The alias uuid_lib is not necessary as there is no local variable named uuid that it would shadow. Please remove this line and change uuid_lib.UUID() to uuid.UUID() in this method.
| namespace.write( | ||
| upsert_rows=batch_data, | ||
| schema=self.schema, | ||
| distance_metric="cosine_distance", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distance_metric is hardcoded to "cosine_distance". While this is the default for Turbopuffer, users might want to use other supported metrics like "euclidean_squared" or "dot_product". This should be a configurable parameter of the TurbopufferDatasink.
I recommend adding a distance_metric parameter to the __init__ method (defaulting to "cosine_distance") and using it here. This change should also be propagated to Dataset.write_turbopuffer.
| distance_metric="cosine_distance", | |
| distance_metric=self.distance_metric, |
| idx = table.column_names.index(self.vector_column) | ||
| new_names = list(table.column_names) | ||
| new_names[idx] = "vector" | ||
| table = table.rename_columns(new_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Missing validation for custom vector column
When a custom vector_column is specified but doesn't exist in the table, the code silently skips renaming without raising an error. The condition if self.vector_column != "vector" and self.vector_column in table.column_names: only executes the renaming block when the column exists. Unlike the id_column validation (which explicitly checks and raises ValueError if missing), this allows tables without the specified vector column to proceed, potentially causing silent failures or errors downstream when Turbopuffer expects a "vector" column.
| self._write_multi_namespace(client, table) | ||
| else: | ||
| # Single namespace mode | ||
| self._write_single_namespace(client, table, self.namespace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Namespace column lost after column renaming
When namespace_column is the same as id_column or vector_column, the column gets renamed in _prepare_arrow_table before _write_multi_namespace executes. Then _write_multi_namespace attempts to find the original namespace_column name in the table, which no longer exists after renaming, causing a ValueError. For example, if both namespace_column="doc_id" and id_column="doc_id", the column is renamed to "id" but multi-namespace mode still looks for "doc_id".
Additional Locations (1)
| idx = table.column_names.index(self.vector_column) | ||
| new_names = list(table.column_names) | ||
| new_names[idx] = "vector" | ||
| table = table.rename_columns(new_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Same column for ID and vector causes failure
When id_column and vector_column are set to the same column name, the ID renaming happens first and consumes that column by renaming it to "id". Then the vector column renaming logic checks if self.vector_column in table.column_names, which is now false since the column was already renamed. This causes the vector renaming to be silently skipped, leaving the table without a "vector" column, which will likely cause errors when writing to Turbopuffer.
|
|
||
| # Filter table for this namespace | ||
| mask = pc.equal(namespace_col, namespace_value) | ||
| group_table = table.filter(mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Null namespace values silently drop rows
When the namespace column contains null values in multi-namespace mode, pc.equal(namespace_col, namespace_value) returns all false values when namespace_value is None, because PyArrow's equality doesn't match nulls. This causes rows with null namespace values to be filtered into an empty group and silently dropped without any error or warning, leading to unexpected data loss.
Description
This PR adds a
TurbopufferDatasinkfor Ray Data, enabling Ray datasets to be written directly into the Turbopuffer vector database. The datasink supports both single-namespace writes and multi-namespace writes, where rows are grouped by anamespace_columnand written into separate Turbopuffer namespaces derived from that column.The implementation includes:
TurbopufferDatasink(Datasink)that:namespacevsnamespace_column, required API key, namespace format).id_column→"id"andvector_column→"vector", and filters out rows with null IDs.namespace_column, formatting namespace names vianamespace_format(e.g.,block_spans__{namespace}), and writing each group to its own Turbopuffer namespace.bytes(including inside lists) → hex strings._prepare_arrow_table:ValueErrorif renaming a customid_columnto"id"would conflict with an existing"id"column.ValueErrorif renaming a customvector_columnto"vector"would conflict with an existing"vector"column.table.column("id")/"vector".A comprehensive test suite is added/updated in
python/ray/data/tests/test_turbopuffer_datasink.pyto cover:"id"or"vector"columns.namespace_column."id"column.This PR is aligned with the design and performance considerations described in
turbopuffer_datasink_spec.md, including support for multi-tenant (multi-namespace) ingestion patterns and Turbopuffer’s performance guidance (schema types, batch sizing, concurrency).Reference docs