1010import pandas as pd
1111from collections import defaultdict
1212from functools import partial
13- from typing import Dict , Optional , Callable
13+ from typing import Dict , List , Optional , Callable , Any , Union
1414
1515from nv_ingest_client .util .milvus import nvingest_retrieval
1616
1717from nv_ingest_harness .utils .cases import get_repo_root
1818
1919
20+ def _compute_beir_metrics (
21+ all_retrieved : List [List [Dict ]],
22+ query_df : pd .DataFrame ,
23+ k_values : List [int ] = [1 , 5 , 10 ],
24+ ) -> Optional [Dict [str , Dict [str , float ]]]:
25+ """
26+ Compute BEIR metrics from retrieval results.
27+
28+ Args:
29+ all_retrieved: List of retrieval results per query. Each result is a list of
30+ dicts with 'entity' containing source info.
31+ query_df: DataFrame with 'query' and 'expected_pdf' columns, optionally 'query_id'.
32+ k_values: Cutoff values for evaluation (default [1, 5, 10]).
33+
34+ Returns:
35+ Dict with keys 'ndcg', 'map', 'recall', 'precision', each containing
36+ metric values like {'NDCG@1': 0.17, 'NDCG@5': 0.35, ...}, or None if BEIR unavailable.
37+ """
38+ try :
39+ from beir .retrieval .evaluation import EvaluateRetrieval
40+ except ImportError :
41+ return None
42+
43+ # Build results dict: {query_id: {doc_id: score}}
44+ results = {}
45+ for idx , answers in enumerate (all_retrieved ):
46+ if "query_id" in query_df .columns :
47+ query_id = str (query_df .iloc [idx ]["query_id" ])
48+ else :
49+ query_id = str (idx )
50+
51+ results [query_id ] = {}
52+ num_results = len (answers )
53+ for rank , r in enumerate (answers ):
54+ source_id = r .get ("entity" , {}).get ("source" , {}).get ("source_id" , "" )
55+ doc_id = os .path .basename (source_id ).split ("." )[0 ]
56+ score = (num_results - rank ) / num_results if num_results > 0 else 0
57+ results [query_id ][doc_id ] = score
58+
59+ # Build qrels dict: {query_id: {doc_id: relevance}}
60+ qrels = {}
61+ for idx , row in query_df .iterrows ():
62+ if "query_id" in query_df .columns :
63+ query_id = str (row ["query_id" ])
64+ else :
65+ query_id = str (idx )
66+ qrels [query_id ] = {str (row ["expected_pdf" ]): 1 }
67+
68+ # Evaluate
69+ ndcg , _map , recall , precision = EvaluateRetrieval .evaluate (qrels , results , k_values , ignore_identical_ids = True )
70+
71+ return {"ndcg" : ndcg , "map" : _map , "recall" : recall , "precision" : precision }
72+
73+
2074def _get_retrieval_func (
2175 vdb_backend : str ,
2276 table_path : Optional [str ] = None ,
@@ -176,7 +230,8 @@ def get_recall_scores_pdf_only(
176230 batch_size : int = 100 ,
177231 vdb_backend : str = "milvus" ,
178232 table_path : Optional [str ] = None ,
179- ) -> Dict [int , float ]:
233+ enable_beir : bool = False ,
234+ ) -> Union [Dict [int , float ], Dict [str , Any ]]:
180235 """
181236 Calculate recall@k scores for queries against a VDB collection using PDF-only matching.
182237
@@ -199,11 +254,14 @@ def get_recall_scores_pdf_only(
199254 batch_size: Number of queries to process per batch (prevents gRPC size limit errors).
200255 vdb_backend: VDB backend to use ("milvus" or "lancedb"). Default is "milvus".
201256 table_path: Path to LanceDB database directory (required if vdb_backend="lancedb").
257+ enable_beir: If True, also compute BEIR metrics (NDCG, MAP, Precision).
202258
203259 Returns:
204- Dictionary mapping k values (1, 5, 10) to recall scores (float 0.0-1.0).
260+ If enable_beir=False: Dictionary mapping k values (1, 5, 10) to recall scores.
261+ If enable_beir=True: Dictionary with 'recall' and 'beir' keys containing metrics.
205262 """
206263 hits = defaultdict (list )
264+ all_retrieved = [] # Collect for BEIR computation
207265
208266 reranker_kwargs = {}
209267 if nv_ranker :
@@ -262,6 +320,10 @@ def get_recall_scores_pdf_only(
262320 ** reranker_kwargs ,
263321 )
264322
323+ # Collect results for BEIR if enabled
324+ if enable_beir :
325+ all_retrieved .extend (batch_answers )
326+
265327 for expected_pdf , retrieved_answers in zip (batch_expected_pdfs , batch_answers ):
266328 # Extract PDF names only (no page numbers)
267329 retrieved_pdfs = [
@@ -276,6 +338,11 @@ def get_recall_scores_pdf_only(
276338
277339 recall_scores = {k : np .mean (hits [k ]) for k in hits if len (hits [k ]) > 0 }
278340
341+ # Compute BEIR metrics if enabled
342+ if enable_beir :
343+ beir_metrics = _compute_beir_metrics (all_retrieved , query_df , k_values = [1 , 5 , 10 ])
344+ return {"recall" : recall_scores , "beir" : beir_metrics }
345+
279346 return recall_scores
280347
281348
@@ -781,7 +848,8 @@ def vidore_recall(
781848 nv_ranker_model_name : Optional [str ] = None ,
782849 vdb_backend : str = "milvus" ,
783850 table_path : Optional [str ] = None ,
784- ) -> Dict [int , float ]:
851+ enable_beir : bool = False ,
852+ ) -> Union [Dict [int , float ], Dict [str , Any ]]:
785853 """
786854 Evaluate recall@k for Vidore V3 dataset using PDF-only matching.
787855
@@ -803,9 +871,11 @@ def vidore_recall(
803871 nv_ranker_model_name: Optional custom reranker model name.
804872 vdb_backend: VDB backend to use ("milvus" or "lancedb"). Default is "milvus".
805873 table_path: Path to LanceDB database directory (required if vdb_backend="lancedb").
874+ enable_beir: If True, also compute BEIR metrics (NDCG, MAP, Precision).
806875
807876 Returns:
808- Dictionary mapping k values (1, 5, 10) to recall scores (float 0.0-1.0).
877+ If enable_beir=False: Dictionary mapping k values (1, 5, 10) to recall scores.
878+ If enable_beir=True: Dictionary with 'recall' and 'beir' keys containing metrics.
809879 """
810880 loader = partial (
811881 vidore_load_ground_truth ,
@@ -829,6 +899,7 @@ def vidore_recall(
829899 nv_ranker_model_name = nv_ranker_model_name ,
830900 vdb_backend = vdb_backend ,
831901 table_path = table_path ,
902+ enable_beir = enable_beir ,
832903 )
833904
834905
0 commit comments