diff --git a/lumen/ai/vector_store.py b/lumen/ai/vector_store.py index 930bd63eb..f627ed398 100644 --- a/lumen/ai/vector_store.py +++ b/lumen/ai/vector_store.py @@ -1,4 +1,5 @@ import asyncio +import importlib import io import json import os @@ -68,6 +69,7 @@ class VectorStore(LLMUser): embeddings = param.ClassSelector( class_=Embeddings, default=NumpyEmbeddings(), + allow_None=True, doc="Embeddings object for text processing.", ) @@ -897,8 +899,16 @@ class DuckDBVectorStore(VectorStore): uri = param.String(default=":memory:", doc="The URI of the DuckDB database") + embeddings = param.ClassSelector( + class_=Embeddings, + default=None, + allow_None=True, + doc="Embeddings object for text processing. If None and a URI is provided, loads from the database; else NumpyEmbeddings.", + ) + def __init__(self, **params): super().__init__(**params) + connection = duckdb.connect(":memory:") # following the instructions from # https://duckdb.org/docs/stable/extensions/vss.html#persistence @@ -930,6 +940,19 @@ def __init__(self, **params): ) self._initialized = uri_exists and has_documents + if self.uri != ":memory:" and self._initialized: + config = self._get_embeddings_config() + if config and self.embeddings is None: + module_name, class_name = config["class"].rsplit(".", 1) + module = importlib.import_module(module_name) + embedding_class = getattr(module, class_name) + self.embeddings = embedding_class(**config["params"]) + log_debug(f"Loaded embeddings {class_name} from database.") + self._check_embeddings_consistency() + + if self.embeddings is None: + self.embeddings = NumpyEmbeddings() + def _setup_database(self, embedding_dim: int) -> None: """Set up the DuckDB database with necessary tables and indexes.""" self.connection.execute("CREATE SEQUENCE IF NOT EXISTS documents_id_seq;") @@ -945,6 +968,34 @@ def _setup_database(self, embedding_dim: int) -> None: """ ) + self.connection.execute( + """ + CREATE TABLE IF NOT EXISTS vector_store_metadata ( + key VARCHAR PRIMARY KEY, + value JSON + ); + """ + ) + + # Store embedding configuration + embedding_info = { + "class": self.embeddings.__class__.__module__ + "." + self.embeddings.__class__.__name__, + "params": {} + } + for param_name, param_obj in self.embeddings.param.objects().items(): + if param_name not in ['name']: + value = getattr(self.embeddings, param_name) + if isinstance(value, (str, int, float, bool, list, dict)) or value is None: + embedding_info["params"][param_name] = value + + self.connection.execute( + """ + INSERT OR REPLACE INTO vector_store_metadata (key, value) + VALUES ('embeddings', ?::JSON); + """, + [json.dumps(embedding_info)] + ) + self.connection.execute( """ CREATE INDEX IF NOT EXISTS embedding_index @@ -953,6 +1004,74 @@ def _setup_database(self, embedding_dim: int) -> None: ) self._initialized = True + def _check_embeddings_consistency(self): + """ + Check if the provided embeddings are consistent with the stored configuration. + Raises ValueError if there's a mismatch that would cause empty query results. + """ + # Check if metadata table exists + stored_config = self._get_embeddings_config() or {"class": "", "params": {}} + stored_class = stored_config["class"] + stored_params = stored_config["params"] + + # Get current embeddings class + current_class = self.embeddings.__class__.__module__ + "." + self.embeddings.__class__.__name__ + + # Check if classes match + if current_class != stored_class: + raise ValueError( + f"Provided embeddings class '{current_class}' does not match the stored class " + f"'{stored_class}' for this vector store. This would result in empty query results. " + f"Use compatible embeddings or create a new vector store." + ) + + # Check if critical parameters match + for param_name, stored_value in stored_params.items(): + if hasattr(self.embeddings, param_name): + current_value = getattr(self.embeddings, param_name) + if current_value != stored_value and param_name in ['model', 'embedding_dim', 'chunk_size']: + raise ValueError( + f"Provided embeddings parameter '{param_name}' value '{current_value}' " + f"does not match stored value '{stored_value}'. This would result in " + f"empty query results. Use compatible embeddings or create a new vector store." + ) + + + def _get_embeddings_config(self): + """ + Get the embeddings configuration stored in the vector store. + + Returns + ------- + dict or None + The embeddings configuration or None if not available. + """ + if not self._initialized: + return None + + # Check if metadata table exists + has_metadata = ( + self.connection.execute( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'vector_store_metadata';" + ).fetchone()[0] + > 0 + ) + + if not has_metadata: + return None + + try: + result = self.connection.execute( + "SELECT value FROM vector_store_metadata WHERE key = 'embeddings';" + ).fetchone() + + if result: + return json.loads(result[0]) + return None + except Exception as e: + log_debug(f"Error retrieving embeddings configuration: {e}") + return None + async def _add_items( self, texts: list[str], diff --git a/lumen/tests/ai/test_vector_store.py b/lumen/tests/ai/test_vector_store.py index b9a0356f2..94c5cacd7 100644 --- a/lumen/tests/ai/test_vector_store.py +++ b/lumen/tests/ai/test_vector_store.py @@ -5,7 +5,7 @@ except ModuleNotFoundError: pytest.skip("lumen.ai could not be imported, skipping tests.", allow_module_level=True) -from lumen.ai.embeddings import NumpyEmbeddings +from lumen.ai.embeddings import Embeddings, NumpyEmbeddings from lumen.ai.vector_store import DuckDBVectorStore, NumpyVectorStore @@ -545,3 +545,16 @@ async def test_not_initalized(self, tmp_path): results = await store.query("First doc") assert len(results) == 0 store.close() + + async def test_check_embeddings_consistency(self, tmp_path): + db_path = str(tmp_path / "test_duckdb.db") + store = DuckDBVectorStore(uri=db_path, embeddings=NumpyEmbeddings()) + await store.add([{"text": "First doc"}]) + store.close() + + store = DuckDBVectorStore(uri=db_path, embeddings=NumpyEmbeddings()) + assert len(await store.query("First doc")) == 1 + store.close() + + with pytest.raises(ValueError, match="Provided embeddings class"): + DuckDBVectorStore(uri=db_path, embeddings=Embeddings())