Skip to content

Commit 885978a

Browse files
jperez999jdye64
andauthored
Add rerank (#1565)
Co-authored-by: Jeremy Dyer <jdye64@gmail.com>
1 parent f3d44a9 commit 885978a

File tree

17 files changed

+1848
-196
lines changed

17 files changed

+1848
-196
lines changed

nemo_retriever/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ dependencies = [
6363
"nemotron-ocr>=0.dev0",
6464
"markitdown",
6565
"timm==1.0.22",
66+
"tqdm",
6667
"accelerate==1.12.0",
6768
"albumentations==2.0.8",
6869
"scikit-learn>=1.6.0",

nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,14 @@ def main(
425425
"--runtime-metrics-prefix",
426426
help="Optional filename prefix for per-run metrics artifacts.",
427427
),
428+
reranker: Optional[bool] = typer.Option(
429+
False, "--reranker/--no-reranker", help="Enable a re-ranking stage with a cross-encoder model."
430+
),
431+
reranker_model_name: str = typer.Option(
432+
"nvidia/llama-nemotron-rerank-1b-v2",
433+
"--reranker-model-name",
434+
help="Cross-encoder model name for re-ranking stage (passed to .embed()).",
435+
),
428436
structured_elements_modality: Optional[str] = typer.Option(
429437
None,
430438
"--structured-elements-modality",
@@ -782,6 +790,7 @@ def _extract_params(batch_tuning: dict, **overrides: Any) -> ExtractParams:
782790
ks=(1, 5, 10),
783791
hybrid=hybrid,
784792
match_mode=recall_match_mode,
793+
reranker=reranker_model_name if reranker else None,
785794
)
786795

787796
# Capture recall only times.

nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def lancedb_schema(vector_dim: int = 2048) -> Any:
197197
pa.field("pdf_basename", pa.string()),
198198
pa.field("page_number", pa.int32()),
199199
pa.field("source", pa.string()),
200-
pa.field("source_id", pa.string()),
201200
pa.field("path", pa.string()),
202201
pa.field("text", pa.string()),
203202
pa.field("metadata", pa.string()),

nemo_retriever/src/nemo_retriever/model/local/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"NemotronTableStructureV1",
1818
"NemotronGraphicElementsV1",
1919
"NemotronParseV12",
20+
"NemotronRerankV2",
2021
"ParakeetCTC1B1ASR",
2122
]
2223

@@ -42,6 +43,10 @@ def __getattr__(name: str):
4243
from .nemotron_parse_v1_2 import NemotronParseV12
4344

4445
return NemotronParseV12
46+
if name == "NemotronRerankV2":
47+
from .nemotron_rerank_v2 import NemotronRerankV2
48+
49+
return NemotronRerankV2
4550
if name == "ParakeetCTC1B1ASR":
4651
from .parakeet_ctc_1_1b_asr import ParakeetCTC1B1ASR
4752

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2+
# All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Local wrapper for nvidia/llama-nemotron-rerank-1b-v2 cross-encoder reranker."""
6+
7+
from __future__ import annotations
8+
9+
from typing import List, Optional
10+
11+
from nemo_retriever.utils.hf_cache import configure_global_hf_cache_base
12+
from ..model import BaseModel, RunMode
13+
14+
15+
_DEFAULT_MODEL = "nvidia/llama-nemotron-rerank-1b-v2"
16+
_DEFAULT_MAX_LENGTH = 512
17+
_DEFAULT_BATCH_SIZE = 32
18+
19+
20+
def _prompt_template(query: str, passage: str) -> str:
21+
"""Format a (query, passage) pair as the model expects."""
22+
return f"question:{query} \n \n passage:{passage}"
23+
24+
25+
class NemotronRerankV2(BaseModel):
26+
"""
27+
Local cross-encoder reranker wrapping nvidia/llama-nemotron-rerank-1b-v2.
28+
29+
The model scores (query, document) pairs and returns raw logits; higher
30+
values indicate greater relevance. It is fine-tuned from
31+
meta-llama/Llama-3.2-1B with bi-directional attention and supports 26
32+
languages with sequences up to 8 192 tokens.
33+
34+
Example::
35+
36+
reranker = NemotronRerankV2()
37+
scores = reranker.score("What is ML?", ["Machine learning is…", "Paris is…"])
38+
# scores -> [20.6, -23.1] (higher = more relevant)
39+
"""
40+
41+
def __init__(
42+
self,
43+
model_name: str = _DEFAULT_MODEL,
44+
device: Optional[str] = None,
45+
hf_cache_dir: Optional[str] = None,
46+
) -> None:
47+
super().__init__()
48+
import torch
49+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
50+
51+
configure_global_hf_cache_base()
52+
53+
self._model_name = model_name
54+
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
55+
56+
kwargs: dict = {"trust_remote_code": True}
57+
if hf_cache_dir:
58+
kwargs["cache_dir"] = hf_cache_dir
59+
60+
self._tokenizer = AutoTokenizer.from_pretrained(
61+
model_name,
62+
padding_side="left",
63+
**kwargs,
64+
)
65+
if self._tokenizer.pad_token is None:
66+
self._tokenizer.pad_token = self._tokenizer.eos_token
67+
68+
self._model = (
69+
AutoModelForSequenceClassification.from_pretrained(
70+
model_name,
71+
torch_dtype=torch.bfloat16,
72+
**kwargs,
73+
)
74+
.eval()
75+
.to(self._device)
76+
)
77+
78+
if self._model.config.pad_token_id is None:
79+
self._model.config.pad_token_id = self._tokenizer.eos_token_id
80+
81+
# ------------------------------------------------------------------
82+
# BaseModel abstract properties
83+
# ------------------------------------------------------------------
84+
85+
@property
86+
def model_name(self) -> str:
87+
return self._model_name
88+
89+
@property
90+
def model_type(self) -> str:
91+
return "reranker"
92+
93+
@property
94+
def model_runmode(self) -> RunMode:
95+
return "local"
96+
97+
@property
98+
def input(self):
99+
return "List[Tuple[str, str]]"
100+
101+
@property
102+
def output(self):
103+
return "List[float]"
104+
105+
@property
106+
def input_batch_size(self) -> int:
107+
return _DEFAULT_BATCH_SIZE
108+
109+
# ------------------------------------------------------------------
110+
# Public API
111+
# ------------------------------------------------------------------
112+
113+
def score(
114+
self,
115+
query: str,
116+
documents: List[str],
117+
*,
118+
max_length: int = _DEFAULT_MAX_LENGTH,
119+
batch_size: int = _DEFAULT_BATCH_SIZE,
120+
) -> List[float]:
121+
"""
122+
Score relevance of *documents* to *query*.
123+
124+
Parameters
125+
----------
126+
query:
127+
The search query.
128+
documents:
129+
Candidate passages/documents to score.
130+
max_length:
131+
Tokenizer truncation length (default 512; max supported 8 192).
132+
batch_size:
133+
Number of (query, doc) pairs to process per GPU forward pass.
134+
135+
Returns
136+
-------
137+
List[float]
138+
Raw logit scores aligned with *documents* (higher = more relevant).
139+
"""
140+
import torch
141+
142+
if not documents:
143+
return []
144+
145+
texts = [_prompt_template(query, d) for d in documents]
146+
all_scores: List[float] = []
147+
148+
with torch.inference_mode():
149+
for start in range(0, len(texts), batch_size):
150+
chunk = texts[start : start + batch_size]
151+
batch = self._tokenizer(
152+
chunk,
153+
padding=True,
154+
truncation=True,
155+
return_tensors="pt",
156+
max_length=max_length,
157+
)
158+
batch = {k: v.to(self._device) for k, v in batch.items()}
159+
logits = self._model(**batch).logits
160+
all_scores.extend(logits.view(-1).cpu().tolist())
161+
162+
return all_scores
163+
164+
def score_pairs(
165+
self,
166+
pairs: List[tuple],
167+
*,
168+
max_length: int = _DEFAULT_MAX_LENGTH,
169+
batch_size: int = _DEFAULT_BATCH_SIZE,
170+
) -> List[float]:
171+
"""
172+
Score a list of (query, document) pairs.
173+
174+
Parameters
175+
----------
176+
pairs:
177+
Sequence of ``(query, document)`` tuples.
178+
max_length:
179+
Tokenizer truncation length.
180+
batch_size:
181+
GPU forward-pass batch size.
182+
183+
Returns
184+
-------
185+
List[float]
186+
Raw logit scores (higher = more relevant).
187+
"""
188+
import torch
189+
190+
if not pairs:
191+
return []
192+
193+
texts = [_prompt_template(q, d) for q, d in pairs]
194+
all_scores: List[float] = []
195+
196+
with torch.inference_mode():
197+
for start in range(0, len(texts), batch_size):
198+
chunk = texts[start : start + batch_size]
199+
batch = self._tokenizer(
200+
chunk,
201+
padding=True,
202+
truncation=True,
203+
return_tensors="pt",
204+
max_length=max_length,
205+
)
206+
batch = {k: v.to(self._device) for k, v in batch.items()}
207+
logits = self._model(**batch).logits
208+
all_scores.extend(logits.view(-1).cpu().tolist())
209+
210+
return all_scores

0 commit comments

Comments
 (0)