Skip to content

Ensure consistent embeddings are used #1214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions lumen/ai/vector_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import importlib
import io
import json
import os
Expand Down Expand Up @@ -68,6 +69,7 @@ class VectorStore(LLMUser):
embeddings = param.ClassSelector(
class_=Embeddings,
default=NumpyEmbeddings(),
allow_None=True,
doc="Embeddings object for text processing.",
)

Expand Down Expand Up @@ -897,8 +899,16 @@ class DuckDBVectorStore(VectorStore):

uri = param.String(default=":memory:", doc="The URI of the DuckDB database")

embeddings = param.ClassSelector(
class_=Embeddings,
default=None,
allow_None=True,
doc="Embeddings object for text processing. If None and a URI is provided, loads from the database; else NumpyEmbeddings.",
)

def __init__(self, **params):
super().__init__(**params)

connection = duckdb.connect(":memory:")
# following the instructions from
# https://duckdb.org/docs/stable/extensions/vss.html#persistence
Expand Down Expand Up @@ -930,6 +940,19 @@ def __init__(self, **params):
)
self._initialized = uri_exists and has_documents

if self.uri != ":memory:" and self._initialized:
config = self._get_embeddings_config()
if config and self.embeddings is None:
module_name, class_name = config["class"].rsplit(".", 1)
module = importlib.import_module(module_name)
embedding_class = getattr(module, class_name)
self.embeddings = embedding_class(**config["params"])
log_debug(f"Loaded embeddings {class_name} from database.")
self._check_embeddings_consistency()

if self.embeddings is None:
self.embeddings = NumpyEmbeddings()

def _setup_database(self, embedding_dim: int) -> None:
"""Set up the DuckDB database with necessary tables and indexes."""
self.connection.execute("CREATE SEQUENCE IF NOT EXISTS documents_id_seq;")
Expand All @@ -945,6 +968,34 @@ def _setup_database(self, embedding_dim: int) -> None:
"""
)

self.connection.execute(
"""
CREATE TABLE IF NOT EXISTS vector_store_metadata (
key VARCHAR PRIMARY KEY,
value JSON
);
"""
)

# Store embedding configuration
embedding_info = {
"class": self.embeddings.__class__.__module__ + "." + self.embeddings.__class__.__name__,
"params": {}
}
for param_name, param_obj in self.embeddings.param.objects().items():
if param_name not in ['name']:
value = getattr(self.embeddings, param_name)
if isinstance(value, (str, int, float, bool, list, dict)) or value is None:
embedding_info["params"][param_name] = value

self.connection.execute(
"""
INSERT OR REPLACE INTO vector_store_metadata (key, value)
VALUES ('embeddings', ?::JSON);
""",
[json.dumps(embedding_info)]
)

self.connection.execute(
"""
CREATE INDEX IF NOT EXISTS embedding_index
Expand All @@ -953,6 +1004,74 @@ def _setup_database(self, embedding_dim: int) -> None:
)
self._initialized = True

def _check_embeddings_consistency(self):
"""
Check if the provided embeddings are consistent with the stored configuration.
Raises ValueError if there's a mismatch that would cause empty query results.
"""
# Check if metadata table exists
stored_config = self._get_embeddings_config() or {"class": "", "params": {}}
stored_class = stored_config["class"]
stored_params = stored_config["params"]

# Get current embeddings class
current_class = self.embeddings.__class__.__module__ + "." + self.embeddings.__class__.__name__

# Check if classes match
if current_class != stored_class:
raise ValueError(
f"Provided embeddings class '{current_class}' does not match the stored class "
f"'{stored_class}' for this vector store. This would result in empty query results. "
f"Use compatible embeddings or create a new vector store."
)

# Check if critical parameters match
for param_name, stored_value in stored_params.items():
if hasattr(self.embeddings, param_name):
current_value = getattr(self.embeddings, param_name)
if current_value != stored_value and param_name in ['model', 'embedding_dim', 'chunk_size']:
raise ValueError(
f"Provided embeddings parameter '{param_name}' value '{current_value}' "
f"does not match stored value '{stored_value}'. This would result in "
f"empty query results. Use compatible embeddings or create a new vector store."
)


def _get_embeddings_config(self):
"""
Get the embeddings configuration stored in the vector store.

Returns
-------
dict or None
The embeddings configuration or None if not available.
"""
if not self._initialized:
return None

# Check if metadata table exists
has_metadata = (
self.connection.execute(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'vector_store_metadata';"
).fetchone()[0]
> 0
)

if not has_metadata:
return None

try:
result = self.connection.execute(
"SELECT value FROM vector_store_metadata WHERE key = 'embeddings';"
).fetchone()

if result:
return json.loads(result[0])
return None
except Exception as e:
log_debug(f"Error retrieving embeddings configuration: {e}")
return None

async def _add_items(
self,
texts: list[str],
Expand Down
15 changes: 14 additions & 1 deletion lumen/tests/ai/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
except ModuleNotFoundError:
pytest.skip("lumen.ai could not be imported, skipping tests.", allow_module_level=True)

from lumen.ai.embeddings import NumpyEmbeddings
from lumen.ai.embeddings import Embeddings, NumpyEmbeddings
from lumen.ai.vector_store import DuckDBVectorStore, NumpyVectorStore


Expand Down Expand Up @@ -545,3 +545,16 @@ async def test_not_initalized(self, tmp_path):
results = await store.query("First doc")
assert len(results) == 0
store.close()

async def test_check_embeddings_consistency(self, tmp_path):
db_path = str(tmp_path / "test_duckdb.db")
store = DuckDBVectorStore(uri=db_path, embeddings=NumpyEmbeddings())
await store.add([{"text": "First doc"}])
store.close()

store = DuckDBVectorStore(uri=db_path, embeddings=NumpyEmbeddings())
assert len(await store.query("First doc")) == 1
store.close()

with pytest.raises(ValueError, match="Provided embeddings class"):
DuckDBVectorStore(uri=db_path, embeddings=Embeddings())
Loading