diff --git a/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/main.py b/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/main.py index 88b46feb8..c1496ebef 100644 --- a/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/main.py +++ b/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/main.py @@ -119,7 +119,6 @@ def main() -> None: cached=model_with_cache, cache_path=CACHE_PATH, task_name=TASKS_NAME_MAPPING.get(config.task_to_evaluate, "CustomRetrievalTask"), - normalize_embeddings=True, batch_size=256, ) log.info(f"Writing embeddings to {config.embeddings_dest} ...") diff --git a/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/writers/embedding_writer.py b/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/writers/embedding_writer.py index 2c646674a..bddc6b20e 100644 --- a/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/writers/embedding_writer.py +++ b/rre-tools/embedding-model-evaluator/src/embedding_model_evaluator/writers/embedding_writer.py @@ -38,14 +38,12 @@ def __init__( cached: CachedEmbeddingWrapper, cache_path: str | Path, task_name: str, - normalize_embeddings: bool, batch_size: int, ): self.config = config self.cached = cached self.cache_path = Path(cache_path) self.task_name = task_name - self.normalize_embeddings = normalize_embeddings self.batch_size = batch_size def write(self, embedding_path: str | Path | None) -> None: @@ -69,10 +67,8 @@ def write(self, embedding_path: str | Path | None) -> None: ] doc_vectors = self.cached.encode( - doc_texts, + texts=doc_texts, task_name=self.task_name, - name=f"{self.task_name}-corpus", - normalize_embeddings=self.normalize_embeddings, batch_size=self.batch_size, ) _write_embeddings_jsonl(documents_path, zip(doc_ids, doc_vectors)) @@ -84,10 +80,8 @@ def write(self, embedding_path: str | Path | None) -> None: query_texts = [query_dict[qid] for qid in query_ids] query_vectors = self.cached.encode( - query_texts, + texts=query_texts, task_name=self.task_name, - name=f"{self.task_name}-queries", - normalize_embeddings=self.normalize_embeddings, batch_size=self.batch_size, ) _write_embeddings_jsonl(queries_path, zip(query_ids, query_vectors)) diff --git a/rre-tools/embedding-model-evaluator/tests/unit/test_embedding_writer.py b/rre-tools/embedding-model-evaluator/tests/unit/test_embedding_writer.py index 783de06d1..6093424c7 100644 --- a/rre-tools/embedding-model-evaluator/tests/unit/test_embedding_writer.py +++ b/rre-tools/embedding-model-evaluator/tests/unit/test_embedding_writer.py @@ -28,18 +28,14 @@ def _encode( texts: list[str], *, task_name: str, - name: str, - normalize_embeddings: bool, batch_size: int, ) -> np.ndarray: - assert task_name == "test_custom_task" - assert normalize_embeddings is True assert batch_size == 32 - if name.endswith("-corpus"): + if task_name.endswith("-corpus"): return np.vstack([np.asarray(vector) for vector in doc_vectors]) - if name.endswith("-queries"): + if task_name.endswith("-queries"): return np.vstack([np.asarray(vector) for vector in query_vectors]) - raise AssertionError(f"Unexpected encode name: {name}") + raise AssertionError(f"Unexpected encode name: {task_name}") cached.encode.side_effect = _encode return cached @@ -59,8 +55,7 @@ def test_embeddings_writer_with_valid_inputs__expects__creates_jsonl_files_with_ config=config, cached=cached, cache_path=tmp_path / "cache", - task_name="test_custom_task", - normalize_embeddings=True, + task_name="test_custom_task-corpus", batch_size=32, ) @@ -68,14 +63,26 @@ def test_embeddings_writer_with_valid_inputs__expects__creates_jsonl_files_with_ embedding_dir = config.embeddings_dest docs_file = embedding_dir / "documents_embeddings.jsonl" - queries_file = embedding_dir / "queries_embeddings.jsonl" - assert docs_file.exists() - assert queries_file.exists() + assert docs_file.exists() with jsonlines.open(docs_file) as r: docs = list(r) + assert docs == [{"id": "doc1", "vector": [0.1, 0.2, 0.3]}] + + # recreating again because of fake cached embedding wrapper for queries and corpus vectors + writer = EmbeddingWriter( + config=config, + cached=cached, + cache_path=tmp_path / "cache", + task_name="test_custom_task-queries", + batch_size=32, + ) + + writer.write(config.embeddings_dest) + queries_file = embedding_dir / "queries_embeddings.jsonl" + assert queries_file.exists() + with jsonlines.open(queries_file) as r: queries = list(r) - - assert docs == [{"id": "doc1", "vector": [0.1, 0.2, 0.3]}] assert queries == [{"id": "query1", "vector": [1.0, 1.1, 1.2]}] + diff --git a/rre-tools/pyproject.toml b/rre-tools/pyproject.toml index b0c551553..a32fc6f4a 100644 --- a/rre-tools/pyproject.toml +++ b/rre-tools/pyproject.toml @@ -35,7 +35,8 @@ dev = [ "ruff>=0.12.10", "typing-extensions>=4.14.1", "types-PyYAML>=6.0.2", - "types-requests>=2.32.4" + "types-requests>=2.32.4", + "setuptools>=80.9.0" ] [tool.pytest.ini_options] diff --git a/rre-tools/tests/test_cross_plataform.py b/rre-tools/tests/test_cross_plataform.py index c9efd4b3c..077d34427 100644 --- a/rre-tools/tests/test_cross_plataform.py +++ b/rre-tools/tests/test_cross_plataform.py @@ -44,12 +44,12 @@ def __init__(self, doc_vectors: Sequence[Sequence[float]], query_vectors: Sequen self._doc_vectors = doc_vectors self._query_vectors = query_vectors - def encode(self, texts, *, task_name: str, name: str, normalize_embeddings: bool, batch_size: int): - if name.endswith("-corpus"): + def encode(self, texts, *, task_name: str, batch_size: int): + if task_name.endswith("-corpus"): return self._doc_vectors - if name.endswith("-queries"): + if task_name.endswith("-queries"): return self._query_vectors - raise AssertionError(f"Unexpected encode name: {name}") + raise AssertionError(f"Unexpected encode name: {task_name}") def close(self) -> None: pass @@ -71,22 +71,30 @@ def test_embedding_writer_with_nested_dirs__expects__creates_files_in_nested_dir config=cfg, cached=cached, cache_path=tmp_path / "cache", - task_name="test_custom_task", - normalize_embeddings=True, + task_name="test_custom_task-corpus", batch_size=32, ) writer.write(dest) docs_file = dest / "documents_embeddings.jsonl" - queries_file = dest / "queries_embeddings.jsonl" - assert docs_file.exists() - assert queries_file.exists() - # sanity read with jsonlines.open(docs_file) as r: _ = list(r) + + writer = EmbeddingWriter( + config=cfg, + cached=cached, + cache_path=tmp_path / "cache", + task_name="test_custom_task-queries", + batch_size=32, + ) + + writer.write(dest) + + queries_file = dest / "queries_embeddings.jsonl" + assert queries_file.exists() with jsonlines.open(queries_file) as r: _ = list(r) diff --git a/rre-tools/uv.lock b/rre-tools/uv.lock index f7720c22d..4947169f0 100644 --- a/rre-tools/uv.lock +++ b/rre-tools/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -479,7 +479,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -2562,6 +2562,7 @@ dev = [ { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "setuptools" }, { name = "types-pyyaml" }, { name = "types-requests" }, { name = "typing-extensions" }, @@ -2585,6 +2586,7 @@ dev = [ { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-cov", specifier = ">=6.2.1" }, { name = "ruff", specifier = ">=0.12.10" }, + { name = "setuptools", specifier = ">=80.9.0" }, { name = "types-pyyaml", specifier = ">=6.0.2" }, { name = "types-requests", specifier = ">=2.32.4" }, { name = "typing-extensions", specifier = ">=4.14.1" },