|
43 | 43 | logger = logging.getLogger(__name__) |
44 | 44 |
|
45 | 45 | ALL_NAIVE = ["random", "first_n", "last_n", "bm25"] |
46 | | -ALL_MODEL = ["swe_pruner", "zilliz", "gliner2"] |
| 46 | +ALL_MODEL = ["swe_pruner", "zilliz", "gliner2", "verbatim_v2"] |
47 | 47 | ALL_BASELINES = ALL_NAIVE + ALL_MODEL |
48 | 48 |
|
49 | 49 |
|
@@ -272,6 +272,75 @@ def baseline_zilliz(model, task: str, tool_output: str, threshold: float = 0.5) |
272 | 272 | return kept |
273 | 273 |
|
274 | 274 |
|
| 275 | +def _load_verbatim_v2(model_name: str = "KRLabsOrg/verbatim-rag-modern-bert-v2"): |
| 276 | + """Load Verbatim-RAG ModernBERT v2 (needs: transformers + trust_remote_code). |
| 277 | +
|
| 278 | + Device selection: |
| 279 | + - CUDA when available (intended path on the eval GPU node) |
| 280 | + - CPU otherwise. We skip MPS by default because the long-context tool-output |
| 281 | + forward pass routinely bumps into Metal's per-buffer size cap. Set |
| 282 | + ``SQUEEZ_VERBATIM_DEVICE=mps`` (with ``PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0``) |
| 283 | + to force MPS anyway. |
| 284 | + """ |
| 285 | + import os |
| 286 | + |
| 287 | + import torch |
| 288 | + from transformers import AutoModel |
| 289 | + |
| 290 | + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
| 291 | + forced = os.environ.get("SQUEEZ_VERBATIM_DEVICE") |
| 292 | + if forced: |
| 293 | + device = forced |
| 294 | + elif torch.cuda.is_available(): |
| 295 | + device = "cuda" |
| 296 | + else: |
| 297 | + device = "cpu" |
| 298 | + model = model.to(device) |
| 299 | + model.eval() |
| 300 | + return model |
| 301 | + |
| 302 | + |
| 303 | +def baseline_verbatim_v2( |
| 304 | + model, |
| 305 | + task: str, |
| 306 | + tool_output: str, |
| 307 | + threshold: float = 0.1, |
| 308 | + min_span_chars: int = 10, |
| 309 | + merge_gap_chars: int = 20, |
| 310 | +) -> list[str]: |
| 311 | + """Verbatim-RAG ModernBERT v2 — keep any line touched by an extracted span. |
| 312 | +
|
| 313 | + Defaults to the recall-tuned config (threshold=0.1, min_span_chars=10) which |
| 314 | + handles short structured answers (file paths, line numbers, log lines) |
| 315 | + common in tool output. The model card documents this as the recommended |
| 316 | + config for technical / structured content. |
| 317 | + """ |
| 318 | + if not tool_output: |
| 319 | + return [] |
| 320 | + result = model.process( |
| 321 | + question=task, |
| 322 | + context=tool_output, |
| 323 | + threshold=threshold, |
| 324 | + min_span_chars=min_span_chars, |
| 325 | + merge_gap_chars=merge_gap_chars, |
| 326 | + ) |
| 327 | + spans = result.get("spans", []) |
| 328 | + if not spans: |
| 329 | + return [] |
| 330 | + lines = tool_output.split("\n") |
| 331 | + line_offsets, pos = [], 0 |
| 332 | + for line in lines: |
| 333 | + line_offsets.append((pos, pos + len(line))) |
| 334 | + pos += len(line) + 1 |
| 335 | + kept_indices: set[int] = set() |
| 336 | + for sp in spans: |
| 337 | + a, b = sp["start"], sp["end"] |
| 338 | + for i, (lo, hi) in enumerate(line_offsets): |
| 339 | + if not (b <= lo or a >= hi): |
| 340 | + kept_indices.add(i) |
| 341 | + return [lines[i] for i in sorted(kept_indices) if lines[i].strip()] |
| 342 | + |
| 343 | + |
275 | 344 | def _load_gliner2(): |
276 | 345 | """Load GLiNER2 model (needs: pip install gliner2).""" |
277 | 346 | from gliner2 import GLiNER2 |
@@ -507,6 +576,22 @@ def main(): |
507 | 576 | except Exception as e: |
508 | 577 | logger.error(f"GLiNER2 failed: {e}") |
509 | 578 |
|
| 579 | + if "verbatim_v2" in baselines: |
| 580 | + logger.info("Loading Verbatim-RAG ModernBERT v2...") |
| 581 | + try: |
| 582 | + model = _load_verbatim_v2() |
| 583 | + logger.info("Running: verbatim_v2") |
| 584 | + results.append( |
| 585 | + evaluate_baseline( |
| 586 | + "Verbatim-RAG ModernBERT v2", |
| 587 | + baseline_verbatim_v2, |
| 588 | + samples, |
| 589 | + model=model, |
| 590 | + ) |
| 591 | + ) |
| 592 | + except Exception as e: |
| 593 | + logger.error(f"verbatim_v2 failed: {type(e).__name__}: {e}") |
| 594 | + |
510 | 595 | # Print and save |
511 | 596 | print_results(results) |
512 | 597 |
|
|
0 commit comments