Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
46ae85b
Refactor: shared PreprocessConfig for gradient processing across buil…
luciaquirke Feb 24, 2026
2c3fc4b
Add gradient preprocessing docs page with case studies
github-actions[bot] Feb 24, 2026
6f079d9
Convert preprocessing docs to RST for Sphinx/ReadTheDocs compatibility
github-actions[bot] Feb 24, 2026
e2ab453
Move preprocessing and experiments docs above API reference
github-actions[bot] Feb 24, 2026
6ecd74d
Fix optimizer normalizer description in preprocessing docs
github-actions[bot] Feb 24, 2026
74c1fd1
Fix wording in preprocessing docs per review feedback
github-actions[bot] Feb 24, 2026
e1e9909
Extract preconditioner mixing into standalone mix_preconditioners fun…
luciaquirke Feb 26, 2026
add3d62
Apply preconditioners after concatenation to reduce VRAM usage
luciaquirke Feb 26, 2026
779cf4a
Address last three review comments
github-actions[bot] Feb 26, 2026
a2a534d
Enable both preconditioners
luciaquirke Feb 26, 2026
867eedd
update
luciaquirke Feb 26, 2026
9ff9574
Move all preconditioning logic into Scorer
luciaquirke Feb 26, 2026
7af663c
Address last three review comments
luciaquirke Feb 26, 2026
098ea2e
unify inverse computation
luciaquirke Feb 28, 2026
80d76ee
remove gradient preprocess during score
luciaquirke Feb 28, 2026
0eae5d9
test split precond
luciaquirke Feb 28, 2026
86105f4
fix bfloat16 grad buffer conversion in InMemoryCollector
luciaquirke Feb 28, 2026
b52923b
rename dist_reduce to teardown; compile gradient processing; move bui…
luciaquirke Mar 1, 2026
6a14e53
fix: standardize trace collector preconditioning
luciaquirke Mar 1, 2026
3b3f17a
Make score fast
luciaquirke Mar 2, 2026
2dd26d3
feat: enable trackstar
luciaquirke Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions bergson/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
__version__ = "0.6.2"

from .builders import (
Builder,
create_builder,
)
from .collection import collect_gradients
from .collector.collector import CollectorComputer
from .collector.gradient_collectors import GradientCollector
Expand All @@ -8,22 +12,20 @@
AttentionConfig,
DataConfig,
IndexConfig,
PreprocessConfig,
QueryConfig,
ReduceConfig,
ScoreConfig,
)
from .data import (
Builder,
InMemorySequenceBuilder,
InMemoryTokenBuilder,
TokenGradients,
create_builder,
load_gradient_dataset,
load_gradients,
load_token_gradients,
)
from .gradients import GradientProcessor
from .normalizer.fit_normalizers import fit_normalizers
from .process_grads import mix_preconditioners
from .query.attributor import Attributor
from .query.faiss_index import FaissConfig
from .score.scorer import Scorer
Expand All @@ -36,8 +38,6 @@
"load_token_gradients",
"TokenGradients",
"Builder",
"InMemorySequenceBuilder",
"InMemoryTokenBuilder",
"create_builder",
"fit_normalizers",
"Attributor",
Expand All @@ -50,8 +50,10 @@
"IndexConfig",
"DataConfig",
"AttentionConfig",
"PreprocessConfig",
"Scorer",
"ScoreConfig",
"ReduceConfig",
"QueryConfig",
"mix_preconditioners",
]
95 changes: 22 additions & 73 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import shutil
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

from simple_parsing import ArgumentParser, ConflictResolution
Expand All @@ -10,6 +7,7 @@
from .config import (
HessianConfig,
IndexConfig,
PreprocessConfig,
QueryConfig,
ReduceConfig,
ScoreConfig,
Expand All @@ -19,23 +17,8 @@
from .query.query_index import query
from .reduce import reduce
from .score.score import score_dataset


def validate_run_path(index_cfg: IndexConfig):
"""Validate the run path."""
if index_cfg.distributed.rank != 0:
return

for path in [Path(index_cfg.run_path), Path(index_cfg.partial_run_path)]:
if not path.exists():
continue

if index_cfg.overwrite:
shutil.rmtree(path)
else:
raise FileExistsError(
f"Run path {path} already exists. Use --overwrite to overwrite it."
)
from .trackstar import trackstar
from .utils.worker_utils import validate_run_path


@dataclass
Expand All @@ -44,14 +27,16 @@ class Build:

index_cfg: IndexConfig

preprocess_cfg: PreprocessConfig

def execute(self):
"""Build the gradient index."""
if self.index_cfg.skip_index and self.index_cfg.skip_preconditioners:
raise ValueError("Either skip_index or skip_preconditioners must be False")

validate_run_path(self.index_cfg)

build(self.index_cfg)
build(self.index_cfg, self.preprocess_cfg)


@dataclass
Expand All @@ -60,12 +45,14 @@ class Preconditioners:

index_cfg: IndexConfig

preprocess_cfg: PreprocessConfig

def execute(self):
"""Compute normalizers and preconditioners."""
self.index_cfg.skip_index = True
self.index_cfg.skip_preconditioners = False
validate_run_path(self.index_cfg)
build(self.index_cfg)
build(self.index_cfg, self.preprocess_cfg)


@dataclass
Expand All @@ -76,6 +63,8 @@ class Reduce:

reduce_cfg: ReduceConfig

preprocess_cfg: PreprocessConfig

def execute(self):
"""Reduce a gradient index."""
if self.index_cfg.projection_dim != 0:
Expand All @@ -85,7 +74,7 @@ def execute(self):

validate_run_path(self.index_cfg)

reduce(self.index_cfg, self.reduce_cfg)
reduce(self.index_cfg, self.reduce_cfg, self.preprocess_cfg)


@dataclass
Expand All @@ -96,6 +85,8 @@ class Score:

index_cfg: IndexConfig

preprocess_cfg: PreprocessConfig

def execute(self):
"""Score a dataset against an existing gradient index."""
assert self.score_cfg.query_path
Expand All @@ -107,7 +98,7 @@ def execute(self):

validate_run_path(self.index_cfg)

score_dataset(self.index_cfg, self.score_cfg)
score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg)


@dataclass
Expand Down Expand Up @@ -140,58 +131,16 @@ class Trackstar:

index_cfg: IndexConfig

trackstar_cfg: TrackstarConfig

score_cfg: ScoreConfig

preprocess_cfg: PreprocessConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would outsource this to a trackstar.py script


trackstar_cfg: TrackstarConfig

def execute(self):
"""Run the full trackstar pipeline: preconditioners -> build -> score."""
run_path = self.index_cfg.run_path
value_precond_path = f"{run_path}/value_preconditioner"
query_precond_path = f"{run_path}/query_preconditioner"
query_path = f"{run_path}/query"
scores_path = f"{run_path}/scores"

# Step 1: Compute normalizers and preconditioners on value dataset
print("Step 1/4: Computing normalizers and preconditioners on value dataset...")
value_precond_cfg = deepcopy(self.index_cfg)
value_precond_cfg.run_path = value_precond_path
value_precond_cfg.skip_index = True
value_precond_cfg.skip_preconditioners = False
validate_run_path(value_precond_cfg)
build(value_precond_cfg)

# Step 2: Compute normalizers and preconditioners on query dataset
print("Step 2/4: Computing normalizers and preconditioners on query dataset...")
query_precond_cfg = deepcopy(self.index_cfg)
query_precond_cfg.run_path = query_precond_path
query_precond_cfg.data = self.trackstar_cfg.query
query_precond_cfg.skip_index = True
query_precond_cfg.skip_preconditioners = False
validate_run_path(query_precond_cfg)
build(query_precond_cfg)

# Step 3: Build per-item query gradient index
print("Step 3/4: Building query gradient index...")
query_cfg = deepcopy(self.index_cfg)
query_cfg.run_path = query_path
query_cfg.data = self.trackstar_cfg.query
query_cfg.processor_path = query_precond_path
query_cfg.skip_preconditioners = True
validate_run_path(query_cfg)
build(query_cfg)

# Step 4: Score value dataset against query using both preconditioners
print("Step 4/4: Scoring value dataset...")
score_index_cfg = deepcopy(self.index_cfg)
score_index_cfg.run_path = scores_path
score_index_cfg.processor_path = value_precond_path
score_index_cfg.skip_preconditioners = True
self.score_cfg.query_path = query_path
self.score_cfg.query_preconditioner_path = query_precond_path
self.score_cfg.index_preconditioner_path = value_precond_path
validate_run_path(score_index_cfg)
score_dataset(score_index_cfg, self.score_cfg)
trackstar(
self.index_cfg, self.score_cfg, self.preprocess_cfg, self.trackstar_cfg
)


@dataclass
Expand Down
12 changes: 9 additions & 3 deletions bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm.auto import tqdm

from bergson.collection import collect_gradients
from bergson.config import IndexConfig
from bergson.config import IndexConfig, PreprocessConfig
from bergson.data import allocate_batches
from bergson.distributed import launch_distributed_run
from bergson.utils.auto_batch_size import maybe_auto_batch_size
Expand All @@ -27,6 +27,7 @@ def build_worker(
local_rank: int,
world_size: int,
cfg: IndexConfig,
preprocess_cfg: PreprocessConfig,
ds: Dataset | IterableDataset,
):
"""
Expand Down Expand Up @@ -108,7 +109,7 @@ def flush(kwargs):
processor.save(cfg.partial_run_path)


def build(index_cfg: IndexConfig):
def build(index_cfg: IndexConfig, preprocess_cfg: PreprocessConfig):
"""
Build a gradient index by distributing work across all available GPUs.

Expand All @@ -117,6 +118,8 @@ def build(index_cfg: IndexConfig):
index_cfg : IndexConfig
Specifies the run path, dataset, model, tokenizer, PEFT adapters,
and many other gradient collection settings.
preprocess_cfg : PreprocessConfig
Preprocessing configuration for gradient normalization/preconditioning.
"""
if index_cfg.debug:
setup_reproducibility()
Expand All @@ -128,7 +131,10 @@ def build(index_cfg: IndexConfig):
ds = setup_data_pipeline(index_cfg)

launch_distributed_run(
"build", build_worker, [index_cfg, ds], index_cfg.distributed
"build",
build_worker,
[index_cfg, preprocess_cfg, ds],
index_cfg.distributed,
)

rank = index_cfg.distributed.rank
Expand Down
Loading