Skip to content

Commit ce6744b

Browse files
committed
perf(reranker): cap passages before MLX scoring
1 parent 6795010 commit ce6744b

4 files changed

Lines changed: 159 additions & 18 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919

2020
### Changed
2121
- Full-vault indexing now embeds chunks in length-sorted batches instead of one file at a time, improving first-run indexing throughput on real Markdown vaults while preserving single-file indexing behavior and the existing SQLite schema.
22+
- The MLX reranker now caps each passage to the first 200 tokens before scoring, reducing warm-query latency on long chunks while preserving the full result preview and `seeklink get` output.
2223

2324
### Fixed
2425
- `seeklink search --rerank-k N` now limits the number of candidates passed to the cross-encoder even when `N` is lower than `--top-k`; the remaining results keep first-stage RRF order.

TODOS.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ ship if and when they become worth the cost.
66
## Search quality and features
77

88
### Cross-encoder performance optimization
9-
The MLX reranker (`Qwen3-Reranker-0.6B`) runs at ~60 ms per pair on M3-class
10-
Apple Silicon, totaling ~1.2–2.7 s for 20 candidates on realistic vault
11-
chunks. Possible reductions:
12-
13-
- Batch inference (process all pairs in one forward pass instead of
14-
sequentially).
15-
- Passage truncation (cap at ~200 tokens for reranking, use full text only
16-
for final display).
9+
The MLX reranker (`Qwen3-Reranker-0.6B`) is still the main warm-query
10+
latency cost on realistic vault chunks. Passage text is now capped before
11+
reranking; remaining possible reductions:
12+
13+
- Hardware-specific batching or sequence-classification reranker probes.
14+
Gate on real blind-test latency because MLX batch throughput depends on
15+
prompt length and padding.
16+
- Better query routing so only ambiguous queries pay the full rerank budget.
1717

1818
### Additional CLI subcommands
1919
Helpers exist inside `seeklink/app.py` but are not exposed on the CLI:

seeklink/reranker.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
passage independently), at the cost of ~60ms per pair.
99
1010
Implementation uses MLX (Apple's native ML framework) which runs on
11-
Metal GPU, achieving ~1.2s for 20 pairs on M3 Air. The model uses a
12-
yes/no prompt format per Qwen3-Reranker's official usage guide:
11+
Metal GPU. Candidate passages are capped before scoring so the reranker
12+
does not spend most of its time on long chunks whose opening text is
13+
already enough for relevance judgments. The model uses a yes/no prompt
14+
format per Qwen3-Reranker's official usage guide:
1315
the model outputs logits for 'yes' and 'no' tokens, and we convert
1416
the yes-probability to a relevance score.
1517
@@ -31,6 +33,7 @@
3133
_DEFAULT_INSTRUCTION = (
3234
"Given a web search query, retrieve relevant passages that answer the query."
3335
)
36+
_MAX_PASSAGE_TOKENS = 200
3437

3538

3639
class Reranker:
@@ -75,10 +78,39 @@ def _ensure_model(self) -> None:
7578
)
7679
self._disabled = True
7780

81+
def _token_list(self, text: str) -> list[int]:
82+
"""Tokenize text into a flat Python list."""
83+
tokens = self._tokenizer.encode(text, return_tensors=None)
84+
if not isinstance(tokens, list):
85+
tokens = tokens.tolist()
86+
if tokens and isinstance(tokens[0], list):
87+
tokens = tokens[0]
88+
return list(tokens)
89+
90+
def _truncate_passage(self, passage: str) -> str:
91+
"""Cap passage text used by the reranker; display text remains full."""
92+
tokens = self._token_list(passage)
93+
if len(tokens) <= _MAX_PASSAGE_TOKENS:
94+
return passage
95+
96+
head = tokens[:_MAX_PASSAGE_TOKENS]
97+
decode = getattr(self._tokenizer, "decode", None)
98+
if decode is not None:
99+
try:
100+
return decode(head, skip_special_tokens=True)
101+
except TypeError:
102+
return decode(head)
103+
except Exception:
104+
logger.debug("Reranker passage decode failed; using char fallback")
105+
106+
# Conservative fallback for unusual tokenizers without decode().
107+
return passage[:1200]
108+
78109
def _score_one(self, query: str, passage: str) -> float:
79110
"""Score a single (query, passage) pair. Returns 0-1 probability."""
80111
import mlx.core as mx
81112

113+
passage = self._truncate_passage(passage)
82114
prompt = (
83115
f"Instruct: {_DEFAULT_INSTRUCTION}\n"
84116
f"Query: {query}\n"
@@ -90,14 +122,8 @@ def _score_one(self, query: str, passage: str) -> float:
90122
)
91123
text += "<think>\n"
92124

93-
tokens = self._tokenizer.encode(text, return_tensors=None)
94-
if isinstance(tokens, list):
95-
input_ids = mx.array([tokens])
96-
else:
97-
input_ids = mx.array(tokens)
98-
if input_ids.ndim == 1:
99-
input_ids = input_ids[None]
100-
125+
tokens = self._token_list(text)
126+
input_ids = mx.array([tokens])
101127
logits = self._model(input_ids)
102128
last_logits = logits[0, -1, :]
103129
mx.eval(last_logits)

tests/test_reranker.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Tests for the MLX reranker wrapper without loading real MLX models."""
2+
3+
from __future__ import annotations
4+
5+
import math
6+
import sys
7+
import types
8+
9+
import numpy as np
10+
import pytest
11+
12+
import seeklink.reranker as reranker_mod
13+
from seeklink.reranker import Reranker
14+
15+
16+
@pytest.fixture
17+
def fake_mlx(monkeypatch):
18+
mlx_module = types.ModuleType("mlx")
19+
core_module = types.ModuleType("mlx.core")
20+
core_module.array = lambda value: np.array(value, dtype=np.int64)
21+
core_module.eval = lambda *args, **kwargs: None
22+
mlx_module.core = core_module
23+
monkeypatch.setitem(sys.modules, "mlx", mlx_module)
24+
monkeypatch.setitem(sys.modules, "mlx.core", core_module)
25+
26+
27+
class FakeTokenizer:
28+
pad_token_id = 0
29+
eos_token_id = 0
30+
31+
def convert_tokens_to_ids(self, token: str) -> int:
32+
return {"yes": 1, "no": 2}[token]
33+
34+
def apply_chat_template(self, messages, tokenize: bool, add_generation_prompt: bool):
35+
assert tokenize is False
36+
assert add_generation_prompt is True
37+
return messages[0]["content"]
38+
39+
def encode(self, text: str, return_tensors=None) -> list[int]:
40+
assert return_tensors is None
41+
if "Document: " not in text:
42+
return [1] * len(text)
43+
passage = text.split("Document: ", 1)[1].split("<think>", 1)[0]
44+
marker = max(1, len(passage))
45+
return [1] * (3 + len(passage)) + [marker]
46+
47+
def decode(self, tokens: list[int], skip_special_tokens: bool = True) -> str:
48+
assert skip_special_tokens is True
49+
return "x" * len(tokens)
50+
51+
52+
class RecordingModel:
53+
def __init__(self, *, fail_all: bool = False):
54+
self.fail_all = fail_all
55+
self.shapes: list[tuple[int, int]] = []
56+
57+
def __call__(self, input_ids):
58+
arr = np.asarray(input_ids)
59+
self.shapes.append(tuple(arr.shape))
60+
if self.fail_all:
61+
raise RuntimeError("fake model failure")
62+
63+
logits = np.zeros((arr.shape[0], arr.shape[1], 3), dtype=np.float32)
64+
for row_index, row in enumerate(arr):
65+
non_padding = np.flatnonzero(row != 0)
66+
last_real = int(non_padding[-1])
67+
marker = float(row[last_real])
68+
logits[row_index, last_real, 1] = marker
69+
logits[row_index, last_real, 2] = 0.0
70+
if last_real != arr.shape[1] - 1:
71+
logits[row_index, -1, 1] = -100.0
72+
logits[row_index, -1, 2] = 100.0
73+
return logits
74+
75+
76+
def _ready_reranker(model: RecordingModel) -> Reranker:
77+
reranker = Reranker()
78+
reranker._model = model
79+
reranker._tokenizer = FakeTokenizer()
80+
reranker._token_yes = 1
81+
reranker._token_no = 2
82+
return reranker
83+
84+
85+
def _sigmoid(value: float) -> float:
86+
return math.exp(value) / (math.exp(value) + 1.0)
87+
88+
89+
def test_rerank_caps_long_passages_before_scoring(fake_mlx, monkeypatch):
90+
monkeypatch.setattr(reranker_mod, "_MAX_PASSAGE_TOKENS", 2)
91+
model = RecordingModel()
92+
reranker = _ready_reranker(model)
93+
94+
scores = reranker.rerank("query", ["abcdef"])
95+
96+
assert scores == pytest.approx([_sigmoid(2)])
97+
assert model.shapes == [(1, 6)]
98+
99+
100+
def test_rerank_keeps_short_passages_intact(fake_mlx, monkeypatch):
101+
monkeypatch.setattr(reranker_mod, "_MAX_PASSAGE_TOKENS", 10)
102+
model = RecordingModel()
103+
reranker = _ready_reranker(model)
104+
105+
scores = reranker.rerank("query", ["abc"])
106+
107+
assert scores == pytest.approx([_sigmoid(3)])
108+
assert model.shapes == [(1, 7)]
109+
110+
111+
def test_rerank_returns_none_when_inference_fails(fake_mlx):
112+
reranker = _ready_reranker(RecordingModel(fail_all=True))
113+
114+
assert reranker.rerank("query", ["passage"]) is None

0 commit comments

Comments
 (0)