Skip to content

Commit c85e9ea

Browse files
committed
Make score fast
1 parent 2d6bf77 commit c85e9ea

File tree

6 files changed

+152
-138
lines changed

6 files changed

+152
-138
lines changed

bergson/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from .builders import (
44
Builder,
5-
InMemorySequenceBuilder,
6-
InMemoryTokenBuilder,
75
create_builder,
86
)
97
from .collection import collect_gradients
@@ -40,8 +38,6 @@
4038
"load_token_gradients",
4139
"TokenGradients",
4240
"Builder",
43-
"InMemorySequenceBuilder",
44-
"InMemoryTokenBuilder",
4541
"create_builder",
4642
"fit_normalizers",
4743
"Attributor",

bergson/__main__.py

Lines changed: 6 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import shutil
2-
from copy import deepcopy
31
from dataclasses import dataclass
4-
from pathlib import Path
52
from typing import Optional, Union
63

74
from simple_parsing import ArgumentParser, ConflictResolution
@@ -17,27 +14,11 @@
1714
TrackstarConfig,
1815
)
1916
from .hessians.hessian_approximations import approximate_hessians
20-
from .process_grads import mix_preconditioners
2117
from .query.query_index import query
2218
from .reduce import reduce
2319
from .score.score import score_dataset
24-
25-
26-
def validate_run_path(index_cfg: IndexConfig):
27-
"""Validate the run path."""
28-
if index_cfg.distributed.rank != 0:
29-
return
30-
31-
for path in [Path(index_cfg.run_path), Path(index_cfg.partial_run_path)]:
32-
if not path.exists():
33-
continue
34-
35-
if index_cfg.overwrite:
36-
shutil.rmtree(path)
37-
else:
38-
raise FileExistsError(
39-
f"Run path {path} already exists. Use --overwrite to overwrite it."
40-
)
20+
from .trackstar import trackstar
21+
from .utils.worker_utils import validate_run_path
4122

4223

4324
@dataclass
@@ -150,70 +131,17 @@ class Trackstar:
150131

151132
index_cfg: IndexConfig
152133

153-
trackstar_cfg: TrackstarConfig
154-
155134
score_cfg: ScoreConfig
156135

157136
preprocess_cfg: PreprocessConfig
158137

138+
trackstar_cfg: TrackstarConfig
139+
159140
def execute(self):
160-
"""Run the full trackstar pipeline: preconditioners -> mix -> build -> score."""
161-
run_path = self.index_cfg.run_path
162-
value_precond_path = f"{run_path}/value_preconditioner"
163-
query_precond_path = f"{run_path}/query_preconditioner"
164-
mixed_precond_path = f"{run_path}/mixed_preconditioner"
165-
query_path = f"{run_path}/query"
166-
scores_path = f"{run_path}/scores"
167-
168-
# Step 1: Compute normalizers and preconditioners on value dataset
169-
print("Step 1/5: Computing normalizers and preconditioners on value dataset...")
170-
value_precond_cfg = deepcopy(self.index_cfg)
171-
value_precond_cfg.run_path = value_precond_path
172-
value_precond_cfg.skip_index = True
173-
value_precond_cfg.skip_preconditioners = False
174-
validate_run_path(value_precond_cfg)
175-
build(value_precond_cfg, self.preprocess_cfg)
176-
177-
# Step 2: Compute normalizers and preconditioners on query dataset
178-
print("Step 2/5: Computing normalizers and preconditioners on query dataset...")
179-
query_precond_cfg = deepcopy(self.index_cfg)
180-
query_precond_cfg.run_path = query_precond_path
181-
query_precond_cfg.data = self.trackstar_cfg.query
182-
query_precond_cfg.skip_index = True
183-
query_precond_cfg.skip_preconditioners = False
184-
validate_run_path(query_precond_cfg)
185-
build(query_precond_cfg, self.preprocess_cfg)
186-
187-
# Step 3: Mix query and value preconditioners
188-
print("Step 3/5: Mixing preconditioners...")
189-
mix_preconditioners(
190-
query_path=query_precond_path,
191-
index_path=value_precond_path,
192-
output_path=mixed_precond_path,
193-
mixing_coefficient=self.trackstar_cfg.mixing_coefficient,
141+
trackstar(
142+
self.index_cfg, self.score_cfg, self.preprocess_cfg, self.trackstar_cfg
194143
)
195144

196-
# Step 4: Build per-item query gradient index
197-
print("Step 4/5: Building query gradient index...")
198-
query_cfg = deepcopy(self.index_cfg)
199-
query_cfg.run_path = query_path
200-
query_cfg.data = self.trackstar_cfg.query
201-
query_cfg.processor_path = query_precond_path
202-
query_cfg.skip_preconditioners = True
203-
validate_run_path(query_cfg)
204-
build(query_cfg, self.preprocess_cfg)
205-
206-
# Step 5: Score value dataset against query using mixed preconditioner
207-
print("Step 5/5: Scoring value dataset...")
208-
score_index_cfg = deepcopy(self.index_cfg)
209-
score_index_cfg.run_path = scores_path
210-
score_index_cfg.processor_path = value_precond_path
211-
score_index_cfg.skip_preconditioners = True
212-
self.score_cfg.query_path = query_path
213-
self.preprocess_cfg.preconditioner_path = mixed_precond_path
214-
validate_run_path(score_index_cfg)
215-
score_dataset(score_index_cfg, self.score_cfg, self.preprocess_cfg)
216-
217145

218146
@dataclass
219147
class Main:

bergson/builders.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,8 @@ def __init__(
8080
grad_sizes: dict[str, int],
8181
dtype: torch.dtype,
8282
*,
83-
attribute_tokens: bool = False,
84-
path: Path | None = None,
85-
reduce_cfg: ReduceConfig | None = None,
86-
preprocess_cfg: PreprocessConfig | None = None,
83+
path: Path,
8784
):
88-
assert path is not None
8985
self.grad_sizes = grad_sizes
9086
self.num_items = len(data)
9187
np_dtype = convert_dtype_to_np(dtype)
@@ -157,8 +153,6 @@ def __init__(
157153
grad_sizes: dict[str, int],
158154
dtype: torch.dtype,
159155
*,
160-
attribute_tokens: bool = False,
161-
path: Path | None = None,
162156
reduce_cfg: ReduceConfig | None = None,
163157
preprocess_cfg: PreprocessConfig | None = None,
164158
):
@@ -293,11 +287,6 @@ def __init__(
293287
data: Dataset,
294288
grad_sizes: dict[str, int],
295289
dtype: torch.dtype,
296-
*,
297-
attribute_tokens: bool = False,
298-
path: Path | None = None,
299-
reduce_cfg: ReduceConfig | None = None,
300-
preprocess_cfg: PreprocessConfig | None = None,
301290
):
302291
self.grad_sizes = grad_sizes
303292
self.num_items = len(data)
@@ -356,12 +345,10 @@ def __init__(
356345
grad_sizes: dict[str, int],
357346
dtype: torch.dtype,
358347
*,
359-
attribute_tokens: bool = False,
360-
path: Path | None = None,
348+
path: Path,
361349
reduce_cfg: ReduceConfig | None = None,
362350
preprocess_cfg: PreprocessConfig | None = None,
363351
):
364-
assert path is not None
365352
self.grad_sizes = grad_sizes
366353
self.num_items = len(data)
367354
self.reduce_cfg = reduce_cfg
@@ -484,16 +471,22 @@ def create_builder(
484471
* no ``path`` → :class:`InMemorySequenceBuilder`
485472
"""
486473
if path is not None:
487-
cls = TokenBuilder if attribute_tokens else SequenceBuilder
488-
else:
489-
cls = InMemoryTokenBuilder if attribute_tokens else InMemorySequenceBuilder
490-
491-
return cls(
474+
if attribute_tokens:
475+
return TokenBuilder(data, grad_sizes, dtype, path=path)
476+
return SequenceBuilder(
477+
data,
478+
grad_sizes,
479+
dtype,
480+
path=path,
481+
reduce_cfg=reduce_cfg,
482+
preprocess_cfg=preprocess_cfg,
483+
)
484+
if attribute_tokens:
485+
return InMemoryTokenBuilder(data, grad_sizes, dtype)
486+
return InMemorySequenceBuilder(
492487
data,
493488
grad_sizes,
494489
dtype,
495-
attribute_tokens=attribute_tokens,
496-
path=path,
497490
reduce_cfg=reduce_cfg,
498491
preprocess_cfg=preprocess_cfg,
499492
)

bergson/score/scorer.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,6 @@
44
from bergson.score.score_writer import ScoreWriter
55

66

7-
@torch.compile(fullgraph=True)
8-
def _cosine_score(
9-
index_grads: torch.Tensor,
10-
query_grads_t: torch.Tensor,
11-
) -> torch.Tensor:
12-
"""Matmul + unit normalization."""
13-
scores = index_grads @ query_grads_t
14-
i_norm = index_grads.pow(2).sum(dim=1).sqrt().clamp_min_(1e-12).unsqueeze(1)
15-
scores.div_(i_norm)
16-
return scores
17-
18-
197
class Scorer:
208
"""
219
Scores training gradients against query gradients.
@@ -80,17 +68,25 @@ def __init__(
8068
self.writer = writer
8169

8270
# Load preconditioner: H^(-1/2) for split, H^(-1) for one-sided
83-
self.preconditioners = get_trackstar_preconditioner(
71+
preconditioners = get_trackstar_preconditioner(
8472
preconditioner_path,
8573
device=device,
8674
power=-0.5 if unit_normalize else -1,
8775
return_dtype=dtype,
8876
)
77+
78+
# Stack preconditioners for batched matmul in score().
79+
# Shape: [n_modules, dim_per_mod, dim_per_mod]
80+
if preconditioners and unit_normalize:
81+
self.precond_stack = torch.stack([preconditioners[m] for m in modules])
82+
else:
83+
self.precond_stack = None
84+
8985
# Precondition query grads per module, then cat into a single tensor
90-
if self.preconditioners:
86+
if preconditioners:
9187
q_list = [
9288
query_grads[m].to(device=self.device, dtype=self.dtype)
93-
@ self.preconditioners[m]
89+
@ preconditioners[m]
9490
for m in modules
9591
]
9692
else:
@@ -112,27 +108,34 @@ def __call__(
112108
@torch.inference_mode()
113109
def score(self, index_grads: dict[str, torch.Tensor]) -> torch.Tensor:
114110
"""Compute scores for a batch of gradients."""
115-
# Device transfer and (optionally split) preconditioning of index grads.
116-
# One-sided mode (unit_normalize=False) only preconditions the query.
117-
i_list = []
118-
for m in self.modules:
119-
g = index_grads[m].to(self.device, self.dtype, non_blocking=True)
120-
if (
121-
self.unit_normalize
122-
and self.preconditioners
123-
and m in self.preconditioners
124-
):
125-
g = g @ self.preconditioners[m]
126-
i_list.append(g)
127-
128-
all_index = torch.cat(i_list, dim=-1)
111+
if self.precond_stack is not None:
112+
# Batched preconditioning: [batch, n_modules, dim] @ [n_modules, dim, dim]
113+
g = torch.stack(
114+
[
115+
index_grads[m].to(self.device, self.dtype, non_blocking=True)
116+
for m in self.modules
117+
],
118+
dim=1,
119+
)
120+
all_index = (
121+
torch.bmm(g.permute(1, 0, 2), self.precond_stack)
122+
.permute(1, 0, 2)
123+
.reshape(g.shape[0], -1)
124+
)
125+
else:
126+
all_index = torch.cat(
127+
[
128+
index_grads[m].to(self.device, self.dtype, non_blocking=True)
129+
for m in self.modules
130+
],
131+
dim=-1,
132+
)
133+
134+
scores = all_index @ self.query_grads_t
129135

130136
if self.unit_normalize:
131-
scores = _cosine_score(all_index, self.query_grads_t)
132-
else:
133-
# Compiled score adds overhead for dot-product-only
134-
# where the single matmul is already fast.
135-
scores = all_index @ self.query_grads_t
137+
i_norm = all_index.pow(2).sum(dim=1).sqrt().clamp_min_(1e-12).unsqueeze(1)
138+
scores.div_(i_norm)
136139

137140
if self.score_mode == "nearest":
138141
return scores.max(dim=-1).values

bergson/trackstar.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from copy import deepcopy
2+
3+
from .build import build
4+
from .config import (
5+
IndexConfig,
6+
PreprocessConfig,
7+
ScoreConfig,
8+
TrackstarConfig,
9+
)
10+
from .process_grads import mix_preconditioners
11+
from .score.score import score_dataset
12+
from .utils.worker_utils import validate_run_path
13+
14+
15+
def trackstar(
16+
index_cfg: IndexConfig,
17+
score_cfg: ScoreConfig,
18+
preprocess_cfg: PreprocessConfig,
19+
trackstar_cfg: TrackstarConfig,
20+
):
21+
"""Run the full trackstar pipeline: preconditioners -> mix -> build -> score."""
22+
run_path = index_cfg.run_path
23+
value_precond_path = f"{run_path}/value_preconditioner"
24+
query_precond_path = f"{run_path}/query_preconditioner"
25+
mixed_precond_path = f"{run_path}/mixed_preconditioner"
26+
query_path = f"{run_path}/query"
27+
scores_path = f"{run_path}/scores"
28+
29+
# Step 1: Compute normalizers and preconditioners on value dataset
30+
print("Step 1/5: Computing normalizers and preconditioners on value dataset...")
31+
value_precond_cfg = deepcopy(index_cfg)
32+
value_precond_cfg.run_path = value_precond_path
33+
value_precond_cfg.skip_index = True
34+
value_precond_cfg.skip_preconditioners = False
35+
validate_run_path(value_precond_cfg)
36+
build(value_precond_cfg, preprocess_cfg)
37+
38+
# Step 2: Compute normalizers and preconditioners on query dataset
39+
print("Step 2/5: Computing normalizers and preconditioners on query dataset...")
40+
query_precond_cfg = deepcopy(index_cfg)
41+
query_precond_cfg.run_path = query_precond_path
42+
query_precond_cfg.data = trackstar_cfg.query
43+
query_precond_cfg.skip_index = True
44+
query_precond_cfg.skip_preconditioners = False
45+
validate_run_path(query_precond_cfg)
46+
build(query_precond_cfg, preprocess_cfg)
47+
48+
# Step 3: Mix query and value preconditioners
49+
print("Step 3/5: Mixing preconditioners...")
50+
mix_preconditioners(
51+
query_path=query_precond_path,
52+
index_path=value_precond_path,
53+
output_path=mixed_precond_path,
54+
mixing_coefficient=trackstar_cfg.mixing_coefficient,
55+
)
56+
57+
# Step 4: Build per-item query gradient index
58+
print("Step 4/5: Building query gradient index...")
59+
query_cfg = deepcopy(index_cfg)
60+
query_cfg.run_path = query_path
61+
query_cfg.data = trackstar_cfg.query
62+
query_cfg.processor_path = query_precond_path
63+
query_cfg.skip_preconditioners = True
64+
validate_run_path(query_cfg)
65+
build(query_cfg, preprocess_cfg)
66+
67+
# Step 5: Score value dataset against query using mixed preconditioner
68+
print("Step 5/5: Scoring value dataset...")
69+
score_index_cfg = deepcopy(index_cfg)
70+
score_index_cfg.run_path = scores_path
71+
score_index_cfg.processor_path = value_precond_path
72+
score_index_cfg.skip_preconditioners = True
73+
score_cfg.query_path = query_path
74+
preprocess_cfg.preconditioner_path = mixed_precond_path
75+
validate_run_path(score_index_cfg)
76+
score_dataset(score_index_cfg, score_cfg, preprocess_cfg)

0 commit comments

Comments
 (0)