Skip to content

Commit 8f234ae

Browse files
committed
Added extractor inference
1 parent 0751e6b commit 8f234ae

4 files changed

Lines changed: 308 additions & 39 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ wandb/
3333
runs/
3434
checkpoints/
3535
output/
36+
paper/

README.md

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
- Tool output pruner for LLM coding agents
1414
- Pipe any tool output (pytest, grep, git log, npm build, kubectl, ...) through squeez with a task description, get back only the relevant lines
15-
- Fine-tuned Qwen 3.5 2B, 0.80 F1, 92% compression
15+
- Two models, same CLI: a **generative** Qwen 3.5 2B (0.80 F1, 92% compression) or a smaller **extractive** ModernBERT alternative
1616
- CLI pipe, Python library, or vLLM server
1717

1818
Existing context pruning tools ([SWE-Pruner](https://github.com/Ayanami1314/swe-pruner), [Zilliz Semantic Highlight](https://huggingface.co/zilliz/semantic-highlight-bilingual-v1), [Provence](https://arxiv.org/abs/2501.16214)) are built for source code or document paragraphs. They don't handle the mixed, unstructured format of tool output (stack traces interleaved with passing tests, grep matches with context lines, build logs with timestamps). Squeez is trained on 27 types of tool output from real SWE-bench workflows and synthetic multi-ecosystem observations.
@@ -236,41 +236,29 @@ Environment variables:
236236
| `SQUEEZ_LOCAL_MODEL` | Path to local model directory |
237237
| `SQUEEZ_SERVER_MODEL` | Model name on the server |
238238
| `SQUEEZ_API_KEY` | API key (if needed) |
239-
| `SQUEEZ_BACKEND` | Force backend: `transformers`, `vllm`, `encoder` |
239+
| `SQUEEZ_BACKEND` | Force backend (rarely needed; auto-detected from the model) |
240240

241241
</details>
242242

243243
<details>
244-
<summary><b>Encoder models</b></summary>
244+
<summary><b>Use the extractive model instead</b></summary>
245245

246-
Squeez also supports encoder-based extraction (ModernBERT, etc.) as an alternative to the generative model. These are faster but less accurate.
246+
If you don't need the 2B generative model, point squeez at a smaller
247+
extractive one — same CLI, same Python API. Configure once, then use
248+
`squeez` normally:
247249

248-
Two encoder approaches:
249-
- **Token encoder**: per-token binary classification, aggregated per line via max-pool
250-
- **Pooled encoder**: single-pass encoder with line-level mean-pool classification
251-
252-
```python
253-
from squeez.inference.extractor import ToolOutputExtractor
250+
```bash
251+
export SQUEEZ_LOCAL_MODEL=KRLabsOrg/verbatim-rag-modern-bert-v2
254252

255-
extractor = ToolOutputExtractor(model_path="./output/squeez_encoder")
256-
filtered = extractor.extract(task="Find the bug", tool_output=raw_output)
253+
pytest -q 2>&1 | squeez "find the failing test"
254+
git log --oneline -50 | squeez "find the auth commit"
257255
```
258256

259-
Standalone loading without squeez installed:
257+
`KRLabsOrg/verbatim-rag-modern-bert-v2` is a 150M ModernBERT span model
258+
trained on a multi-domain mix that includes Squeez tool-output. See
259+
[RESULTS.md](RESULTS.md) for the head-to-head with Squeez-2B.
260260

261-
```python
262-
from transformers import AutoModel, AutoTokenizer
263-
264-
model = AutoModel.from_pretrained("output/squeez_pooled", trust_remote_code=True)
265-
tokenizer = AutoTokenizer.from_pretrained("output/squeez_pooled")
266-
267-
result = model.process(
268-
task="Find the traceback",
269-
tool_output=open("output.log").read(),
270-
tokenizer=tokenizer,
271-
)
272-
print(result["highlighted_lines"])
273-
```
261+
To train your own extractive model, see [TRAINING.md](TRAINING.md).
274262

275263
</details>
276264

scripts/evaluate_baselines.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
logger = logging.getLogger(__name__)
4444

4545
ALL_NAIVE = ["random", "first_n", "last_n", "bm25"]
46-
ALL_MODEL = ["swe_pruner", "zilliz", "gliner2"]
46+
ALL_MODEL = ["swe_pruner", "zilliz", "gliner2", "verbatim_v2"]
4747
ALL_BASELINES = ALL_NAIVE + ALL_MODEL
4848

4949

@@ -272,6 +272,75 @@ def baseline_zilliz(model, task: str, tool_output: str, threshold: float = 0.5)
272272
return kept
273273

274274

275+
def _load_verbatim_v2(model_name: str = "KRLabsOrg/verbatim-rag-modern-bert-v2"):
276+
"""Load Verbatim-RAG ModernBERT v2 (needs: transformers + trust_remote_code).
277+
278+
Device selection:
279+
- CUDA when available (intended path on the eval GPU node)
280+
- CPU otherwise. We skip MPS by default because the long-context tool-output
281+
forward pass routinely bumps into Metal's per-buffer size cap. Set
282+
``SQUEEZ_VERBATIM_DEVICE=mps`` (with ``PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0``)
283+
to force MPS anyway.
284+
"""
285+
import os
286+
287+
import torch
288+
from transformers import AutoModel
289+
290+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
291+
forced = os.environ.get("SQUEEZ_VERBATIM_DEVICE")
292+
if forced:
293+
device = forced
294+
elif torch.cuda.is_available():
295+
device = "cuda"
296+
else:
297+
device = "cpu"
298+
model = model.to(device)
299+
model.eval()
300+
return model
301+
302+
303+
def baseline_verbatim_v2(
304+
model,
305+
task: str,
306+
tool_output: str,
307+
threshold: float = 0.1,
308+
min_span_chars: int = 10,
309+
merge_gap_chars: int = 20,
310+
) -> list[str]:
311+
"""Verbatim-RAG ModernBERT v2 — keep any line touched by an extracted span.
312+
313+
Defaults to the recall-tuned config (threshold=0.1, min_span_chars=10) which
314+
handles short structured answers (file paths, line numbers, log lines)
315+
common in tool output. The model card documents this as the recommended
316+
config for technical / structured content.
317+
"""
318+
if not tool_output:
319+
return []
320+
result = model.process(
321+
question=task,
322+
context=tool_output,
323+
threshold=threshold,
324+
min_span_chars=min_span_chars,
325+
merge_gap_chars=merge_gap_chars,
326+
)
327+
spans = result.get("spans", [])
328+
if not spans:
329+
return []
330+
lines = tool_output.split("\n")
331+
line_offsets, pos = [], 0
332+
for line in lines:
333+
line_offsets.append((pos, pos + len(line)))
334+
pos += len(line) + 1
335+
kept_indices: set[int] = set()
336+
for sp in spans:
337+
a, b = sp["start"], sp["end"]
338+
for i, (lo, hi) in enumerate(line_offsets):
339+
if not (b <= lo or a >= hi):
340+
kept_indices.add(i)
341+
return [lines[i] for i in sorted(kept_indices) if lines[i].strip()]
342+
343+
275344
def _load_gliner2():
276345
"""Load GLiNER2 model (needs: pip install gliner2)."""
277346
from gliner2 import GLiNER2
@@ -507,6 +576,22 @@ def main():
507576
except Exception as e:
508577
logger.error(f"GLiNER2 failed: {e}")
509578

579+
if "verbatim_v2" in baselines:
580+
logger.info("Loading Verbatim-RAG ModernBERT v2...")
581+
try:
582+
model = _load_verbatim_v2()
583+
logger.info("Running: verbatim_v2")
584+
results.append(
585+
evaluate_baseline(
586+
"Verbatim-RAG ModernBERT v2",
587+
baseline_verbatim_v2,
588+
samples,
589+
model=model,
590+
)
591+
)
592+
except Exception as e:
593+
logger.error(f"verbatim_v2 failed: {type(e).__name__}: {e}")
594+
510595
# Print and save
511596
print_results(results)
512597

0 commit comments

Comments
 (0)