Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions perch_hoplite/db/sqlite_usearch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
from collections.abc import Sequence
import dataclasses
import functools
import json
import sqlite3
from typing import Any
Expand Down Expand Up @@ -164,6 +165,31 @@ def _sqlite_filepath(self) -> epath.Path:
def _usearch_filepath(self) -> epath.Path:
return epath.Path(self.db_path) / UINDEX_FILENAME

@functools.cached_property
def offset_dtype(self) -> type[Any]:
"""The data type of the offsets in the database."""
if self.ui.size == 0:
return np.float32
# Otherwise, check the in-memory index.
# Some old DBs may have float16 offsets, but in this case will always
# have a single offset, so we can detect this and return float16.
idx = self.get_one_embedding_id()
cursor = self._get_cursor()
cursor.execute(
"""
SELECT he.offsets
FROM hoplite_sources hs
JOIN hoplite_embeddings he ON hs.id = he.source_idx
WHERE he.id = ?;
""",
(int(idx),),
)
binary_offsets = cursor.fetchall()[0]
if len(binary_offsets) == 2:
return np.float16
else:
return np.float32

def thread_split(self):
"""Get a new instance of the SQLite DB."""
return self.create(self.db_path)
Expand Down Expand Up @@ -250,7 +276,7 @@ def insert_embedding(
source_id = self._get_source_id(
source.dataset_name, source.source_id, insert=True
)
offset_bytes = serialize_embedding(source.offsets, self.embedding_dtype)
offset_bytes = serialize_embedding(source.offsets, self.offset_dtype)
cursor.execute(
"""
INSERT INTO hoplite_embeddings (source_idx, offsets) VALUES (?, ?);
Expand Down Expand Up @@ -295,7 +321,7 @@ def get_embedding_source(
(int(embedding_id),),
)
dataset, source, offsets = cursor.fetchall()[0]
offsets = deserialize_embedding(offsets, self.embedding_dtype)
offsets = deserialize_embedding(offsets, self.offset_dtype)
return interface.EmbeddingSource(dataset, str(source), offsets)

def get_embeddings(
Expand Down Expand Up @@ -345,7 +371,7 @@ def get_embeddings_by_source(
result_pairs = cursor.fetchall()
outputs = []
for idx, offsets_bytes in result_pairs:
got_offsets = deserialize_embedding(offsets_bytes, self.embedding_dtype)
got_offsets = deserialize_embedding(offsets_bytes, self.offset_dtype)
if offsets is not None and not np.array_equal(got_offsets, offsets):
continue
outputs.append(idx)
Expand Down