Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c3ddd40
[DAGE-73] bug fixed by adding all doc to datastore after saving it
nicolo-rinaldi Sep 12, 2025
d5bc1a9
[DAGE-73] added tests
nicolo-rinaldi Sep 12, 2025
6c11b78
[DAGE-73] Updated version with new DataStore autosave feature
nicolo-rinaldi Sep 29, 2025
0771718
[DAGE-73] added `uv sync --group dev` to README.md
nicolo-rinaldi Sep 30, 2025
3f0e7ed
[DAGE-73] typo in main.py datastore variable
nicolo-rinaldi Sep 30, 2025
f485770
[DAGE-73] added robustness to fetch_all
nicolo-rinaldi Sep 30, 2025
c71af17
[DAGE-73] rolled back to `_write_embeddings_jsonl` old version in Emb…
nicolo-rinaldi Sep 30, 2025
ed6b963
[DAGE-73] 2 big fixes
nicolo-rinaldi Sep 30, 2025
5693ef4
[DAGE-73] reverted back max_tokens to be ignored by ruff
nicolo-rinaldi Oct 1, 2025
d427743
[DAGE-73] workaround modified to avoid changing the datastore persist…
nicolo-rinaldi Oct 1, 2025
52310da
[DAGE-73] fixed description bug
nicolo-rinaldi Oct 6, 2025
7f67595
[DAGE-73] extended fetch_all() method to be used with vespa_search_en…
nicolo-rinaldi Oct 6, 2025
b98a6d2
[DAGE-73] uniformed config param in M1 and M2
nicolo-rinaldi Oct 6, 2025
0041e5f
[DAGE-73] centralized function setup_logging
nicolo-rinaldi Oct 6, 2025
04a3949
[DAGE-73] fixed hardcoded description and coded the streamed version …
nicolo-rinaldi Oct 7, 2025
e63d58e
[DAGE-73] added _get_total_hits to do the streamed version of fetch_a…
nicolo-rinaldi Oct 7, 2025
e4fb400
[DAGE-73] vespa should be fixed, I need to update the tests since fet…
nicolo-rinaldi Oct 7, 2025
41d5d1f
[DAGE-73] fixed vespa tests
nicolo-rinaldi Oct 8, 2025
3c7ca66
[DAGE-73] addressing Naz comments
nicolo-rinaldi Oct 8, 2025
50e9e2c
[DAGE-73] added break in while loop
nicolo-rinaldi Oct 8, 2025
715f2f2
[DAGE-73] addressing Daniele's comments
nicolo-rinaldi Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rre-tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ cd rre-tools
# install dependencies (for users)
uv sync

# install optional dev dependencies such as mypy/ruff
uv sync --extra dev
# install development dependencies as well (e.g., mypy and ruff)
uv sync --group dev
```

## Running Dataset Generator (DAGE)
Expand All @@ -49,7 +49,7 @@ at the Dataset Generator [README](dataset-generator/README.md).

Execute the main script via CLI, pointing to your DAGE configuration file:
```bash
uv run dataset-generator --config_file <path-to-DAGE-config-yaml>
uv run dataset-generator --config <path-to-DAGE-config-yaml>
```
To know more about all the possible CLI parameters, execute:
```bash
Expand Down
7 changes: 7 additions & 0 deletions rre-tools/commons/src/commons/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def configure_logging(level: Union[str, int] = logging.INFO) -> None:
datefmt='%H:%M:%S'
)

def setup_logging(verbose: bool = False) -> None:
if verbose:
configure_logging(logging.DEBUG)
else:
configure_logging(logging.INFO)
return

# EXAMPLE:
if __name__ == "__main__":
# 1. Configure logging
Expand Down
6 changes: 5 additions & 1 deletion rre-tools/commons/src/commons/writers/mteb_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def _write_corpus(self, corpus_path: Path, datastore: DataStore) -> None:
doc_id = str(doc.id)
fields = doc.fields
title = _to_string(fields.get("title"))
text = _to_string(fields.get("description"))
text = " ".join(
_to_string(value)
for key, value in fields.items()
if key != "title"
)

row = {"id": doc_id, "title": title, "text": text}
file.write(json.dumps(row, ensure_ascii=False) + "\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def build_openai(config: LLMConfig) -> BaseChatModel:
log.debug("Building OpenAI ChatModel using model=%s", config.model)
return ChatOpenAI(
model=config.model,
# max_tokens=config.max_tokens, # commented due to the fact that ruff is saying there is no max_tokens param
max_tokens=config.max_tokens, # type: ignore[arg-type]
api_key=SecretStr(key),
)

Expand Down
54 changes: 37 additions & 17 deletions rre-tools/dataset-generator/src/dataset_generator/main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from __future__ import annotations

# ------ temporary import for corpus.json bug workaround ------
import json
from pathlib import Path
from commons.utils import _to_string
# -------------------------------------------------------------

from typing import List
from langchain_core.language_models import BaseChatModel
from logging import Logger, getLogger, DEBUG, INFO
from logging import Logger, getLogger

# project imports
from dataset_generator.config import Config
from dataset_generator.utils import parse_args
from commons.logger import configure_logging
from commons.logger import setup_logging
from dataset_generator.llm import LLMConfig, LLMService, LLMServiceFactory
from commons.model import Document, Query, LLMQueryResponse, LLMScoreResponse, WriterConfig
from commons.model import Document, Query, LLMQueryResponse, LLMScoreResponse, WriterConfig
from commons.writers import WriterFactory, AbstractWriter
from dataset_generator.search_engine import SearchEngineFactory, BaseSearchEngine
from commons.data_store import DataStore


log: Logger = getLogger(__name__)


def setup_logging(verbose: bool = False) -> None:
if verbose:
configure_logging(DEBUG)
else:
configure_logging(INFO)
return


def add_user_queries(config: Config, data_store: DataStore) -> None:
"""Loads queries from file (if exists) and adds them as Query objects."""
Expand All @@ -35,7 +34,8 @@ def add_user_queries(config: Config, data_store: DataStore) -> None:
data_store.add_query(clean_line)


def generate_and_add_queries(config: Config, data_store: DataStore, llm_service: LLMService, search_engine: BaseSearchEngine) -> None:
def generate_and_add_queries(config: Config, data_store: DataStore, llm_service: LLMService,
search_engine: BaseSearchEngine) -> None:
"""Retrieve docs and generate queries with LLM Service. Adds docs, queries and ratings to the datastore."""
docs_to_generate_queries: List[Document] = search_engine.fetch_for_query_generation(
documents_filter=config.documents_filter,
Expand Down Expand Up @@ -83,7 +83,7 @@ def add_cartesian_product_scores(config: Config, data_store: DataStore, llm_serv


def expand_docset_with_search_engine_top_k(config: Config, data_store: DataStore,
llm_service: LLMService, search_engine: BaseSearchEngine) -> None:
llm_service: LLMService, search_engine: BaseSearchEngine) -> None:
"""Retrieve docs for each query and score the (q, doc) pairs."""
if config.query_template is not None:
log.debug(f"Searching for documents with query template in {config.query_template}")
Expand All @@ -105,11 +105,10 @@ def expand_docset_with_search_engine_top_k(config: Config, data_store: DataStore
log.warning("Query template not found. Skipping retrieval.")



def main() -> None:
# configuration and logger definition
args = parse_args()
config: Config = Config.load(args.config_file)
config: Config = Config.load(args.config)
writer_config: WriterConfig = config.build_writer_config()
setup_logging(args.verbose)

Expand All @@ -124,7 +123,7 @@ def main() -> None:
llm: BaseChatModel = LLMServiceFactory.build(LLMConfig.load(config.llm_configuration_file))
service: LLMService = LLMService(chat_model=llm)
writer: AbstractWriter = WriterFactory.build(writer_config)

# load user queries
add_user_queries(config, data_store)

Expand All @@ -139,15 +138,36 @@ def main() -> None:

# write results
output_destination = config.output_destination
writer.write(output_destination, data_store)
log.info(f"Synthetic Dataset has been generated in: {output_destination}")
data_store.save()
writer.write(output_destination, data_store)

# save explanation - forced to extract value before invoking export_all_records_with_explanation (mypy)
if config.save_llm_explanation:
if llm_explanation_path := config.llm_explanation_destination:
data_store.export_all_records_with_explanation(llm_explanation_path)
log.info(f"Dataset with LLM explanation is saved into: {llm_explanation_path}")

# TODO:
# work on a better solution, instead of overwriting the corpus.json file, and maybe modify the MtebWriter with the
# fetch from the search engine
if config.output_format == "mteb":
# copy pasted from MtebWriter
corpus_path = Path(output_destination) / "corpus.jsonl"
corpus_path.unlink(missing_ok=True)
with corpus_path.open("a", encoding="utf-8") as file:
for doc in search_engine.fetch_all(doc_fields=config.doc_fields):
doc_id = str(doc.id)
fields = doc.fields
title = _to_string(fields.get("title"))
text = " ".join(
_to_string(value)
for key, value in fields.items()
if key != "title"
)

row = {"id": doc_id, "title": title, "text": text}
file.write(json.dumps(row, ensure_ascii=False) + "\n")

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,30 @@ def __init__(self, endpoint: HttpUrl):
log.debug(f"Working on endpoint: {self.endpoint}")
self.UNIQUE_KEY = "_id"

def _get_total_hits(self, payload: Dict[str, Any]) -> int:
search_url = urljoin(self.endpoint.encoded_string(), '_search')

log.debug(f"Search url: {search_url}")
log.debug(f"Payload: {payload}")

try:
response = requests.post(search_url, headers=self.HEADERS, json=payload)
response.raise_for_status()
except (ConnectionError, Timeout, RequestException, HTTPError) as e:
log.error(f"ElasticSearch query failed: {e}")
raise

return int(response.json().get('hits', {}).get('total', {}).get('value', 0))

@property
def _fetch_all_payload(self) -> Dict[str, Any]:
return {"match_all": {}}

def fetch_for_query_generation(self,
documents_filter: Union[None, List[Dict[str, List[str]]]],
doc_number: int,
doc_fields: List[str]) -> List[Document]:
doc_fields: List[str],
start: int = 0) -> List[Document]:
"""
Fetches a set of documents from Elasticsearch for query generation purposes.

Expand All @@ -37,12 +57,13 @@ def fetch_for_query_generation(self,
Each filter is a dictionary mapping field names to allowed values.
doc_number (int): Number of documents to retrieve.
doc_fields (List[str]): List of field names to include in the output.
start (int, optional): Starting index. Defaults to 0.

Returns:
List[Document]: A list of documents formatted as `Document` instances.
"""
# Build base query
query: Dict[str, Any] = {"match_all": {}}
query: Dict[str, Any] = self._fetch_all_payload

# Add filters, if provided
filter_clauses = []
Expand All @@ -66,6 +87,7 @@ def fetch_for_query_generation(self,
payload = {
"size": doc_number,
"query": query,
"from": start,
"_source": doc_fields
}

Expand All @@ -88,10 +110,6 @@ def fetch_for_evaluation(self, query_template: Path | str, doc_fields: List[str]
payload: Dict[str, Any] = self._parse_query_template(query_template)
payload = self._replace_placeholder(payload, self.QUERY_PLACEHOLDER, keyword)

# query_string_obj = payload.get("query", {}).get("query_string", {})
# if "query" in query_string_obj:
# query_string_obj["query"] = query_string_obj["query"].replace(self.QUERY_PLACEHOLDER, keyword)

fields = doc_fields if self.UNIQUE_KEY in doc_fields else doc_fields + [self.UNIQUE_KEY]
payload["_source"] = fields
return self._search(payload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,29 @@ def __init__(self, endpoint: HttpUrl):
self.HEADERS = {'Content-Type': 'application/json'}
self.UNIQUE_KEY = "id"

def _get_total_hits(self, payload: Dict[str, Any]) -> int:
search_url = f"{self.endpoint}/_search"
log.debug(f"User-specified fields: {payload.get('_source')}")
log.debug(f"Search url: {search_url}")
log.debug(f"Payload: {payload}")
try:
response = requests.post(search_url, headers=self.HEADERS, json=payload)
response.raise_for_status()
except (ConnectionError, Timeout, RequestException, HTTPError) as e:
log.error(f"OpenSearch query failed: {e}")
raise

return int(response.json().get('hits', {}).get('total', {}).get('value', 0))

@property
def _fetch_all_payload(self) -> Dict[str, Any]:
return {"match_all": {}}

def fetch_for_query_generation(self,
documents_filter: Union[None, List[Dict[str, List[str]]]],
doc_number: int,
doc_fields: List[str]) -> List[Document]:
doc_fields: List[str],
start: int = 0) -> List[Document]:
"""Fetches a list of documents for query generation based on optional filters."""
filters: List[Dict[str, Any]] = []
if documents_filter:
Expand All @@ -42,20 +61,20 @@ def fetch_for_query_generation(self,

fields = doc_fields if self.UNIQUE_KEY in doc_fields else doc_fields + [self.UNIQUE_KEY]

query: Dict[str, Any] = {}
if filters:
query = {
"bool": {
"filter": filters
}
}
else:
query = {
"match_all": {}
}
query = self._fetch_all_payload

payload = {
"query": query,
"_source": fields,
"from": start,
"size": doc_number
}

Expand All @@ -67,10 +86,6 @@ def fetch_for_evaluation(self, query_template: Path | str, doc_fields: List[str]
payload: Dict[str, Any] = self._parse_query_template(query_template)
payload = self._replace_placeholder(payload, self.QUERY_PLACEHOLDER, keyword)

# query_string_obj = payload.get("query", {}).get("query_string", {})
# if "query" in query_string_obj:
# query_string_obj["query"] = query_string_obj["query"].replace(self.QUERY_PLACEHOLDER, keyword)

fields = doc_fields if self.UNIQUE_KEY in doc_fields else doc_fields + [self.UNIQUE_KEY]
payload["_source"] = fields

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,48 @@
from abc import ABC, abstractmethod
from json import JSONDecodeError
from pathlib import Path
from typing import List, Dict, Any, Union
from typing import List, Dict, Any, Union, Iterator
from pydantic import HttpUrl
from commons.model.document import Document

DOC_NUMBER_EACH_FETCH = 100

class BaseSearchEngine(ABC):
def __init__(self, endpoint: HttpUrl):
self.endpoint = HttpUrl(endpoint)
self.QUERY_PLACEHOLDER = "$query"
self.UNIQUE_KEY = 'id'

def fetch_all(self, doc_fields: List[str]) -> Iterator[Document]:
"""Extract all documents from search engine in batches.

Yields batches of documents instead of loading everything in memory.

Args:
doc_fields: Fields to extract from documents

Yields:
List[Document]: Batch of documents
"""
# Now this is relying on fetch_for_query_generation to avoid duplicate code. Might be changed in the future
start: int = 0
total_hits: int = self._get_total_hits(self._fetch_all_payload)
while start < total_hits:
batch = self.fetch_for_query_generation(
documents_filter=None,
doc_number=DOC_NUMBER_EACH_FETCH,
doc_fields=doc_fields,
start=start
)
if not batch:
break
for doc in batch:
yield doc
# if we didn't reach the end of the docs, then len(batch) == DOC_NUMBER_EACH_FETCH
# if we reached the end of the docs. then len(batch) <= DOC_NUMBER_EACH_FETCH -> next iteration we exit the
# loop since we are adding DOC_NUMBER_EACH_FETCH (not len(batch)) and start becomes greater than total_hits
start += DOC_NUMBER_EACH_FETCH


def _parse_query_template(self, path: Path | str) -> Dict[str, Any]:
"""Return the payload"""
Expand Down Expand Up @@ -41,10 +73,11 @@ def _replace_placeholder(self, obj: Any, placeholder: str, keyword: str | None)
def fetch_for_query_generation(self,
documents_filter: Union[None, List[Dict[str, List[str]]]],
doc_number: int,
doc_fields: List[str]) \
doc_fields: List[str],
start: int = 0) \
-> List[Document]:
"""Extract documents for generating queries."""
raise NotImplementedError
pass

@abstractmethod
def fetch_for_evaluation(self,
Expand All @@ -53,9 +86,21 @@ def fetch_for_evaluation(self,
keyword: str="*:*") \
-> List[Document]:
"""Search for documents based on a keyword and a query template to evaluate the system."""
raise NotImplementedError
pass

@abstractmethod
def _search(self, payload: Dict[str, Any]) -> List[Document]:
"""Search for documents using a query."""
raise NotImplementedError
pass

@abstractmethod
def _get_total_hits(self, payload: Dict[str, Any]) -> int:
"""Get the total number of documents returned by a query."""
pass

@property
@abstractmethod
def _fetch_all_payload(self) -> Dict[str, Any]:
"""Payload to fetch all documents from the search engine."""
pass

Loading