Skip to content

Commit 52e8066

Browse files
luciaquirkeclaude
andcommitted
Refactor: shared PreprocessConfig for gradient processing across build, reduce, score
- Create PreprocessConfig with unit_normalize, preconditioner paths, and mixing_coefficient fields extracted from ReduceConfig and ScoreConfig - Fix broken imports in process_grads.py (gradient_processor -> gradients, utils.utils -> utils.math) - Fix accumulate_grads -> aggregate_grads variable name bug in process_grads.py - Add compute_preconditioner() returning H^(-1/2) for unit_normalize or H^(-1) otherwise, and normalize_flat_grad() for flat tensors - Fix data.py import from nonexistent .reduce.process_query_grads - Fix SequenceBuilder/InMemorySequenceBuilder missing h_inv computation and broken normalize_grad calls with undefined device - Add preconditioner support to Scorer for index-side H^(-1/2) application - Thread PreprocessConfig through build, reduce, score, and collection - Fix accumulate_grads -> aggregate_grads in huggingface.py callback - Add tests for compute_preconditioner and Scorer preconditioner support Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 890fa41 commit 52e8066

File tree

14 files changed

+477
-117
lines changed

14 files changed

+477
-117
lines changed

bergson/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
AttentionConfig,
99
DataConfig,
1010
IndexConfig,
11+
PreprocessConfig,
1112
QueryConfig,
1213
ReduceConfig,
1314
ScoreConfig,
@@ -50,6 +51,7 @@
5051
"IndexConfig",
5152
"DataConfig",
5253
"AttentionConfig",
54+
"PreprocessConfig",
5355
"Scorer",
5456
"ScoreConfig",
5557
"ReduceConfig",

bergson/__main__.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .config import (
1111
HessianConfig,
1212
IndexConfig,
13+
PreprocessConfig,
1314
QueryConfig,
1415
ReduceConfig,
1516
ScoreConfig,
@@ -44,14 +45,16 @@ class Build:
4445

4546
index_cfg: IndexConfig
4647

48+
preprocess_cfg: PreprocessConfig
49+
4750
def execute(self):
4851
"""Build the gradient index."""
4952
if self.index_cfg.skip_index and self.index_cfg.skip_preconditioners:
5053
raise ValueError("Either skip_index or skip_preconditioners must be False")
5154

5255
validate_run_path(self.index_cfg)
5356

54-
build(self.index_cfg)
57+
build(self.index_cfg, self.preprocess_cfg)
5558

5659

5760
@dataclass
@@ -60,12 +63,14 @@ class Preconditioners:
6063

6164
index_cfg: IndexConfig
6265

66+
preprocess_cfg: PreprocessConfig
67+
6368
def execute(self):
6469
"""Compute normalizers and preconditioners."""
6570
self.index_cfg.skip_index = True
6671
self.index_cfg.skip_preconditioners = False
6772
validate_run_path(self.index_cfg)
68-
build(self.index_cfg)
73+
build(self.index_cfg, self.preprocess_cfg)
6974

7075

7176
@dataclass
@@ -76,6 +81,8 @@ class Reduce:
7681

7782
reduce_cfg: ReduceConfig
7883

84+
preprocess_cfg: PreprocessConfig
85+
7986
def execute(self):
8087
"""Reduce a gradient index."""
8188
if self.index_cfg.projection_dim != 0:
@@ -85,7 +92,7 @@ def execute(self):
8592

8693
validate_run_path(self.index_cfg)
8794

88-
reduce(self.index_cfg, self.reduce_cfg)
95+
reduce(self.index_cfg, self.reduce_cfg, self.preprocess_cfg)
8996

9097

9198
@dataclass
@@ -96,6 +103,8 @@ class Score:
96103

97104
index_cfg: IndexConfig
98105

106+
preprocess_cfg: PreprocessConfig
107+
99108
def execute(self):
100109
"""Score a dataset against an existing gradient index."""
101110
assert self.score_cfg.query_path
@@ -107,7 +116,7 @@ def execute(self):
107116

108117
validate_run_path(self.index_cfg)
109118

110-
score_dataset(self.index_cfg, self.score_cfg)
119+
score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg)
111120

112121

113122
@dataclass
@@ -144,6 +153,8 @@ class Trackstar:
144153

145154
score_cfg: ScoreConfig
146155

156+
preprocess_cfg: PreprocessConfig
157+
147158
def execute(self):
148159
"""Run the full trackstar pipeline: preconditioners -> build -> score."""
149160
run_path = self.index_cfg.run_path
@@ -159,7 +170,7 @@ def execute(self):
159170
value_precond_cfg.skip_index = True
160171
value_precond_cfg.skip_preconditioners = False
161172
validate_run_path(value_precond_cfg)
162-
build(value_precond_cfg)
173+
build(value_precond_cfg, self.preprocess_cfg)
163174

164175
# Step 2: Compute normalizers and preconditioners on query dataset
165176
print("Step 2/4: Computing normalizers and preconditioners on query dataset...")
@@ -169,7 +180,7 @@ def execute(self):
169180
query_precond_cfg.skip_index = True
170181
query_precond_cfg.skip_preconditioners = False
171182
validate_run_path(query_precond_cfg)
172-
build(query_precond_cfg)
183+
build(query_precond_cfg, self.preprocess_cfg)
173184

174185
# Step 3: Build per-item query gradient index
175186
print("Step 3/4: Building query gradient index...")
@@ -179,7 +190,7 @@ def execute(self):
179190
query_cfg.processor_path = query_precond_path
180191
query_cfg.skip_preconditioners = True
181192
validate_run_path(query_cfg)
182-
build(query_cfg)
193+
build(query_cfg, self.preprocess_cfg)
183194

184195
# Step 4: Score value dataset against query using both preconditioners
185196
print("Step 4/4: Scoring value dataset...")
@@ -188,10 +199,10 @@ def execute(self):
188199
score_index_cfg.processor_path = value_precond_path
189200
score_index_cfg.skip_preconditioners = True
190201
self.score_cfg.query_path = query_path
191-
self.score_cfg.query_preconditioner_path = query_precond_path
192-
self.score_cfg.index_preconditioner_path = value_precond_path
202+
self.preprocess_cfg.query_preconditioner_path = query_precond_path
203+
self.preprocess_cfg.index_preconditioner_path = value_precond_path
193204
validate_run_path(score_index_cfg)
194-
score_dataset(score_index_cfg, self.score_cfg)
205+
score_dataset(score_index_cfg, self.score_cfg, self.preprocess_cfg)
195206

196207

197208
@dataclass

bergson/build.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tqdm.auto import tqdm
1111

1212
from bergson.collection import collect_gradients
13-
from bergson.config import IndexConfig
13+
from bergson.config import IndexConfig, PreprocessConfig
1414
from bergson.data import allocate_batches
1515
from bergson.distributed import launch_distributed_run
1616
from bergson.utils.auto_batch_size import maybe_auto_batch_size
@@ -27,6 +27,7 @@ def build_worker(
2727
local_rank: int,
2828
world_size: int,
2929
cfg: IndexConfig,
30+
preprocess_cfg: PreprocessConfig,
3031
ds: Dataset | IterableDataset,
3132
):
3233
"""
@@ -108,7 +109,7 @@ def flush(kwargs):
108109
processor.save(cfg.partial_run_path)
109110

110111

111-
def build(index_cfg: IndexConfig):
112+
def build(index_cfg: IndexConfig, preprocess_cfg: PreprocessConfig):
112113
"""
113114
Build a gradient index by distributing work across all available GPUs.
114115
@@ -117,6 +118,8 @@ def build(index_cfg: IndexConfig):
117118
index_cfg : IndexConfig
118119
Specifies the run path, dataset, model, tokenizer, PEFT adapters,
119120
and many other gradient collection settings.
121+
preprocess_cfg : PreprocessConfig
122+
Preprocessing configuration for gradient normalization/preconditioning.
120123
"""
121124
if index_cfg.debug:
122125
setup_reproducibility()
@@ -128,7 +131,10 @@ def build(index_cfg: IndexConfig):
128131
ds = setup_data_pipeline(index_cfg)
129132

130133
launch_distributed_run(
131-
"build", build_worker, [index_cfg, ds], index_cfg.distributed
134+
"build",
135+
build_worker,
136+
[index_cfg, preprocess_cfg, ds],
137+
index_cfg.distributed,
132138
)
133139

134140
rank = index_cfg.distributed.rank

bergson/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bergson.collector.collector import CollectorComputer
55
from bergson.collector.gradient_collectors import GradientCollector
6-
from bergson.config import AttentionConfig, IndexConfig, ReduceConfig
6+
from bergson.config import AttentionConfig, IndexConfig, PreprocessConfig, ReduceConfig
77
from bergson.gradients import GradientProcessor
88
from bergson.score.scorer import Scorer
99

@@ -19,6 +19,7 @@ def collect_gradients(
1919
attention_cfgs: dict[str, AttentionConfig] | None = None,
2020
scorer: Scorer | None = None,
2121
reduce_cfg: ReduceConfig | None = None,
22+
preprocess_cfg: PreprocessConfig | None = None,
2223
):
2324
"""
2425
Compute gradients using the hooks specified in the GradientCollector.
@@ -31,6 +32,7 @@ def collect_gradients(
3132
data=data,
3233
scorer=scorer,
3334
reduce_cfg=reduce_cfg,
35+
preprocess_cfg=preprocess_cfg,
3436
attention_cfgs=attention_cfgs or {},
3537
filter_modules=cfg.filter_modules,
3638
)

bergson/collector/gradient_collectors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import Tensor
1111

1212
from bergson.collector.collector import HookCollectorBase
13-
from bergson.config import IndexConfig, ReduceConfig
13+
from bergson.config import IndexConfig, PreprocessConfig, ReduceConfig
1414
from bergson.data import Builder, create_builder
1515
from bergson.gradients import (
1616
AdafactorNormalizer,
@@ -46,6 +46,9 @@ class GradientCollector(HookCollectorBase):
4646
reduce_cfg: ReduceConfig | None = None
4747
"""Configuration for in-run gradient reduction."""
4848

49+
preprocess_cfg: PreprocessConfig | None = None
50+
"""Configuration for gradient preprocessing."""
51+
4952
builder: Builder | None = None
5053
"""Handles writing gradients to disk. Created in setup() if save_index is True."""
5154

@@ -95,6 +98,7 @@ def setup(self) -> None:
9598
attribute_tokens=self.cfg.attribute_tokens,
9699
path=self.cfg.partial_run_path,
97100
reduce_cfg=self.reduce_cfg,
101+
preprocess_cfg=self.preprocess_cfg,
98102
)
99103
else:
100104
self.builder = None

bergson/collector/in_memory_collector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch import Tensor, nn
1212

1313
from bergson.collector.collector import HookCollectorBase
14-
from bergson.config import IndexConfig, ReduceConfig
14+
from bergson.config import IndexConfig, PreprocessConfig, ReduceConfig
1515
from bergson.data import Builder, create_builder
1616
from bergson.gradients import (
1717
AdafactorNormalizer,
@@ -52,6 +52,9 @@ class InMemoryCollector(HookCollectorBase):
5252
reduce_cfg: ReduceConfig | None = None
5353
"""Configuration for in-run gradient reduction."""
5454

55+
preprocess_cfg: PreprocessConfig | None = None
56+
"""Configuration for gradient preprocessing."""
57+
5558
builder: Builder | None = None
5659
"""Handles writing gradients. Created in setup()."""
5760

@@ -109,6 +112,7 @@ def setup(self) -> None:
109112
self.save_dtype,
110113
attribute_tokens=self.cfg.attribute_tokens,
111114
reduce_cfg=self.reduce_cfg,
115+
preprocess_cfg=self.preprocess_cfg,
112116
)
113117

114118
def teardown(self) -> None:

bergson/config.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,23 @@ class QueryConfig:
265265
its top results as rows with columns: query, result, result_index, score."""
266266

267267

268+
@dataclass
269+
class PreprocessConfig:
270+
"""Config for gradient preprocessing, shared across build, reduce, and score."""
271+
272+
unit_normalize: bool = False
273+
"""Whether to unit normalize the gradients."""
274+
275+
query_preconditioner_path: str | None = None
276+
"""Path to a precomputed preconditioner for query gradients."""
277+
278+
index_preconditioner_path: str | None = None
279+
"""Path to a precomputed preconditioner for index gradients."""
280+
281+
mixing_coefficient: float = 0.99
282+
"""Weight for mixing query vs index preconditioner (1.0 = query only)."""
283+
284+
268285
@dataclass
269286
class ScoreConfig:
270287
"""Config for querying an index on the fly."""
@@ -280,25 +297,8 @@ class ScoreConfig:
280297
similar query gradient (the maximum score).
281298
`individual`: compute a separate score for each query gradient."""
282299

283-
query_preconditioner_path: str | None = None
284-
"""Path to a precomputed preconditioner to be applied to
285-
the query dataset gradients."""
286-
287-
index_preconditioner_path: str | None = None
288-
"""Path to a precomputed preconditioner to be applied to
289-
the query dataset gradients. This does not affect the
290-
ability to compute a new preconditioner during the query."""
291-
292-
mixing_coefficient: float = 0.99
293-
"""Coefficient to weight the application of the query preconditioner
294-
and the pre-computed index preconditioner. 0.0 means only use the
295-
index preconditioner and 1.0 means only use the query preconditioner."""
296-
297-
modules: list[str] = field(default_factory=list)
298-
"""Modules to use for the query. If empty, all modules will be used."""
299-
300-
unit_normalize: bool = False
301-
"""Whether to unit normalize the gradients before computing the scores."""
300+
skip_query_preprocess: bool = False
301+
"""Skip query preprocessing if already applied during reduce."""
302302

303303
batch_size: int = 1024
304304
"""Batch size for processing the query dataset."""
@@ -307,16 +307,24 @@ class ScoreConfig:
307307
"""Precision (dtype) to convert the query and index gradients to before
308308
computing the scores. If "auto", the model's gradient dtype is used."""
309309

310+
modules: list[str] = field(default_factory=list)
311+
"""Modules to use for the query. If empty, all modules will be used."""
312+
310313

311314
@dataclass
312315
class ReduceConfig:
313-
"""Config for reducing the gradients."""
316+
"""Config for reducing a dataset into a standalone query."""
314317

315318
method: Literal["mean", "sum"] = "mean"
316319
"""Method for reducing the gradients."""
317320

318-
unit_normalize: bool = False
319-
"""Whether to unit normalize the gradients before reducing them."""
321+
modules: list[str] = field(default_factory=list)
322+
"""Modules to use for the query. If empty, all modules will be used."""
323+
324+
normalize_reduced_grad: bool = False
325+
"""Whether to unit normalize the reduced query gradient. This has
326+
no effect on future score rankings but does affect the magnitude of
327+
the scores."""
320328

321329

322330
@dataclass

0 commit comments

Comments
 (0)