diff --git a/bergson/__init__.py b/bergson/__init__.py index 3019f686..839068eb 100644 --- a/bergson/__init__.py +++ b/bergson/__init__.py @@ -1,5 +1,9 @@ __version__ = "0.6.2" +from .builders import ( + Builder, + create_builder, +) from .collection import collect_gradients from .collector.collector import CollectorComputer from .collector.gradient_collectors import GradientCollector @@ -8,22 +12,20 @@ AttentionConfig, DataConfig, IndexConfig, + PreprocessConfig, QueryConfig, ReduceConfig, ScoreConfig, ) from .data import ( - Builder, - InMemorySequenceBuilder, - InMemoryTokenBuilder, TokenGradients, - create_builder, load_gradient_dataset, load_gradients, load_token_gradients, ) from .gradients import GradientProcessor from .normalizer.fit_normalizers import fit_normalizers +from .process_grads import mix_preconditioners from .query.attributor import Attributor from .query.faiss_index import FaissConfig from .score.scorer import Scorer @@ -36,8 +38,6 @@ "load_token_gradients", "TokenGradients", "Builder", - "InMemorySequenceBuilder", - "InMemoryTokenBuilder", "create_builder", "fit_normalizers", "Attributor", @@ -50,8 +50,10 @@ "IndexConfig", "DataConfig", "AttentionConfig", + "PreprocessConfig", "Scorer", "ScoreConfig", "ReduceConfig", "QueryConfig", + "mix_preconditioners", ] diff --git a/bergson/__main__.py b/bergson/__main__.py index 79ec970a..e3e7ff35 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -1,7 +1,4 @@ -import shutil -from copy import deepcopy from dataclasses import dataclass -from pathlib import Path from typing import Optional, Union from simple_parsing import ArgumentParser, ConflictResolution @@ -10,6 +7,7 @@ from .config import ( HessianConfig, IndexConfig, + PreprocessConfig, QueryConfig, ReduceConfig, ScoreConfig, @@ -19,23 +17,8 @@ from .query.query_index import query from .reduce import reduce from .score.score import score_dataset - - -def validate_run_path(index_cfg: IndexConfig): - """Validate the run path.""" - if index_cfg.distributed.rank != 0: - return - - for path in [Path(index_cfg.run_path), Path(index_cfg.partial_run_path)]: - if not path.exists(): - continue - - if index_cfg.overwrite: - shutil.rmtree(path) - else: - raise FileExistsError( - f"Run path {path} already exists. Use --overwrite to overwrite it." - ) +from .trackstar import trackstar +from .utils.worker_utils import validate_run_path @dataclass @@ -44,6 +27,8 @@ class Build: index_cfg: IndexConfig + preprocess_cfg: PreprocessConfig + def execute(self): """Build the gradient index.""" if self.index_cfg.skip_index and self.index_cfg.skip_preconditioners: @@ -51,7 +36,7 @@ def execute(self): validate_run_path(self.index_cfg) - build(self.index_cfg) + build(self.index_cfg, self.preprocess_cfg) @dataclass @@ -60,12 +45,14 @@ class Preconditioners: index_cfg: IndexConfig + preprocess_cfg: PreprocessConfig + def execute(self): """Compute normalizers and preconditioners.""" self.index_cfg.skip_index = True self.index_cfg.skip_preconditioners = False validate_run_path(self.index_cfg) - build(self.index_cfg) + build(self.index_cfg, self.preprocess_cfg) @dataclass @@ -76,6 +63,8 @@ class Reduce: reduce_cfg: ReduceConfig + preprocess_cfg: PreprocessConfig + def execute(self): """Reduce a gradient index.""" if self.index_cfg.projection_dim != 0: @@ -85,7 +74,7 @@ def execute(self): validate_run_path(self.index_cfg) - reduce(self.index_cfg, self.reduce_cfg) + reduce(self.index_cfg, self.reduce_cfg, self.preprocess_cfg) @dataclass @@ -96,6 +85,8 @@ class Score: index_cfg: IndexConfig + preprocess_cfg: PreprocessConfig + def execute(self): """Score a dataset against an existing gradient index.""" assert self.score_cfg.query_path @@ -107,7 +98,7 @@ def execute(self): validate_run_path(self.index_cfg) - score_dataset(self.index_cfg, self.score_cfg) + score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg) @dataclass @@ -140,58 +131,16 @@ class Trackstar: index_cfg: IndexConfig - trackstar_cfg: TrackstarConfig - score_cfg: ScoreConfig + preprocess_cfg: PreprocessConfig + + trackstar_cfg: TrackstarConfig + def execute(self): - """Run the full trackstar pipeline: preconditioners -> build -> score.""" - run_path = self.index_cfg.run_path - value_precond_path = f"{run_path}/value_preconditioner" - query_precond_path = f"{run_path}/query_preconditioner" - query_path = f"{run_path}/query" - scores_path = f"{run_path}/scores" - - # Step 1: Compute normalizers and preconditioners on value dataset - print("Step 1/4: Computing normalizers and preconditioners on value dataset...") - value_precond_cfg = deepcopy(self.index_cfg) - value_precond_cfg.run_path = value_precond_path - value_precond_cfg.skip_index = True - value_precond_cfg.skip_preconditioners = False - validate_run_path(value_precond_cfg) - build(value_precond_cfg) - - # Step 2: Compute normalizers and preconditioners on query dataset - print("Step 2/4: Computing normalizers and preconditioners on query dataset...") - query_precond_cfg = deepcopy(self.index_cfg) - query_precond_cfg.run_path = query_precond_path - query_precond_cfg.data = self.trackstar_cfg.query - query_precond_cfg.skip_index = True - query_precond_cfg.skip_preconditioners = False - validate_run_path(query_precond_cfg) - build(query_precond_cfg) - - # Step 3: Build per-item query gradient index - print("Step 3/4: Building query gradient index...") - query_cfg = deepcopy(self.index_cfg) - query_cfg.run_path = query_path - query_cfg.data = self.trackstar_cfg.query - query_cfg.processor_path = query_precond_path - query_cfg.skip_preconditioners = True - validate_run_path(query_cfg) - build(query_cfg) - - # Step 4: Score value dataset against query using both preconditioners - print("Step 4/4: Scoring value dataset...") - score_index_cfg = deepcopy(self.index_cfg) - score_index_cfg.run_path = scores_path - score_index_cfg.processor_path = value_precond_path - score_index_cfg.skip_preconditioners = True - self.score_cfg.query_path = query_path - self.score_cfg.query_preconditioner_path = query_precond_path - self.score_cfg.index_preconditioner_path = value_precond_path - validate_run_path(score_index_cfg) - score_dataset(score_index_cfg, self.score_cfg) + trackstar( + self.index_cfg, self.score_cfg, self.preprocess_cfg, self.trackstar_cfg + ) @dataclass diff --git a/bergson/build.py b/bergson/build.py index a2022562..cd5bc03d 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -10,7 +10,7 @@ from tqdm.auto import tqdm from bergson.collection import collect_gradients -from bergson.config import IndexConfig +from bergson.config import IndexConfig, PreprocessConfig from bergson.data import allocate_batches from bergson.distributed import launch_distributed_run from bergson.utils.auto_batch_size import maybe_auto_batch_size @@ -27,6 +27,7 @@ def build_worker( local_rank: int, world_size: int, cfg: IndexConfig, + preprocess_cfg: PreprocessConfig, ds: Dataset | IterableDataset, ): """ @@ -108,7 +109,7 @@ def flush(kwargs): processor.save(cfg.partial_run_path) -def build(index_cfg: IndexConfig): +def build(index_cfg: IndexConfig, preprocess_cfg: PreprocessConfig): """ Build a gradient index by distributing work across all available GPUs. @@ -117,6 +118,8 @@ def build(index_cfg: IndexConfig): index_cfg : IndexConfig Specifies the run path, dataset, model, tokenizer, PEFT adapters, and many other gradient collection settings. + preprocess_cfg : PreprocessConfig + Preprocessing configuration for gradient normalization/preconditioning. """ if index_cfg.debug: setup_reproducibility() @@ -128,7 +131,10 @@ def build(index_cfg: IndexConfig): ds = setup_data_pipeline(index_cfg) launch_distributed_run( - "build", build_worker, [index_cfg, ds], index_cfg.distributed + "build", + build_worker, + [index_cfg, preprocess_cfg, ds], + index_cfg.distributed, ) rank = index_cfg.distributed.rank diff --git a/bergson/builders.py b/bergson/builders.py new file mode 100644 index 00000000..ba11f3c7 --- /dev/null +++ b/bergson/builders.py @@ -0,0 +1,495 @@ +from abc import ABC, abstractmethod +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +from datasets import Dataset + +from .config import PreprocessConfig, ReduceConfig +from .data import compute_num_token_grads, create_index, create_token_index +from .process_grads import ( + get_trackstar_preconditioner, + normalize_flat_grad, + precondition_grad, +) +from .utils.utils import convert_dtype_to_np, tensor_to_numpy + +_EPS_SQ = torch.finfo(torch.float32).eps ** 2 + + +@torch.compile(fullgraph=True) +def _reduce(grads: torch.Tensor, buffer: torch.Tensor, do_normalize: bool) -> None: + """Normalize + sum grads in a single fused kernel.""" + if do_normalize: + inv_norms = grads.pow(2).sum(dim=-1).clamp_min_(_EPS_SQ).rsqrt().unsqueeze(1) + grads = grads * inv_norms + buffer[0] += grads.sum(dim=0).to(torch.float32) + + +class Builder(ABC): + """Interface for gradient index writers. + + Use :func:`create_builder` to construct the appropriate concrete + subclass based on *attribute_tokens* and *path*. + """ + + grad_buffer: np.ndarray + + @abstractmethod + def __call__( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + ) -> None: ... + + def flush(self) -> None: + if isinstance(self.grad_buffer, np.memmap): + self.grad_buffer.flush() + + def teardown(self) -> None: + """ + Called at the end. + + Override to perform custom cleanup such as: + - Saving results to disk + - Flushing buffers + - Freeing resources + """ + pass + + +class TokenBuilder(Builder): + """Creates and writes per-token gradients to disk. + + Parameters + ---------- + data : Dataset + The dataset being indexed (used only for length). + grad_sizes : dict[str, int] + Per-module gradient dimensions. + dtype : torch.dtype + Torch dtype for the gradients (converted to numpy internally). + path : Path + Root directory for the index artifacts. + """ + + def __init__( + self, + data: Dataset, + grad_sizes: dict[str, int], + dtype: torch.dtype, + *, + path: Path, + ): + self.grad_sizes = grad_sizes + self.num_items = len(data) + np_dtype = convert_dtype_to_np(dtype) + + self.num_token_grads = compute_num_token_grads(data) + self.grad_buffer, self.offsets = create_token_index( + path, + self.num_token_grads, + grad_sizes, + np_dtype, + ) + + def __call__( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + ): + """Write a batch of per-token gradients to the flat buffer. + + ``mod_grads`` values have shape ``[total_valid_in_batch, grad_dim_mod]`` + (already filtered to valid positions). Batch indices may be + non-contiguous, so each example's chunk is written individually. + """ + torch.cuda.synchronize() + + per_example_lengths = self.num_token_grads[indices] + + col_offset = 0 + for module_name in self.grad_sizes.keys(): + g_np = tensor_to_numpy(mod_grads[module_name]) + dim = g_np.shape[1] + row = 0 + for idx, sl in zip(indices, per_example_lengths): + buf_start = int(self.offsets[idx]) + buf_end = int(self.offsets[idx + 1]) + self.grad_buffer[buf_start:buf_end, col_offset : col_offset + dim] = ( + g_np[row : row + sl] + ) + row += sl + col_offset += dim + + def teardown(self): + self.flush() + + +class InMemorySequenceBuilder(Builder): + """Stores per-example gradients in memory. + + Drop-in replacement for :class:`SequenceBuilder` that keeps + gradients in a plain numpy array instead of a memory-mapped + file. Supports optional gradient reduction via + *reduce_cfg*. + + Parameters + ---------- + data : Dataset + The dataset being indexed (used only for length). + grad_sizes : dict[str, int] + Per-module gradient dimensions. + dtype : torch.dtype + Torch dtype for the gradients. + reduce_cfg : ReduceConfig | None + When set, accumulate all gradients into a single + row (mean or sum) instead of storing per-example. + preprocess_cfg : PreprocessConfig | None + When set, apply preconditioning/normalization during reduce. + """ + + def __init__( + self, + data: Dataset, + grad_sizes: dict[str, int], + dtype: torch.dtype, + *, + reduce_cfg: ReduceConfig | None = None, + preprocess_cfg: PreprocessConfig | None = None, + ): + self.grad_sizes = grad_sizes + self.num_items = len(data) + self.reduce_cfg = reduce_cfg + self.preprocess_cfg = preprocess_cfg + total_grad_dim = sum(grad_sizes.values()) + + if reduce_cfg is not None: + np_dtype = np.float32 + num_grads = 1 + device = "cuda" if torch.cuda.is_available() else "cpu" + self.in_memory_grad_buffer = torch.zeros( + (1, total_grad_dim), + dtype=torch.float32, + device=device, + ) + self.h_inv = ( + get_trackstar_preconditioner( + self.preprocess_cfg.preconditioner_path, + power=-0.5 if self.preprocess_cfg.unit_normalize else -1, + device=torch.device(device), + ) + if self.preprocess_cfg is not None + else {} + ) + else: + np_dtype = convert_dtype_to_np(dtype) + num_grads = self.num_items + self.in_memory_grad_buffer = None + self.h_inv: dict[str, torch.Tensor] = {} + + self.grad_buffer = np.zeros( + (num_grads, total_grad_dim), + dtype=np_dtype, + ) + + def reduce( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + ): + """Accumulate batch gradients into the reduce buffer.""" + assert self.reduce_cfg is not None + assert self.in_memory_grad_buffer is not None + + # Precondition the gradients + mod_grads = precondition_grad(mod_grads, self.h_inv) + + unit_normalize = ( + self.preprocess_cfg.unit_normalize + if self.preprocess_cfg is not None + else False + ) + + all_grads = torch.cat([mod_grads[m] for m in self.grad_sizes.keys()], dim=-1) + _reduce(all_grads, self.in_memory_grad_buffer, unit_normalize) + + def __call__( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + ): + if self.reduce_cfg is not None: + self.reduce(indices, mod_grads) + return + + if torch.cuda.is_available(): + torch.cuda.synchronize() + offset = 0 + for module_name in self.grad_sizes.keys(): + dim = mod_grads[module_name].shape[1] + self.grad_buffer[ + indices, + offset : offset + dim, + ] = tensor_to_numpy(mod_grads[module_name]) + offset += dim + + def teardown(self): + if self.reduce_cfg is None: + return + + assert self.in_memory_grad_buffer is not None + + if torch.cuda.is_available(): + self.in_memory_grad_buffer = self.in_memory_grad_buffer.cuda() + + if dist.is_initialized(): + dist.reduce( + self.in_memory_grad_buffer, + dst=0, + op=dist.ReduceOp.SUM, + ) + + if self.reduce_cfg.method == "mean": + self.in_memory_grad_buffer /= self.num_items + + if self.reduce_cfg.normalize_reduced_grad: + device = self.in_memory_grad_buffer.device + self.in_memory_grad_buffer = normalize_flat_grad( + self.in_memory_grad_buffer, device + ) + + self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() + + self.grad_buffer[:] = tensor_to_numpy(self.in_memory_grad_buffer).astype( + self.grad_buffer.dtype + ) + + +class InMemoryTokenBuilder(Builder): + """Stores per-token gradients in memory. + + Drop-in replacement for :class:`TokenBuilder` that keeps + gradients in a plain numpy array instead of a memory-mapped + file. + + Parameters + ---------- + data : Dataset + The dataset being indexed (used only for length and + label information). + grad_sizes : dict[str, int] + Per-module gradient dimensions. + dtype : torch.dtype + Torch dtype for the gradients. + """ + + def __init__( + self, + data: Dataset, + grad_sizes: dict[str, int], + dtype: torch.dtype, + ): + self.grad_sizes = grad_sizes + self.num_items = len(data) + np_dtype = convert_dtype_to_np(dtype) + total_grad_dim = sum(grad_sizes.values()) + + self.num_token_grads = compute_num_token_grads(data) + self.offsets = np.zeros(len(self.num_token_grads) + 1, dtype=np.int64) + np.cumsum(self.num_token_grads, out=self.offsets[1:]) + total_tokens = int(self.offsets[-1]) + + self.grad_buffer = np.zeros((total_tokens, total_grad_dim), dtype=np_dtype) + + def __call__( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + ): + """Write a batch of per-token gradients. + + ``mod_grads`` values have shape + ``[total_valid_in_batch, grad_dim_mod]`` + (already filtered to valid positions). + """ + if torch.cuda.is_available(): + torch.cuda.synchronize() + per_example_lengths = self.num_token_grads[indices] + + col_offset = 0 + for module_name in self.grad_sizes.keys(): + g_np = tensor_to_numpy(mod_grads[module_name]) + dim = g_np.shape[1] + row = 0 + for idx, sl in zip(indices, per_example_lengths): + buf_start = int(self.offsets[idx]) + buf_end = int(self.offsets[idx + 1]) + self.grad_buffer[ + buf_start:buf_end, + col_offset : col_offset + dim, + ] = g_np[row : row + sl] + row += sl + col_offset += dim + + +class SequenceBuilder(Builder): + """Creates and writes gradients to disk, with optional distributed reduction. + Scores are always saved as float32.""" + + num_items: int + + reduce_cfg: ReduceConfig | None + + def __init__( + self, + data: Dataset, + grad_sizes: dict[str, int], + dtype: torch.dtype, + *, + path: Path, + reduce_cfg: ReduceConfig | None = None, + preprocess_cfg: PreprocessConfig | None = None, + ): + self.grad_sizes = grad_sizes + self.num_items = len(data) + self.reduce_cfg = reduce_cfg + self.preprocess_cfg = preprocess_cfg + self.rank = dist.get_rank() if dist.is_initialized() else 0 + if reduce_cfg is not None: + num_grads = 1 + np_dtype = np.float32 + self.in_memory_grad_buffer = torch.zeros( + (num_grads, sum(self.grad_sizes.values())), + dtype=torch.float32, + device=f"cuda:{self.rank}", + ) + device = torch.device(f"cuda:{self.rank}") + self.h_inv = ( + get_trackstar_preconditioner( + self.preprocess_cfg.preconditioner_path, + power=-0.5 if self.preprocess_cfg.unit_normalize else -1, + device=torch.device(device), + ) + if self.preprocess_cfg is not None + else {} + ) + else: + num_grads = self.num_items + np_dtype = convert_dtype_to_np(dtype) + self.in_memory_grad_buffer = None + self.h_inv: dict[str, torch.Tensor] = {} + + self.grad_buffer = create_index( + path, + num_grads=num_grads, + grad_sizes=self.grad_sizes, + dtype=np_dtype, + with_structure=False, + ) + + def reduce(self, indices: list[int], mod_grads: dict[str, torch.Tensor]): + assert self.reduce_cfg is not None and self.in_memory_grad_buffer is not None + + # Precondition the gradients + mod_grads = precondition_grad(mod_grads, self.h_inv) + + unit_normalize = ( + self.preprocess_cfg.unit_normalize + if self.preprocess_cfg is not None + else False + ) + + all_grads = torch.cat([mod_grads[m] for m in self.grad_sizes.keys()], dim=-1) + _reduce(all_grads, self.in_memory_grad_buffer, unit_normalize) + + def __call__(self, indices: list[int], mod_grads: dict[str, torch.Tensor]): + torch.cuda.synchronize() + + if self.reduce_cfg is not None: + self.reduce(indices, mod_grads) + else: + # It turns out that it's very important for efficiency to write the + # gradients sequentially instead of first concatenating them, then + # writing to one vector + offset = 0 + for module_name in self.grad_sizes.keys(): + self.grad_buffer[ + indices, offset : offset + mod_grads[module_name].shape[1] + ] = tensor_to_numpy(mod_grads[module_name]) + offset += mod_grads[module_name].shape[1] + + def teardown(self): + self.flush() + + if self.reduce_cfg is None: + return + + assert self.in_memory_grad_buffer is not None + + self.in_memory_grad_buffer = self.in_memory_grad_buffer.cuda() + + if dist.is_initialized(): + dist.reduce(self.in_memory_grad_buffer, dst=0, op=dist.ReduceOp.SUM) + + if self.reduce_cfg.method == "mean": + self.in_memory_grad_buffer /= self.num_items + + # Unit normalize the reduced gradient + if self.reduce_cfg.normalize_reduced_grad: + device = self.in_memory_grad_buffer.device + self.in_memory_grad_buffer = normalize_flat_grad( + self.in_memory_grad_buffer, device + ) + + self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() + + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + self.grad_buffer[:] = tensor_to_numpy(self.in_memory_grad_buffer).astype( + self.grad_buffer.dtype + ) + + self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() + + +def create_builder( + data: Dataset, + grad_sizes: dict[str, int], + dtype: torch.dtype, + *, + attribute_tokens: bool = False, + path: Path | None = None, + reduce_cfg: ReduceConfig | None = None, + preprocess_cfg: PreprocessConfig | None = None, +) -> Builder: + """Create the appropriate :class:`Builder` subclass. + + Dispatches based on *attribute_tokens* and *path*: + + * ``path`` given + ``attribute_tokens`` → :class:`TokenBuilder` + * ``path`` given → :class:`SequenceBuilder` + * no ``path`` + ``attribute_tokens`` → :class:`InMemoryTokenBuilder` + * no ``path`` → :class:`InMemorySequenceBuilder` + """ + if path is not None: + if attribute_tokens: + return TokenBuilder(data, grad_sizes, dtype, path=path) + return SequenceBuilder( + data, + grad_sizes, + dtype, + path=path, + reduce_cfg=reduce_cfg, + preprocess_cfg=preprocess_cfg, + ) + if attribute_tokens: + return InMemoryTokenBuilder(data, grad_sizes, dtype) + return InMemorySequenceBuilder( + data, + grad_sizes, + dtype, + reduce_cfg=reduce_cfg, + preprocess_cfg=preprocess_cfg, + ) diff --git a/bergson/collection.py b/bergson/collection.py index 52567496..8d3af1fc 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -3,7 +3,7 @@ from bergson.collector.collector import CollectorComputer from bergson.collector.gradient_collectors import GradientCollector -from bergson.config import AttentionConfig, IndexConfig, ReduceConfig +from bergson.config import AttentionConfig, IndexConfig, PreprocessConfig, ReduceConfig from bergson.gradients import GradientProcessor from bergson.score.scorer import Scorer @@ -19,6 +19,7 @@ def collect_gradients( attention_cfgs: dict[str, AttentionConfig] | None = None, scorer: Scorer | None = None, reduce_cfg: ReduceConfig | None = None, + preprocess_cfg: PreprocessConfig | None = None, ): """ Compute gradients using the hooks specified in the GradientCollector. @@ -31,6 +32,7 @@ def collect_gradients( data=data, scorer=scorer, reduce_cfg=reduce_cfg, + preprocess_cfg=preprocess_cfg, attention_cfgs=attention_cfgs or {}, filter_modules=cfg.filter_modules, ) diff --git a/bergson/collector/dist_preconditioners_gradient_collector.py b/bergson/collector/dist_preconditioners_gradient_collector.py index 6d57795c..aadd3fa4 100644 --- a/bergson/collector/dist_preconditioners_gradient_collector.py +++ b/bergson/collector/dist_preconditioners_gradient_collector.py @@ -8,9 +8,9 @@ from jaxtyping import Float from torch import Tensor +from bergson.builders import Builder, create_builder from bergson.collector.collector import HookCollectorBase from bergson.config import IndexConfig, ReduceConfig -from bergson.data import Builder, create_builder from bergson.gradients import ( AdafactorNormalizer, AdamNormalizer, @@ -275,10 +275,8 @@ def teardown(self): self.processor.save(self.cfg.partial_run_path) - # Flush and reduce builder if it exists if self.builder is not None: - self.builder.flush() - self.builder.dist_reduce() + self.builder.teardown() def exchange_preconditioner_gradients( diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index e3488d6d..32245815 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -9,9 +9,9 @@ from jaxtyping import Float from torch import Tensor +from bergson.builders import Builder, create_builder from bergson.collector.collector import HookCollectorBase -from bergson.config import IndexConfig, ReduceConfig -from bergson.data import Builder, create_builder +from bergson.config import IndexConfig, PreprocessConfig, ReduceConfig from bergson.gradients import ( AdafactorNormalizer, AdamNormalizer, @@ -46,6 +46,9 @@ class GradientCollector(HookCollectorBase): reduce_cfg: ReduceConfig | None = None """Configuration for in-run gradient reduction.""" + preprocess_cfg: PreprocessConfig | None = None + """Configuration for gradient preprocessing.""" + builder: Builder | None = None """Handles writing gradients to disk. Created in setup() if save_index is True.""" @@ -95,6 +98,7 @@ def setup(self) -> None: attribute_tokens=self.cfg.attribute_tokens, path=self.cfg.partial_run_path, reduce_cfg=self.reduce_cfg, + preprocess_cfg=self.preprocess_cfg, ) else: self.builder = None @@ -250,15 +254,13 @@ def teardown(self): self.rank, ) - # Flush and reduce builder if it exists - if self.builder is not None: - self.builder.flush() - self.builder.dist_reduce() + if self.builder: + self.builder.teardown() if self.rank == 0: - if self.reduce_cfg is not None: + if self.reduce_cfg: # Create a new dataset with one row for each reduced gradient - assert self.builder is not None + assert self.builder self.data = Dataset.from_list( [ {"query_index": i} @@ -294,12 +296,6 @@ class TraceCollector(HookCollectorBase): mod_grads: dict = field(default_factory=lambda: defaultdict(list)) """Accumulated grads per module. Maps module name to list of gradient tensors.""" - eps: float = 1e-6 - """Epsilon for numerical stability in preconditioning.""" - - precondition: bool = False - """Whether to apply preconditioning via autocorrelation Hessian approximation.""" - device: torch.device | str """Device to store collected gradients on.""" @@ -380,14 +376,6 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): P = P.flatten(1).clamp_(self.lo, self.hi) - # Precondition the gradient using Cholesky solve - # TODO: Should damp here? - if self.precondition: - eigval, eigvec = self.processor.preconditioners_eigen[name] - eigval_inverse_sqrt = 1.0 / (eigval + self.eps).sqrt() - prec = eigvec * eigval_inverse_sqrt @ eigvec.mT - P = P.type_as(prec) @ prec # <- apply to P - # Store the gradient for later use self.mod_grads[name].append(P.to(self.device, self.dtype, non_blocking=True)) diff --git a/bergson/collector/in_memory_collector.py b/bergson/collector/in_memory_collector.py index cee70f01..b42b9bd9 100644 --- a/bergson/collector/in_memory_collector.py +++ b/bergson/collector/in_memory_collector.py @@ -10,9 +10,9 @@ from jaxtyping import Float from torch import Tensor, nn +from bergson.builders import Builder, create_builder from bergson.collector.collector import HookCollectorBase -from bergson.config import IndexConfig, ReduceConfig -from bergson.data import Builder, create_builder +from bergson.config import IndexConfig, PreprocessConfig, ReduceConfig from bergson.gradients import ( AdafactorNormalizer, AdamNormalizer, @@ -22,7 +22,7 @@ process_preconditioners, ) from bergson.score.scorer import Scorer -from bergson.utils.utils import assert_type, get_gradient_dtype +from bergson.utils.utils import assert_type, get_gradient_dtype, numpy_to_tensor @dataclass(kw_only=True) @@ -52,6 +52,9 @@ class InMemoryCollector(HookCollectorBase): reduce_cfg: ReduceConfig | None = None """Configuration for in-run gradient reduction.""" + preprocess_cfg: PreprocessConfig | None = None + """Configuration for gradient preprocessing.""" + builder: Builder | None = None """Handles writing gradients. Created in setup().""" @@ -109,6 +112,7 @@ def setup(self) -> None: self.save_dtype, attribute_tokens=self.cfg.attribute_tokens, reduce_cfg=self.reduce_cfg, + preprocess_cfg=self.preprocess_cfg, ) def teardown(self) -> None: @@ -127,16 +131,13 @@ def teardown(self) -> None: ) if self.builder is not None: - self.builder.dist_reduce() - self.builder.flush() + self.builder.teardown() # Populate self.gradients from builder buffer buf = self.builder.grad_buffer offset = 0 for name, dim in grad_sizes.items(): - self.gradients[name] = torch.from_numpy( - buf[:, offset : offset + dim].copy() - ) + self.gradients[name] = numpy_to_tensor(buf[:, offset : offset + dim]) offset += dim if self.scorer is not None: diff --git a/bergson/config.py b/bergson/config.py index cf0fdc22..4faddff5 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -155,7 +155,7 @@ class IndexConfig: """Type of normalizer to use for the gradients.""" skip_preconditioners: bool = False - """Whether to skip computing preconditioners for the gradients.""" + """Whether to skip estimating preconditioner statistics""" skip_index: bool = False """Whether to skip building the gradient index.""" @@ -265,6 +265,17 @@ class QueryConfig: its top results as rows with columns: query, result, result_index, score.""" +@dataclass +class PreprocessConfig: + """Config for gradient preprocessing, shared across build, reduce, and score.""" + + unit_normalize: bool = False + """Whether to unit normalize the gradients.""" + + preconditioner_path: str | None = None + """Path to a precomputed preconditioner.""" + + @dataclass class ScoreConfig: """Config for querying an index on the fly.""" @@ -272,34 +283,12 @@ class ScoreConfig: query_path: str = "" """Path to the existing query index.""" - score: Literal["mean", "nearest", "individual"] = "mean" + score: Literal["nearest", "individual"] = "individual" """Method for scoring the gradients with the query. - `mean`: compute each gradient's similarity to the mean - query gradient. `nearest`: compute each gradient's similarity to the most similar query gradient (the maximum score). `individual`: compute a separate score for each query gradient.""" - query_preconditioner_path: str | None = None - """Path to a precomputed preconditioner to be applied to - the query dataset gradients.""" - - index_preconditioner_path: str | None = None - """Path to a precomputed preconditioner to be applied to - the query dataset gradients. This does not affect the - ability to compute a new preconditioner during the query.""" - - mixing_coefficient: float = 0.99 - """Coefficient to weight the application of the query preconditioner - and the pre-computed index preconditioner. 0.0 means only use the - index preconditioner and 1.0 means only use the query preconditioner.""" - - modules: list[str] = field(default_factory=list) - """Modules to use for the query. If empty, all modules will be used.""" - - unit_normalize: bool = False - """Whether to unit normalize the gradients before computing the scores.""" - batch_size: int = 1024 """Batch size for processing the query dataset.""" @@ -307,16 +296,25 @@ class ScoreConfig: """Precision (dtype) to convert the query and index gradients to before computing the scores. If "auto", the model's gradient dtype is used.""" + modules: list[str] = field(default_factory=list) + """Modules to use for the query. If empty, all modules will be used.""" + @dataclass class ReduceConfig: - """Config for reducing the gradients.""" + """Config for reducing the gradients of a dataset into a standalone + aggregated gradient.""" method: Literal["mean", "sum"] = "mean" """Method for reducing the gradients.""" - unit_normalize: bool = False - """Whether to unit normalize the gradients before reducing them.""" + modules: list[str] = field(default_factory=list) + """Modules to use for the query. If empty, all modules will be used.""" + + normalize_reduced_grad: bool = False + """Whether to unit normalize the reduced query gradient. This has + no effect on future relative score rankings but does affect score + magnitudes.""" @dataclass @@ -379,3 +377,6 @@ class TrackstarConfig: query: DataConfig = field(default_factory=DataConfig) """Query dataset specification.""" + + mixing_coefficient: float = 0.99 + """Weight for mixing query vs index preconditioner (1.0 = query only).""" diff --git a/bergson/data.py b/bergson/data.py index 7b9843d4..689e36ca 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -2,7 +2,6 @@ import math import os import random -from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Sequence @@ -22,12 +21,10 @@ from numpy.lib.recfunctions import structured_to_unstructured from numpy.typing import DTypeLike -from .config import DataConfig, ReduceConfig +from .config import DataConfig from .utils.utils import ( assert_type, - convert_dtype_to_np, simple_parse_args_string, - tensor_to_numpy, ) @@ -183,301 +180,6 @@ def __getitem__(self, i: int) -> np.ndarray: return np.asarray(self.mmap[self._offsets[i] : self._offsets[i + 1]]) -class Builder(ABC): - """Interface for gradient index writers. - - Use :func:`create_builder` to construct the appropriate concrete - subclass based on *attribute_tokens* and *path*. - """ - - grad_buffer: np.ndarray - - @abstractmethod - def __call__( - self, - indices: list[int], - mod_grads: dict[str, torch.Tensor], - ) -> None: ... - - def flush(self) -> None: - if isinstance(self.grad_buffer, np.memmap): - self.grad_buffer.flush() - - def dist_reduce(self) -> None: - pass - - -class TokenBuilder(Builder): - """Creates and writes per-token gradients to disk. - - Parameters - ---------- - data : Dataset - The dataset being indexed (used only for length). - grad_sizes : dict[str, int] - Per-module gradient dimensions. - dtype : torch.dtype - Torch dtype for the gradients (converted to numpy internally). - path : Path - Root directory for the index artifacts. - """ - - def __init__( - self, - data: Dataset, - grad_sizes: dict[str, int], - dtype: torch.dtype, - *, - attribute_tokens: bool = False, - path: Path | None = None, - reduce_cfg: ReduceConfig | None = None, - ): - assert path is not None - self.grad_sizes = grad_sizes - self.num_items = len(data) - np_dtype = convert_dtype_to_np(dtype) - - self.num_token_grads = compute_num_token_grads(data) - self.grad_buffer, self.offsets = create_token_index( - path, - self.num_token_grads, - grad_sizes, - np_dtype, - ) - - def __call__( - self, - indices: list[int], - mod_grads: dict[str, torch.Tensor], - ): - """Write a batch of per-token gradients to the flat buffer. - - ``mod_grads`` values have shape ``[total_valid_in_batch, grad_dim_mod]`` - (already filtered to valid positions). Batch indices may be - non-contiguous, so each example's chunk is written individually. - """ - torch.cuda.synchronize() - - per_example_lengths = self.num_token_grads[indices] - - col_offset = 0 - for module_name in self.grad_sizes.keys(): - g_np = tensor_to_numpy(mod_grads[module_name]) - dim = g_np.shape[1] - row = 0 - for idx, sl in zip(indices, per_example_lengths): - buf_start = int(self.offsets[idx]) - buf_end = int(self.offsets[idx + 1]) - self.grad_buffer[buf_start:buf_end, col_offset : col_offset + dim] = ( - g_np[row : row + sl] - ) - row += sl - col_offset += dim - - -class InMemorySequenceBuilder(Builder): - """Stores per-example gradients in memory. - - Drop-in replacement for :class:`SequenceBuilder` that keeps - gradients in a plain numpy array instead of a memory-mapped - file. Supports optional gradient reduction via - *reduce_cfg*. - - Parameters - ---------- - data : Dataset - The dataset being indexed (used only for length). - grad_sizes : dict[str, int] - Per-module gradient dimensions. - dtype : torch.dtype - Torch dtype for the gradients. - reduce_cfg : ReduceConfig | None - When set, accumulate all gradients into a single - row (mean or sum) instead of storing per-example. - """ - - def __init__( - self, - data: Dataset, - grad_sizes: dict[str, int], - dtype: torch.dtype, - *, - attribute_tokens: bool = False, - path: Path | None = None, - reduce_cfg: ReduceConfig | None = None, - ): - self.grad_sizes = grad_sizes - self.num_items = len(data) - self.reduce_cfg = reduce_cfg - self.eps = torch.finfo(torch.float32).eps - total_grad_dim = sum(grad_sizes.values()) - - if reduce_cfg is not None: - np_dtype = np.float32 - num_grads = 1 - device = "cuda" if torch.cuda.is_available() else "cpu" - self.in_memory_grad_buffer = torch.zeros( - (1, total_grad_dim), - dtype=torch.float32, - device=device, - ) - else: - np_dtype = convert_dtype_to_np(dtype) - num_grads = self.num_items - self.in_memory_grad_buffer = None - - self.grad_buffer = np.zeros( - (num_grads, total_grad_dim), - dtype=np_dtype, - ) - - def reduce( - self, - indices: list[int], - mod_grads: dict[str, torch.Tensor], - ): - """Accumulate batch gradients into the reduce buffer.""" - assert self.reduce_cfg is not None - assert self.in_memory_grad_buffer is not None - device = next(iter(mod_grads.values())).device - - if self.reduce_cfg.unit_normalize: - ssqs = torch.zeros(len(indices), device=device) - for mod_grad in mod_grads.values(): - ssqs += mod_grad.pow(2).sum(dim=-1) - norms = ssqs.sqrt() - else: - norms = torch.ones(len(indices), device=device) - - offset = 0 - for module_name in self.grad_sizes.keys(): - grads = mod_grads[module_name] - if self.reduce_cfg.unit_normalize: - grads = grads / (norms.unsqueeze(1) + self.eps) - grads = grads.sum(dim=0).to(torch.float32) - self.in_memory_grad_buffer[ - 0, - offset : offset + grads.shape[0], - ] += grads - offset += grads.shape[0] - - def __call__( - self, - indices: list[int], - mod_grads: dict[str, torch.Tensor], - ): - if self.reduce_cfg is not None: - self.reduce(indices, mod_grads) - return - - if torch.cuda.is_available(): - torch.cuda.synchronize() - offset = 0 - for module_name in self.grad_sizes.keys(): - dim = mod_grads[module_name].shape[1] - self.grad_buffer[ - indices, - offset : offset + dim, - ] = tensor_to_numpy(mod_grads[module_name]) - offset += dim - - def dist_reduce(self): - if self.reduce_cfg is None: - return - - assert self.in_memory_grad_buffer is not None - - if torch.cuda.is_available(): - self.in_memory_grad_buffer = self.in_memory_grad_buffer.cuda() - - if dist.is_initialized(): - dist.reduce( - self.in_memory_grad_buffer, - dst=0, - op=dist.ReduceOp.SUM, - ) - - if self.reduce_cfg.method == "mean": - self.in_memory_grad_buffer /= self.num_items - - self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() - - self.grad_buffer[:] = tensor_to_numpy(self.in_memory_grad_buffer).astype( - self.grad_buffer.dtype - ) - - -class InMemoryTokenBuilder(Builder): - """Stores per-token gradients in memory. - - Drop-in replacement for :class:`TokenBuilder` that keeps - gradients in a plain numpy array instead of a memory-mapped - file. - - Parameters - ---------- - data : Dataset - The dataset being indexed (used only for length and - label information). - grad_sizes : dict[str, int] - Per-module gradient dimensions. - dtype : torch.dtype - Torch dtype for the gradients. - """ - - def __init__( - self, - data: Dataset, - grad_sizes: dict[str, int], - dtype: torch.dtype, - *, - attribute_tokens: bool = False, - path: Path | None = None, - reduce_cfg: ReduceConfig | None = None, - ): - self.grad_sizes = grad_sizes - self.num_items = len(data) - np_dtype = convert_dtype_to_np(dtype) - total_grad_dim = sum(grad_sizes.values()) - - self.num_token_grads = compute_num_token_grads(data) - self.offsets = np.zeros(len(self.num_token_grads) + 1, dtype=np.int64) - np.cumsum(self.num_token_grads, out=self.offsets[1:]) - total_tokens = int(self.offsets[-1]) - - self.grad_buffer = np.zeros((total_tokens, total_grad_dim), dtype=np_dtype) - - def __call__( - self, - indices: list[int], - mod_grads: dict[str, torch.Tensor], - ): - """Write a batch of per-token gradients. - - ``mod_grads`` values have shape - ``[total_valid_in_batch, grad_dim_mod]`` - (already filtered to valid positions). - """ - if torch.cuda.is_available(): - torch.cuda.synchronize() - per_example_lengths = self.num_token_grads[indices] - - col_offset = 0 - for module_name in self.grad_sizes.keys(): - g_np = tensor_to_numpy(mod_grads[module_name]) - dim = g_np.shape[1] - row = 0 - for idx, sl in zip(indices, per_example_lengths): - buf_start = int(self.offsets[idx]) - buf_end = int(self.offsets[idx + 1]) - self.grad_buffer[ - buf_start:buf_end, - col_offset : col_offset + dim, - ] = g_np[row : row + sl] - row += sl - col_offset += dim - - def ceildiv(a: int, b: int) -> int: """Ceiling division of two integers.""" return -(-a // b) # Equivalent to math.ceil(a / b) but faster for integers @@ -834,148 +536,6 @@ def load_scores( return Scores(mmap, info) -class SequenceBuilder(Builder): - """Creates and writes gradients to disk, with optional distributed reduction. - Scores are always saved as float32.""" - - num_items: int - - reduce_cfg: ReduceConfig | None - - def __init__( - self, - data: Dataset, - grad_sizes: dict[str, int], - dtype: torch.dtype, - *, - attribute_tokens: bool = False, - path: Path | None = None, - reduce_cfg: ReduceConfig | None = None, - ): - assert path is not None - self.grad_sizes = grad_sizes - self.num_items = len(data) - self.reduce_cfg = reduce_cfg - self.eps = torch.finfo(torch.float32).eps - self.rank = dist.get_rank() if dist.is_initialized() else 0 - if reduce_cfg is not None: - num_grads = 1 - np_dtype = np.float32 - self.in_memory_grad_buffer = torch.zeros( - (num_grads, sum(self.grad_sizes.values())), - dtype=torch.float32, - device=f"cuda:{self.rank}", - ) - else: - num_grads = self.num_items - np_dtype = convert_dtype_to_np(dtype) - self.in_memory_grad_buffer = None - - self.grad_buffer = create_index( - path, - num_grads=num_grads, - grad_sizes=self.grad_sizes, - dtype=np_dtype, - with_structure=False, - ) - - def reduce(self, indices: list[int], mod_grads: dict[str, torch.Tensor]): - assert self.reduce_cfg is not None and self.in_memory_grad_buffer is not None - device = next(iter(mod_grads.values())).device - - if self.reduce_cfg.unit_normalize: - ssqs = torch.zeros(len(indices), device=device) - for mod_grad in mod_grads.values(): - ssqs += mod_grad.pow(2).sum(dim=-1) - norms = ssqs.sqrt() - else: - norms = torch.ones(len(indices), device=device) - - offset = 0 - for module_name in self.grad_sizes.keys(): - grads = mod_grads[module_name] - if self.reduce_cfg.unit_normalize: - grads = grads / (norms.unsqueeze(1) + self.eps) - - grads = grads.sum(dim=0).to(torch.float32) - - self.in_memory_grad_buffer[0, offset : offset + grads.shape[0]] += grads - offset += grads.shape[0] - - def __call__(self, indices: list[int], mod_grads: dict[str, torch.Tensor]): - torch.cuda.synchronize() - - if self.reduce_cfg is not None: - self.reduce(indices, mod_grads) - else: - # It turns out that it's very important for efficiency to write the - # gradients sequentially instead of first concatenating them, then - # writing to one vector - offset = 0 - for module_name in self.grad_sizes.keys(): - self.grad_buffer[ - indices, offset : offset + mod_grads[module_name].shape[1] - ] = tensor_to_numpy(mod_grads[module_name]) - offset += mod_grads[module_name].shape[1] - - def dist_reduce(self): - if self.reduce_cfg is None: - return - - assert self.in_memory_grad_buffer is not None - - self.in_memory_grad_buffer = self.in_memory_grad_buffer.cuda() - - if dist.is_initialized(): - dist.reduce(self.in_memory_grad_buffer, dst=0, op=dist.ReduceOp.SUM) - - if self.reduce_cfg.method == "mean": - self.in_memory_grad_buffer /= self.num_items - - self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() - - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - self.grad_buffer[:] = tensor_to_numpy(self.in_memory_grad_buffer).astype( - self.grad_buffer.dtype - ) - - self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() - - -def create_builder( - data: Dataset, - grad_sizes: dict[str, int], - dtype: torch.dtype, - *, - attribute_tokens: bool = False, - path: Path | None = None, - reduce_cfg: ReduceConfig | None = None, -) -> Builder: - """Create the appropriate :class:`Builder` subclass. - - Dispatches based on *attribute_tokens* and *path*: - - * ``path`` given + ``attribute_tokens`` → :class:`TokenBuilder` - * ``path`` given → :class:`SequenceBuilder` - * no ``path`` + ``attribute_tokens`` → :class:`InMemoryTokenBuilder` - * no ``path`` → :class:`InMemorySequenceBuilder` - """ - if path is not None: - cls = TokenBuilder if attribute_tokens else SequenceBuilder - else: - cls = InMemoryTokenBuilder if attribute_tokens else InMemorySequenceBuilder - - return cls( - data, - grad_sizes, - dtype, - attribute_tokens=attribute_tokens, - path=path, - reduce_cfg=reduce_cfg, - ) - - def pad_and_tensor( sequences: list[list[int]], labels: list[list[int]] | None = None, diff --git a/bergson/process_grads.py b/bergson/process_grads.py new file mode 100644 index 00000000..27d8b2dd --- /dev/null +++ b/bergson/process_grads.py @@ -0,0 +1,241 @@ +import json +import warnings +from pathlib import Path +from typing import Literal + +import torch + +from bergson.gradients import GradientProcessor +from bergson.utils.math import damped_psd_power + + +def normalize_grad( + grad_dict: dict[str, torch.Tensor], + unit_normalize: bool, + device: torch.device, +) -> dict[str, torch.Tensor]: + """Preprocess a single gradient. Optionally unit-normalizes + across all columns, moves to device.""" + final_dtype = next(iter(grad_dict.values())).dtype + grads = { + name: g.to(device=device, dtype=torch.float32) for name, g in grad_dict.items() + } + + if unit_normalize: + norm = torch.sqrt(torch.stack([g.pow(2).sum() for g in grads.values()]).sum()) + if norm > 0: + grads = {k: v / norm for k, v in grads.items()} + else: + warnings.warn("Gradient norm is zero, skipping normalization") + + return {k: v.to(final_dtype) for k, v in grads.items()} + + +def normalize_flat_grad( + grad: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + """Unit-normalize a single gradient tensor.""" + final_dtype = grad.dtype + grad = grad.to(device=device, dtype=torch.float32) + norm = grad.norm() + if norm > 0: + grad /= norm + else: + warnings.warn("Gradient norm is zero, skipping normalization") + return grad.to(final_dtype) + + +def mix_preconditioners( + query_path: str | Path, + index_path: str | Path, + output_path: str | Path, + mixing_coefficient: float = 0.99, +) -> Path: + """Mix query and index preconditioners and save the result to disk. + + Computes ``H_mixed = coeff * H_query + (1 - coeff) * H_index`` for + every module's raw H matrix, then persists a new + :class:`~bergson.gradients.GradientProcessor` at *output_path*. + + A ``mix_config.json`` file is also written alongside for provenance. + + Parameters + ---------- + query_path : str | Path + Directory containing the query GradientProcessor. + index_path : str | Path + Directory containing the index GradientProcessor. + output_path : str | Path + Directory where the mixed GradientProcessor will be saved. + mixing_coefficient : float + Weight for the query preconditioner (1.0 = query only). + + Returns + ------- + Path + The *output_path* as a :class:`pathlib.Path`. + """ + query_path = Path(query_path) + index_path = Path(index_path) + output_path = Path(output_path) + + q_proc = GradientProcessor.load(query_path) + i_proc = GradientProcessor.load(index_path) + + mixed_preconditioners = { + k: q_proc.preconditioners[k] * mixing_coefficient + + i_proc.preconditioners[k] * (1 - mixing_coefficient) + for k in q_proc.preconditioners + } + + # Build a new processor with the mixed preconditioners + mixed_proc = GradientProcessor( + normalizers=q_proc.normalizers, + preconditioners=mixed_preconditioners, + preconditioners_eigen={}, + projection_dim=q_proc.projection_dim, + reshape_to_square=q_proc.reshape_to_square, + projection_type=q_proc.projection_type, + include_bias=q_proc.include_bias, + ) + mixed_proc.save(output_path) + + # Save provenance metadata + mix_config = { + "query_path": str(query_path), + "index_path": str(index_path), + "mixing_coefficient": mixing_coefficient, + } + with (output_path / "mix_config.json").open("w") as f: + json.dump(mix_config, f, indent=2) + + return output_path + + +def get_trackstar_preconditioner( + preconditioner_path: str | None, + device: torch.device, + power: float = -0.5, + return_dtype: torch.dtype | None = None, +) -> dict[str, torch.Tensor]: + """Compute preconditioner matrices from a saved processor file. + + Parameters + ---------- + preconditioner_path : str | None + Directory containing the saved GradientProcessor. + device : torch.device + Device to load the preconditioner onto. + power : float + Matrix power to apply to each H matrix. + + * ``-0.5`` — H^(-1/2), used for split (two-sided) preconditioning + where both query and index gradients are multiplied by H^(-1/2). + * ``-1`` — H^(-1), used for one-sided preconditioning where only + the query gradients are preconditioned. + """ + if preconditioner_path is None: + return {} + + preconditioners = GradientProcessor.load( + Path(preconditioner_path), + map_location=device, + ).preconditioners + + final_dtype = return_dtype or next(iter(preconditioners.values())).dtype + + return { + name: damped_psd_power(H.to(device=device), power=power).to(final_dtype) + for name, H in preconditioners.items() + } + + +def precondition_flat_grads( + grads: torch.Tensor, + h_inv: dict[str, torch.Tensor], + ordered_modules: list[str], + batch_size: int = 8192, +) -> torch.Tensor: + """Precondition flat (concatenated) gradients in-place. + + Uses column offsets to avoid duplicating the full tensor and processes + rows in batches to bound peak memory. Each small ``[batch, d]`` slice is + moved to ``h_inv``'s device for the matmul and written back. + """ + if not h_inv: + return grads + + for start in range(0, grads.shape[0], batch_size): + end = min(start + batch_size, grads.shape[0]) + col = 0 + for name in ordered_modules: + h = h_inv[name] + d = h.shape[0] + grads[start:end, col : col + d] = ( + grads[start:end, col : col + d].to(device=h.device, dtype=h.dtype) @ h + ).to(device=grads.device, dtype=grads.dtype) + col += d + + return grads + + +def precondition_grad( + grad: dict[str, torch.Tensor], + h_inv: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """Precondition a single example's gradients.""" + if not h_inv: + return grad + + return { + name: grad[name].to(device=h_inv[name].device, dtype=h_inv[name].dtype) + @ h_inv[name] + for name in grad.keys() + } + + +def preprocess_grads( + grad_dict: dict[str, torch.Tensor], + grad_column_names: list[str], + unit_normalize: bool, + device: torch.device, + aggregate_grads: Literal["mean", "sum", "none"] = "none", + normalize_aggregated_grad: bool = False, +) -> dict[str, torch.Tensor]: + """Preprocess the gradients. Returns a dictionary of preprocessed gradients + with shape [N, grad_dim] or [1, grad_dim]. Preprocessing includes some + combination of per-item unit normalization, aggregation, aggregated + gradient normalization, and dtype conversion.""" + + # Short-circuit if possible + if aggregate_grads == "none" and not unit_normalize: + return {name: grad_dict[name].to(device=device) for name in grad_column_names} + + grads = { + name: grad_dict[name].to(device=device, dtype=torch.float32) + for name in grad_column_names + } + + # Per-item unit normalization + if unit_normalize: + norms = torch.cat(list(grads.values()), dim=1).norm(dim=1, keepdim=True) + grads = {k: v / norms for k, v in grads.items()} + + # Aggregate across items + if aggregate_grads == "mean": + grads = {name: grads[name].mean(0, keepdim=True) for name in grad_column_names} + elif aggregate_grads == "sum": + grads = {name: grads[name].sum(0, keepdim=True) for name in grad_column_names} + elif aggregate_grads != "none": + raise ValueError(f"Invalid aggregate_grads: {aggregate_grads}") + + # Normalize the aggregated gradient + if normalize_aggregated_grad: + grad_norm = torch.cat( + [grads[name].flatten() for name in grad_column_names], dim=0 + ).norm() + for name in grad_column_names: + grads[name] /= grad_norm + + return grads diff --git a/bergson/query/attributor.py b/bergson/query/attributor.py index 041f530b..2d59a857 100644 --- a/bergson/query/attributor.py +++ b/bergson/query/attributor.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from pathlib import Path -from typing import Generator +from typing import Generator, Literal import torch from torch import Tensor, nn @@ -9,6 +9,7 @@ from bergson.data import load_gradients from bergson.gradients import GradientProcessor from bergson.query.faiss_index import FaissConfig, FaissIndex +from bergson.utils.math import damped_psd_power from bergson.utils.utils import numpy_to_tensor @@ -38,34 +39,62 @@ def scores(self) -> Tensor: class Attributor: + + precondition: Literal["one-sided", "two-sided", "none"] + def __init__( self, index_path: str | Path, device: str = "cpu", dtype: torch.dtype = torch.float32, unit_norm: bool = False, + precondition: bool = False, faiss_cfg: FaissConfig | None = None, ): self.device = device self.dtype = dtype self.unit_norm = unit_norm + + if precondition and unit_norm: + self.precondition = "two-sided" + elif precondition: + self.precondition = "one-sided" + else: + self.precondition = "none" + self.faiss_index = None index_path = Path(index_path) # Load the gradient processor self.processor = GradientProcessor.load(index_path, map_location=device) + # Precompute preconditioners + self.h_inv: dict[str, Tensor] = {} + for name, H in self.processor.preconditioners.items(): + if "two-sided": + # Two-sided: precompute H^(-1) for two-sided application + self.h_inv[name] = damped_psd_power(H, power=-0.5).to(device) + elif "one-sided": + # One-sided: precompute H^(-1) for query-side application in search() + self.h_inv[name] = damped_psd_power(H, power=-1.0).to(device) + # Load the gradients into a FAISS index if faiss_cfg: faiss_index_name = ( f"faiss_{faiss_cfg.index_factory.replace(',', '_')}" f"{'_cosine' if unit_norm else ''}" + f"{'_precondition' if precondition else ''}" ) faiss_path = index_path / faiss_index_name if not (faiss_path / "config.json").exists(): FaissIndex.create_index( - index_path, faiss_path, faiss_cfg, device, unit_norm + index_path, + faiss_path, + faiss_cfg, + device, + unit_norm, + self.h_inv, ) self.faiss_index = FaissIndex( @@ -88,6 +117,17 @@ def __init__( self.ordered_modules = mmap.dtype.names if unit_norm: + if precondition: + # Split: apply H^(-1/2) to index grads before normalization, + # for TrackStar + for name in self.grads: + if name in self.processor.preconditioners: + h_inv = damped_psd_power( + self.processor.preconditioners[name], power=-0.5 + ) + self.grads[name] = self.grads[name].float() @ h_inv.to(device) + self.grads[name] = self.grads[name].to(dtype=dtype) + norm = torch.cat( [self.grads[name] for name in self.ordered_modules], dim=1 ).norm(dim=1, keepdim=True) @@ -131,6 +171,14 @@ def search( for name in self.ordered_modules } + # One- or two-sided preconditioning: apply H^(-1) or H^(-0.5) to query + if self.h_inv: + for name in q: + if name in self.h_inv: + q[name] = q[name].float() @ self.h_inv[name] + q[name] = q[name].to(self.dtype) + + # Preconditioning is applied inside TraceCollector if self.unit_norm: norm = torch.cat(list(q.values()), dim=1).norm(dim=1, keepdim=True) @@ -172,7 +220,6 @@ def trace( module: nn.Module, k: int | None, *, - precondition: bool = False, modules: set[str] | None = None, reverse: bool = False, ) -> Generator[TraceResult, None, None]: @@ -183,7 +230,6 @@ def trace( Args: module: The module to trace. k: The number of nearest examples to return. - precondition: Whether to apply preconditioning. modules: The modules to trace. If None, all modules will be traced. reverse: If True, return the lowest influence examples instead of highest. """ @@ -193,7 +239,6 @@ def trace( collector = TraceCollector( model=module, processor=self.processor, - precondition=precondition, target_modules=modules, device=self.device, dtype=self.dtype, diff --git a/bergson/query/faiss_index.py b/bergson/query/faiss_index.py index a63798f1..92065b1d 100644 --- a/bergson/query/faiss_index.py +++ b/bergson/query/faiss_index.py @@ -12,6 +12,7 @@ from tqdm import tqdm from bergson.config import FaissConfig +from bergson.process_grads import precondition_flat_grads if TYPE_CHECKING: import faiss # noqa: F401 # pyright: ignore[reportMissingImports] @@ -205,6 +206,7 @@ def create_index( faiss_cfg: FaissConfig, device: str, unit_norm: bool, + preconditioners: dict[str, torch.Tensor], ): faiss = _require_faiss() @@ -258,6 +260,9 @@ def build_shard_from_buffer( print(f"Building shard {shard_idx}...") grads_chunk = np.concatenate(buffer_parts, axis=0) + grads_chunk = precondition_flat_grads( + torch.from_numpy(grads_chunk), preconditioners, ordered_modules + ).numpy() buffer_parts.clear() index = faiss.index_factory( diff --git a/bergson/reduce.py b/bergson/reduce.py index 77b0f180..a8229172 100644 --- a/bergson/reduce.py +++ b/bergson/reduce.py @@ -12,7 +12,7 @@ from bergson.collection import collect_gradients from bergson.collector.gradient_collectors import GradientCollector -from bergson.config import IndexConfig, ReduceConfig +from bergson.config import IndexConfig, PreprocessConfig, ReduceConfig from bergson.data import allocate_batches from bergson.utils.utils import assert_type from bergson.utils.worker_utils import setup_model_and_peft @@ -27,6 +27,7 @@ def reduce_worker( world_size: int, index_cfg: IndexConfig, reduce_cfg: ReduceConfig, + preprocess_cfg: PreprocessConfig, ds: Dataset | IterableDataset, ): """ @@ -78,6 +79,7 @@ def reduce_worker( "target_modules": target_modules, "attention_cfgs": attention_cfgs, "reduce_cfg": reduce_cfg, + "preprocess_cfg": preprocess_cfg, } if isinstance(ds, Dataset): @@ -145,7 +147,11 @@ def flush(kwargs): json.dump(metadata, f, indent=2) -def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): +def reduce( + index_cfg: IndexConfig, + reduce_cfg: ReduceConfig, + preprocess_cfg: PreprocessConfig, +): """ Reduce a dataset to a single aggregated gradient vector. @@ -155,7 +161,9 @@ def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): Specifies the run path, dataset, model, tokenizer, PEFT adapters, and many other gradient collection settings. reduce_cfg : ReduceConfig - Specifies aggregation strategy (mean/sum, unit normalization). + Specifies aggregation strategy (mean/sum). + preprocess_cfg : PreprocessConfig + Preprocessing configuration for gradient normalization/preconditioning. """ index_cfg.partial_run_path.mkdir(parents=True, exist_ok=True) with (index_cfg.partial_run_path / "index_config.json").open("w") as f: @@ -164,7 +172,10 @@ def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): ds = setup_data_pipeline(index_cfg) launch_distributed_run( - "reduce", reduce_worker, [index_cfg, reduce_cfg, ds], index_cfg.distributed + "reduce", + reduce_worker, + [index_cfg, reduce_cfg, preprocess_cfg, ds], + index_cfg.distributed, ) if index_cfg.distributed.rank == 0: diff --git a/bergson/score/score.py b/bergson/score/score.py index 624030c1..e3dfbe95 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -4,7 +4,6 @@ from dataclasses import asdict from datetime import timedelta from pathlib import Path -from typing import Literal import numpy as np import torch @@ -13,19 +12,18 @@ from tqdm.auto import tqdm from bergson.collection import collect_gradients -from bergson.config import IndexConfig, ScoreConfig +from bergson.config import IndexConfig, PreprocessConfig, ScoreConfig from bergson.data import ( allocate_batches, load_gradients, ) from bergson.distributed import launch_distributed_run -from bergson.gradients import GradientProcessor +from bergson.process_grads import preprocess_grads from bergson.score.score_writer import ( MemmapSequenceScoreWriter, MemmapTokenScoreWriter, ) from bergson.score.scorer import Scorer -from bergson.utils.math import compute_damped_inverse from bergson.utils.utils import ( assert_type, convert_precision_to_torch, @@ -43,6 +41,7 @@ def create_scorer( data: Dataset, query_grads: dict[str, torch.Tensor], score_cfg: ScoreConfig, + preprocess_cfg: PreprocessConfig, device: torch.device, dtype: torch.dtype, *, @@ -65,108 +64,13 @@ def create_scorer( writer=writer, device=device, dtype=dtype, - unit_normalize=score_cfg.unit_normalize, - score_mode="nearest" if score_cfg.score == "nearest" else "inner_product", + unit_normalize=preprocess_cfg.unit_normalize, + score_mode=score_cfg.score, attribute_tokens=attribute_tokens, + preconditioner_path=preprocess_cfg.preconditioner_path, ) -def preprocess_grads( - grad_dict: dict[str, torch.Tensor], - grad_column_names: list[str], - unit_normalize: bool, - device: torch.device, - accumulate_grads: Literal["mean", "sum", "none"] = "none", - normalize_accumulated_grad: bool = False, -) -> dict[str, torch.Tensor]: - """Preprocess the gradients. Returns a dictionary of preprocessed gradients - with shape [N, grad_dim] or [1, grad_dim]. Preprocessing includes some - combination of per-item unit normalization, accumulation, accumulated - gradient normalization, and dtype conversion.""" - - # Short-circuit if possible - if accumulate_grads == "none" and not unit_normalize: - return {name: grad_dict[name].to(device=device) for name in grad_column_names} - - grads = { - name: grad_dict[name].to(device=device, dtype=torch.float32) - for name in grad_column_names - } - - # Per-item unit normalization - if unit_normalize: - norms = torch.cat(list(grads.values()), dim=1).norm(dim=1, keepdim=True) - grads = {k: v / norms for k, v in grads.items()} - - # Accumulate across items - if accumulate_grads == "mean": - grads = {name: grads[name].mean(0, keepdim=True) for name in grad_column_names} - elif accumulate_grads == "sum": - grads = {name: grads[name].sum(0, keepdim=True) for name in grad_column_names} - elif accumulate_grads != "none": - raise ValueError(f"Invalid accumulate_grads: {accumulate_grads}") - - # Normalize the accumulated gradient - if normalize_accumulated_grad: - grad_norm = torch.cat( - [grads[name].flatten() for name in grad_column_names], dim=0 - ).norm() - for name in grad_column_names: - grads[name] /= grad_norm - - return grads - - -def precondition_grads( - grads: dict[str, torch.Tensor], - score_cfg: ScoreConfig, - target_modules: list[str], - device: torch.device, -) -> dict[str, torch.Tensor]: - """Precondition query gradients with the query and/or index preconditioners.""" - use_q = score_cfg.query_preconditioner_path is not None - use_i = score_cfg.index_preconditioner_path is not None - - if use_q or use_i: - q, i = {}, {} - if use_q: - assert score_cfg.query_preconditioner_path is not None - q = GradientProcessor.load( - Path(score_cfg.query_preconditioner_path), - map_location=device, - ).preconditioners - if use_i: - assert score_cfg.index_preconditioner_path is not None - i = GradientProcessor.load( - Path(score_cfg.index_preconditioner_path), map_location=device - ).preconditioners - - mixed_preconditioner = ( - { - k: q[k] * score_cfg.mixing_coefficient - + i[k] * (1 - score_cfg.mixing_coefficient) - for k in q - } - if (q and i) - else (q or i) - ) - - # Compute H^(-1) via eigendecomposition and apply to query gradients - h_inv = { - name: compute_damped_inverse(H.to(device=device)) - for name, H in mixed_preconditioner.items() - } - - grads = { - name: ( - grads[name].to(device=device, dtype=h_inv[name].dtype) @ h_inv[name] - ).cpu() - for name in target_modules - } - - return {name: grads[name] for name in score_cfg.modules} - - def get_query_grads(score_cfg: ScoreConfig) -> dict[str, torch.Tensor]: """ Load query gradients from the mmap index and return as a dict of tensors. @@ -213,6 +117,7 @@ def score_worker( world_size: int, index_cfg: IndexConfig, score_cfg: ScoreConfig, + preprocess_cfg: PreprocessConfig, ds: Dataset | IterableDataset, query_grads: dict[str, torch.Tensor], ): @@ -232,6 +137,8 @@ def score_worker( score_cfg : ScoreConfig Score configuration specifying query path, target modules, and scoring method (mean/nearest/individual). + preprocess_cfg : PreprocessConfig + Preprocessing configuration for gradient normalization/preconditioning. ds : Dataset | IterableDataset The entire dataset to be indexed. A subset is assigned to each worker. query_grads : dict[str, torch.Tensor] @@ -285,6 +192,7 @@ def score_worker( ds, query_grads, score_cfg, + preprocess_cfg, device=score_device, dtype=score_dtype, attribute_tokens=index_cfg.attribute_tokens, @@ -311,6 +219,7 @@ def flush(kwargs): ds_shard, query_grads, score_cfg, + preprocess_cfg, device=score_device, dtype=score_dtype, ) @@ -333,6 +242,7 @@ def flush(kwargs): def score_dataset( index_cfg: IndexConfig, score_cfg: ScoreConfig, + preprocess_cfg: PreprocessConfig, preprocess_device=torch.device("cuda:0"), ): """ @@ -346,6 +256,8 @@ def score_dataset( score_cfg : ScoreConfig Specifies the query path, target modules, and scoring method (mean/nearest/individual). + preprocess_cfg : PreprocessConfig + Preprocessing configuration for gradient normalization/preconditioning. """ index_cfg.partial_run_path.mkdir(parents=True, exist_ok=True) with (index_cfg.partial_run_path / "index_config.json").open("w") as f: @@ -356,22 +268,18 @@ def score_dataset( ds = setup_data_pipeline(index_cfg) query_grads = get_query_grads(score_cfg) - query_grads = precondition_grads( - query_grads, score_cfg, score_cfg.modules, preprocess_device - ) + query_grads = preprocess_grads( query_grads, score_cfg.modules, - score_cfg.unit_normalize, + preprocess_cfg.unit_normalize, preprocess_device, - accumulate_grads="mean" if score_cfg.score == "mean" else "none", - normalize_accumulated_grad=score_cfg.score == "mean", ) launch_distributed_run( "score", score_worker, - [index_cfg, score_cfg, ds, query_grads], + [index_cfg, score_cfg, preprocess_cfg, ds, query_grads], index_cfg.distributed, ) diff --git a/bergson/score/scorer.py b/bergson/score/scorer.py index 5a05812d..5ea8a056 100644 --- a/bergson/score/scorer.py +++ b/bergson/score/scorer.py @@ -1,5 +1,6 @@ import torch +from bergson.process_grads import get_trackstar_preconditioner from bergson.score.score_writer import ScoreWriter @@ -7,6 +8,11 @@ class Scorer: """ Scores training gradients against query gradients. + Handles preconditioning internally: + - Loads preconditioner from disk if ``preconditioner_path`` is given. + - Applies to query grads once at init time. + - Applies to index grads per-batch in :meth:`score` (split mode only). + Accepts a ScoreWriter for saving the scores (disk or in-memory). """ @@ -19,8 +25,9 @@ def __init__( dtype: torch.dtype, *, unit_normalize: bool = False, - score_mode: str = "inner_product", + score_mode: str = "individual", attribute_tokens: bool = False, + preconditioner_path: str | None = None, ): """ Initialize the scorer. @@ -40,9 +47,17 @@ def __init__( unit_normalize : bool Whether to unit normalize gradients before scoring. score_mode : str - Scoring mode: "inner_product" or "nearest". + Scoring mode: "individual" or "nearest". attribute_tokens : bool Whether gradients are per-token (rows = total_valid tokens). + preconditioner_path : str | None + Path to a saved GradientProcessor. When provided: + + * ``unit_normalize=True`` — loads H^(-1/2) and applies to both + query (here) and index (in :meth:`score`) for split + (two-sided) preconditioning. + * ``unit_normalize=False`` — loads H^(-1) and applies to query + only for one-sided preconditioning. """ self.device = device self.dtype = dtype @@ -52,33 +67,78 @@ def __init__( self.attribute_tokens = attribute_tokens self.writer = writer - self.query_tensor = torch.cat( - [query_grads[m].to(device=self.device, dtype=self.dtype) for m in modules], - dim=1, + # Load preconditioner: H^(-1/2) for split, H^(-1) for one-sided + preconditioners = get_trackstar_preconditioner( + preconditioner_path, + device=device, + power=-0.5 if unit_normalize else -1, + return_dtype=dtype, ) + # Stack preconditioners for batched matmul in score(). + # Shape: [n_modules, dim_per_mod, dim_per_mod] + if preconditioners and unit_normalize: + self.preconditioners_shapes = {m: preconditioners[m].shape for m in modules} + self.preconditioners = torch.stack([preconditioners[m] for m in modules]) + else: + self.preconditioners = None + + # Precondition query grads per module, then cat into a single tensor + if preconditioners: + q_list = [ + query_grads[m].to(device=self.device, dtype=self.dtype) + @ preconditioners[m] + for m in modules + ] + else: + q_list = [ + query_grads[m].to(device=self.device, dtype=self.dtype) for m in modules + ] + # Pre-transpose for scoring: [total_dim, n_queries] + self.query_grads_t = torch.cat(q_list, dim=-1).T + def __call__( self, indices: list[int], mod_grads: dict[str, torch.Tensor], ): """Score a batch of training gradients against all queries.""" - # Convert the gradients to the scoring dtype - if next(iter(mod_grads.values())).dtype != self.dtype: - mod_grads = {name: grad.to(self.dtype) for name, grad in mod_grads.items()} - scores = self.score(mod_grads) self.writer(indices, scores) @torch.inference_mode() - def score(self, mod_grads: dict[str, torch.Tensor]) -> torch.Tensor: + def score(self, index_grads: dict[str, torch.Tensor]) -> torch.Tensor: """Compute scores for a batch of gradients.""" - grads = torch.cat([mod_grads[m].to(self.device) for m in self.modules], dim=1) + if self.preconditioners is not None: + # Batched preconditioning: [batch, n_modules, dim] @ [n_modules, dim, dim] + g = torch.stack( + [ + index_grads[m].to(self.device, self.dtype, non_blocking=True) + for m in self.modules + ], + dim=1, + ) + all_index = ( + torch.bmm(g.permute(1, 0, 2), self.preconditioners) + .permute(1, 0, 2) + .reshape(g.shape[0], -1) + ) + else: + all_index = torch.cat( + [ + index_grads[m].to(self.device, self.dtype, non_blocking=True) + for m in self.modules + ], + dim=-1, + ) + + scores = all_index @ self.query_grads_t + if self.unit_normalize: - grads = grads / grads.norm(dim=1, keepdim=True) + i_norm = all_index.pow(2).sum(dim=1).sqrt().clamp_min_(1e-12).unsqueeze(1) + scores.div_(i_norm) if self.score_mode == "nearest": - all_scores = grads @ self.query_tensor.T - return all_scores.max(dim=-1).values + return scores.max(dim=-1).values - return grads @ self.query_tensor.T + return scores diff --git a/bergson/trackstar.py b/bergson/trackstar.py new file mode 100644 index 00000000..738a32dc --- /dev/null +++ b/bergson/trackstar.py @@ -0,0 +1,76 @@ +from copy import deepcopy + +from .build import build +from .config import ( + IndexConfig, + PreprocessConfig, + ScoreConfig, + TrackstarConfig, +) +from .process_grads import mix_preconditioners +from .score.score import score_dataset +from .utils.worker_utils import validate_run_path + + +def trackstar( + index_cfg: IndexConfig, + score_cfg: ScoreConfig, + preprocess_cfg: PreprocessConfig, + trackstar_cfg: TrackstarConfig, +): + """Run the full trackstar pipeline: preconditioners -> mix -> build -> score.""" + run_path = index_cfg.run_path + value_precond_path = f"{run_path}/value_preconditioner" + query_precond_path = f"{run_path}/query_preconditioner" + mixed_precond_path = f"{run_path}/mixed_preconditioner" + query_path = f"{run_path}/query" + scores_path = f"{run_path}/scores" + + # Step 1: Compute normalizers and preconditioners on value dataset + print("Step 1/5: Computing normalizers and preconditioners on value dataset...") + value_precond_cfg = deepcopy(index_cfg) + value_precond_cfg.run_path = value_precond_path + value_precond_cfg.skip_index = True + value_precond_cfg.skip_preconditioners = False + validate_run_path(value_precond_cfg) + build(value_precond_cfg, preprocess_cfg) + + # Step 2: Compute normalizers and preconditioners on query dataset + print("Step 2/5: Computing normalizers and preconditioners on query dataset...") + query_precond_cfg = deepcopy(index_cfg) + query_precond_cfg.run_path = query_precond_path + query_precond_cfg.data = trackstar_cfg.query + query_precond_cfg.skip_index = True + query_precond_cfg.skip_preconditioners = False + validate_run_path(query_precond_cfg) + build(query_precond_cfg, preprocess_cfg) + + # Step 3: Mix query and value preconditioners + print("Step 3/5: Mixing preconditioners...") + mix_preconditioners( + query_path=query_precond_path, + index_path=value_precond_path, + output_path=mixed_precond_path, + mixing_coefficient=trackstar_cfg.mixing_coefficient, + ) + + # Step 4: Build per-item query gradient index + print("Step 4/5: Building query gradient index...") + query_cfg = deepcopy(index_cfg) + query_cfg.run_path = query_path + query_cfg.data = trackstar_cfg.query + query_cfg.processor_path = query_precond_path + query_cfg.skip_preconditioners = True + validate_run_path(query_cfg) + build(query_cfg, preprocess_cfg) + + # Step 5: Score value dataset against query using mixed preconditioner + print("Step 5/5: Scoring value dataset...") + score_index_cfg = deepcopy(index_cfg) + score_index_cfg.run_path = scores_path + score_index_cfg.processor_path = value_precond_path + score_index_cfg.skip_preconditioners = True + score_cfg.query_path = query_path + preprocess_cfg.preconditioner_path = mixed_precond_path + validate_run_path(score_index_cfg) + score_dataset(score_index_cfg, score_cfg, preprocess_cfg) diff --git a/bergson/utils/math.py b/bergson/utils/math.py index 33faf31d..4378a9fc 100644 --- a/bergson/utils/math.py +++ b/bergson/utils/math.py @@ -54,41 +54,52 @@ def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor: @torch.compile -def psd_rsqrt(A: Tensor) -> Tensor: - """Efficiently compute the p.s.d. pseudoinverse sqrt of p.s.d. matrix `A`.""" - L, U = torch.linalg.eigh(A) - L = L[..., None, :].clamp_min(0.0) +def psd_power(H: Tensor, power: float) -> Tensor: + """Compute a pseudoinverse power of p.s.d. matrix `H` via eigendecomposition. - # We actually compute the pseudo-inverse here for numerical stability. + Uses the same tolerance heuristic as `torch.linalg.pinv` to zero out + eigenvalues that are effectively zero, ensuring numerical stability. + + Args: + H: Positive semi-definite matrix. + power: Exponent to apply to eigenvalues (e.g. -0.5 for rsqrt, -1 for inverse). + """ + eigval, eigvec = torch.linalg.eigh(H) + eigval = eigval[..., None, :].clamp_min(0.0) + + # Zero out eigenvalues below the tolerance threshold (pseudoinverse). # Use the same heuristic as `torch.linalg.pinv` to determine the tolerance. - thresh = L[..., None, -1] * A.shape[-1] * torch.finfo(A.dtype).eps - rsqrt = U * torch.where(L > thresh, L.rsqrt(), 0.0) @ U.mH + thresh = eigval[..., None, -1] * H.shape[-1] * torch.finfo(H.dtype).eps + result = eigvec * torch.where(eigval > thresh, eigval.pow(power), 0.0) @ eigvec.mH - return rsqrt + return result -def compute_damped_inverse( +@torch.compile +def damped_psd_power( H: Tensor, + power: float, damping_factor: float = 0.1, dtype: torch.dtype = torch.float64, regularizer: Tensor | None = None, ) -> Tensor: - """Compute H^(-1) with damping for numerical stability. + """Compute a damped power of p.s.d. matrix `H` via eigendecomposition. - Uses eigendecomposition to compute the inverse of a positive semi-definite - matrix with adaptive damping based on the matrix's mean absolute value. + Adds adaptive damping before computing the power to improve numerical stability. Args: - H: Positive semi-definite matrix to invert. - damping_factor: Multiplier for the damping term (default: 0.1). + H: Positive semi-definite matrix. + power: Exponent to apply to eigenvalues (e.g. -0.5 for rsqrt, -1 for inverse). + damping_factor: Multiplier for the damping term (default: 0.1). Set to + 0 to disable damping. dtype: Dtype for intermediate computation (default: float64 for stability). regularizer: Optional matrix to use as regularizer instead of identity. - If provided, computes inv(H + damping_factor * regularizer). + If provided, computes (H + damping_factor * regularizer)^power. If None (default), uses scaled identity: - inv(H + damping_factor * |H|_mean * I). + (H + damping_factor * |H|_mean * I)^power. Returns: - The damped inverse H^(-1) in the original dtype of H. + The damped power of H in the original dtype. """ original_dtype = H.dtype H = H.to(dtype=dtype) @@ -98,8 +109,9 @@ def compute_damped_inverse( else: damping_val = damping_factor * H.abs().mean() H = H + damping_val * torch.eye(H.shape[0], device=H.device, dtype=H.dtype) + eigval, eigvec = torch.linalg.eigh(H) - return (eigvec * (1.0 / eigval) @ eigvec.mT).to(original_dtype) + return (eigvec * eigval.pow(power) @ eigvec.mH).to(original_dtype) def trace(matrices: Tensor) -> Tensor: diff --git a/bergson/utils/worker_utils.py b/bergson/utils/worker_utils.py index 8750aa11..9ae59771 100644 --- a/bergson/utils/worker_utils.py +++ b/bergson/utils/worker_utils.py @@ -1,3 +1,4 @@ +import shutil import warnings from pathlib import Path from typing import cast @@ -26,6 +27,23 @@ from bergson.utils.utils import assert_type, get_layer_list +def validate_run_path(index_cfg: IndexConfig): + """Validate the run path.""" + if index_cfg.distributed.rank != 0: + return + + for path in [Path(index_cfg.run_path), Path(index_cfg.partial_run_path)]: + if not path.exists(): + continue + + if index_cfg.overwrite: + shutil.rmtree(path) + else: + raise FileExistsError( + f"Run path {path} already exists. Use --overwrite to overwrite it." + ) + + def create_normalizers( model: PreTrainedModel, ds: Dataset | IterableDataset, diff --git a/docs/index.rst b/docs/index.rst index c3b833be..b8aa14af 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,22 @@ Benchmarks benchmarks/index +Preprocessing +------------- + +.. toctree:: + :maxdepth: 2 + + preprocessing + +Experiments +----------- + +.. toctree:: + :maxdepth: 2 + + experiments + API Reference -------------- @@ -61,14 +77,6 @@ API Reference api utils -Experiments ------------ - -.. toctree:: - :maxdepth: 2 - - experiments - Content Index ------------------ diff --git a/docs/preprocessing.rst b/docs/preprocessing.rst new file mode 100644 index 00000000..2d2dcb4b --- /dev/null +++ b/docs/preprocessing.rst @@ -0,0 +1,206 @@ +Gradient Preprocessing +====================== + +Bergson supports several gradient preprocessing operations that affect the quality and meaning of similarity scores. This page explains the operations available, when to apply them to query versus index gradients, and walks through concrete use cases. + +Operations +---------- + +**Optimizer normalization** (``--normalizer``): Scales each gradient element by the inverse root-mean-square (RMS) of that parameter's gradient history — i.e., divides by :math:`\sqrt{E[g^2]} + \varepsilon`, where :math:`E[g^2]` is the mean of squared gradients across the dataset. Applied elementwise during gradient collection. Unlike the Adam optimizer used during training, this uses a simple mean over the dataset rather than an exponential moving average. This downweights parameters with large gradient magnitudes and amplifies signal in directions with consistently small gradients. + +**Unit normalization** (``--unit_normalize``): Normalizes each gradient vector to unit L2 norm before similarity computation, enabling cosine similarity when used with inner product scoring. + +**Preconditioning** (``--query_preconditioner_path``, ``--index_preconditioner_path``): Applies a per-module matrix transformation derived from a Hessian approximation (second moment matrix of gradients). For inner product scoring, :math:`H^{-1}` is applied to the query side. For cosine similarity scoring, :math:`H^{-1/2}` must be applied to both sides symmetrically. + +Query vs Index Gradients +------------------------ + +Every similarity computation involves two sides: + +- **Index gradients**: Gradients from the training dataset you want to search. +- **Query gradients**: Gradients from the dataset whose most similar training examples you want to find. + +For a similarity score to be meaningful, preprocessing applied to query and index gradients must be consistent. + +.. list-table:: + :header-rows: 1 + :widths: 35 20 45 + + * - Operation + - Can apply one-sided? + - Notes + * - Optimizer normalization + - Yes + - Apply the same ``--normalizer`` when collecting both query and index gradients + * - Preconditioning (inner product) + - Yes + - :math:`H^{-1}` applied to query only; relative score rankings are preserved + * - Preconditioning (cosine similarity) + - **No** + - :math:`H^{-1/2}` must be applied to **both** sides before unit normalization + * - Unit normalization + - **No** + - Must be applied consistently to both sides + +**Unit normalization is a non-linear operation and does not commute with preconditioning.** When unit normalization is enabled alongside preconditioning, the preconditioner must be applied to both query and index gradients before normalization. Bergson handles this automatically: when ``unit_normalize=True``, it applies :math:`H^{-1/2}` to the query gradient upfront in the ``score`` command and applies :math:`H^{-1/2}` to each index gradient as it is collected during scoring. + +Case Studies +------------ + +Cosine similarity with an optimizer normalizer (full gradients) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Goal:** Rank training examples by cosine similarity to a query, using optimizer-normalized gradients. + +Optimizer normalization scales each parameter's gradient by :math:`1/(\sqrt{v} + \varepsilon)`, where :math:`v = E[g^2]` is the mean of squared gradients across the dataset. Applied before cosine similarity, this reweights the gradient space by the inverse RMS of each parameter's gradient history, emphasizing directions with consistently small-magnitude gradients. + +The normalizer is applied during gradient collection, so the same ``--normalizer`` must be set when collecting both query and index gradients. Unit normalization is then applied at scoring time to obtain cosine similarity. + +.. code-block:: bash + + # Reduce query dataset to a single mean gradient with optimizer normalization + bergson reduce runs/query \ + --model EleutherAI/pythia-160m \ + --dataset query_data \ + --projection_dim 0 \ + --normalizer adafactor \ + --method mean \ + --skip_preconditioners + + # Score: collect training gradients with the same normalizer, unit normalize for cosine similarity + bergson score runs/scores \ + --query_path runs/query \ + --model EleutherAI/pythia-160m \ + --dataset training_data \ + --projection_dim 0 \ + --normalizer adafactor \ + --unit_normalize + +Both commands use ``--projection_dim 0`` to preserve the full gradient, and the same ``--normalizer`` to ensure consistent per-parameter scaling. The ``score`` command applies unit normalization to both the loaded query gradient and each training gradient, giving cosine similarity in the optimizer-normalized space. + +Inner product with an optimizer normalizer (full gradients) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Goal:** Rank training examples by inner product with a query gradient in optimizer-normalized space, approximating the classic influence function. + +The influence function estimates the change in query loss from upweighting a training example as :math:`\partial L_q / \partial \varepsilon_t \approx -g_q H^{-1} g_t^T`, where :math:`H` is the Hessian. The optimizer normalizer provides a diagonal approximation to :math:`H^{-1/2}`, so applying it to both query and index gradients approximates the full influence inner product. + +Unlike cosine similarity, inner product preserves gradient magnitude, so training examples with larger gradients contribute more to the score. + +.. code-block:: bash + + # Reduce query dataset to a single mean gradient with optimizer normalization + bergson reduce runs/query \ + --model EleutherAI/pythia-160m \ + --dataset query_data \ + --projection_dim 0 \ + --normalizer adafactor \ + --method mean \ + --skip_preconditioners + + # Score: inner product (no --unit_normalize) + bergson score runs/scores \ + --query_path runs/query \ + --model EleutherAI/pythia-160m \ + --dataset training_data \ + --projection_dim 0 \ + --normalizer adafactor + +**Inner product vs cosine similarity:** Use inner product when gradient magnitude carries information (larger gradients indicate stronger relevance). Use cosine similarity to compare direction independently of magnitude, which is more robust when examples differ systematically in gradient norm (e.g., due to different sequence lengths or loss scales). + +Randomly projected gradients with reduce and score +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Goal:** Select training examples most similar to a query set using random projection, keeping full-batch scoring tractable for large models. + +Random projections approximately preserve inner products and cosine similarities (Johnson-Lindenstrauss) while reducing gradient dimensionality by orders of magnitude. For large models, full gradients may be gigabytes per example; projecting to a few thousand dimensions makes the ``reduce → score`` pipeline tractable while retaining most of the signal. + +``reduce`` aggregates all query gradients into a single vector (mean or sum) without storing per-example gradients. ``score`` then collects each training gradient on-the-fly and scores it against the precomputed query vector, avoiding the need to build or store a full training gradient index. + +.. code-block:: bash + + # Reduce query dataset to a single mean gradient vector + bergson reduce runs/query \ + --model EleutherAI/pythia-160m \ + --dataset query_data \ + --projection_dim 4096 \ + --method mean \ + --skip_preconditioners + + # Score training data against the reduced query + bergson score runs/scores \ + --query_path runs/query \ + --model EleutherAI/pythia-160m \ + --dataset training_data \ + --projection_dim 4096 + +Both commands must use the same ``--projection_dim`` and identical model configuration so that both sides are projected into the same random subspace. The random projection matrix is derived deterministically from the model architecture and the projection dimension. + +.. note:: + + **Preprocessing order:** Optimizer normalization must be applied during gradient collection (set ``--normalizer`` at both ``reduce`` and ``score`` time). It cannot be applied after the mean-reduction in ``reduce`` - the normalizer is non-linear so applying it to the mean gradient is not the same as normalizing each gradient then taking the mean. + +Randomly projected gradients with unit normalization, preconditioners, build, and score +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Goal:** Compute preconditioner-weighted cosine similarity using random projections. This is the approach used by the ``trackstar`` command. + +When combining preconditioning with cosine similarity, the preconditioner must be applied before unit normalization to both query and index gradients. Bergson applies :math:`H^{-1/2}` to the query gradient at the start of ``score``, and :math:`H^{-1/2}` to each index gradient as it is collected. The resulting score is: + +.. math:: + + g_q^p = g_q \cdot H^{-1/2} + + g_t^p = g_t \cdot H^{-1/2} + + \text{score}(q, t) = \frac{g_q^p}{\|g_q^p\|} \cdot \frac{g_t^p}{\|g_t^p\|} + +This is cosine similarity in the :math:`H^{-1}`-weighted inner product space — the same geometry used by the influence function. + +.. code-block:: bash + + # Step 1: Compute normalizers and preconditioners on the query dataset + bergson preconditioners runs/query_precond \ + --model EleutherAI/pythia-160m \ + --dataset query_data \ + --projection_dim 4096 + + # Step 2: Compute normalizers and preconditioners on the training dataset + bergson preconditioners runs/index_precond \ + --model EleutherAI/pythia-160m \ + --dataset training_data \ + --projection_dim 4096 + + # Step 3: Build per-example query gradient index + # The query normalizer (from runs/query_precond) is applied during collection + bergson build runs/query \ + --model EleutherAI/pythia-160m \ + --dataset query_data \ + --projection_dim 4096 \ + --processor_path runs/query_precond \ + --skip_preconditioners + + # Step 4: Score training data against query + # H^(-1/2) is applied to both query and index gradients, then unit normalized + bergson score runs/scores \ + --query_path runs/query \ + --model EleutherAI/pythia-160m \ + --dataset training_data \ + --projection_dim 4096 \ + --processor_path runs/index_precond \ + --skip_preconditioners \ + --unit_normalize \ + --query_preconditioner_path runs/query_precond \ + --index_preconditioner_path runs/index_precond + +This pipeline is also available as the ``trackstar`` command, which automates the four steps above. See ``bergson trackstar --help`` for the full argument list. + +**Why** :math:`H^{-1/2}` **on both sides?** For inner product scoring, applying :math:`H^{-1}` to one side only is sufficient since the relative ordering of :math:`g_q H^{-1} g_t^T` is preserved. For cosine similarity, the unit normalization would undo a one-sided application: normalizing :math:`g_t` to unit norm discards the preconditioner's geometry. Applying :math:`H^{-1/2}` symmetrically to both sides before normalization preserves the preconditioned structure and ensures the normalization operates in the correct space. + +**Mixing query and index preconditioners:** When query and index datasets come from different distributions, ``--mixing_coefficient`` (default 0.99) interpolates between their second moment matrices (i.e. the empirical Fisher information matrices): + +.. math:: + + H_\text{mixed} = \alpha \cdot H_\text{query} + (1 - \alpha) \cdot H_\text{index} + +Values close to 1.0 weight the query distribution more heavily; values close to 0.0 weight the index distribution. Adjust this according to the guidelines in https://arxiv.org/abs/2410.17413 diff --git a/examples/semantic/asymmetric.py b/examples/semantic/asymmetric.py index 9d366494..795d7d5b 100644 --- a/examples/semantic/asymmetric.py +++ b/examples/semantic/asymmetric.py @@ -452,7 +452,7 @@ def score_asymmetric_eval( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -526,8 +526,8 @@ def score_asymmetric_eval( regularizer = None if reg_proc is not None and name in reg_proc.preconditioners: regularizer = reg_proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse( - H, damping_factor=damping_factor, regularizer=regularizer + h_inv[name] = damped_psd_power( + H, power=-1, damping_factor=damping_factor, regularizer=regularizer ) # Concatenate train gradients @@ -823,7 +823,7 @@ def score_asymmetric_eval_with_pca_projection( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power from .preconditioners import project_orthogonal_to_style_subspace @@ -883,7 +883,7 @@ def score_asymmetric_eval_with_pca_projection( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H, damping_factor=damping_factor) + h_inv[name] = damped_psd_power(H, power=-1, damping_factor=damping_factor) # Concatenate train gradients print("Preparing train gradients...") @@ -1150,7 +1150,7 @@ def score_majority_style_eval( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -1206,7 +1206,7 @@ def score_majority_style_eval( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H) + h_inv[name] = damped_psd_power(H, power=-1) # Concatenate train gradients print("Preparing train gradients...") @@ -1321,7 +1321,7 @@ def score_summed_eval( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -1393,7 +1393,7 @@ def score_summed_eval( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H) + h_inv[name] = damped_psd_power(H, power=-1) # Concatenate train gradients print("Preparing train gradients...") @@ -1975,7 +1975,7 @@ def score_with_inner_product( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -2013,7 +2013,7 @@ def score_with_inner_product( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H) + h_inv[name] = damped_psd_power(H, power=-1) # Concatenate train gradients - NO NORMALIZATION print("Preparing train gradients (no normalization)...") @@ -2367,7 +2367,7 @@ def score_summed_rewrites( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -2432,7 +2432,7 @@ def score_summed_rewrites( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H) + h_inv[name] = damped_psd_power(H, power=-1) # Concatenate train gradients print("Preparing train gradients...") @@ -2542,7 +2542,7 @@ def score_original_style_eval( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -2593,7 +2593,7 @@ def score_original_style_eval( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H) + h_inv[name] = damped_psd_power(H, power=-1) # Concatenate train gradients print("Preparing train gradients...") @@ -2721,7 +2721,7 @@ def _score_single_style_eval( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -2774,7 +2774,7 @@ def _score_single_style_eval( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H) + h_inv[name] = damped_psd_power(H, power=-1) # Prepare train gradients train_grad_list = [] diff --git a/examples/semantic/attribute_preservation.py b/examples/semantic/attribute_preservation.py index 9cfd936c..7118a1d2 100644 --- a/examples/semantic/attribute_preservation.py +++ b/examples/semantic/attribute_preservation.py @@ -729,7 +729,7 @@ def score_attribute_eval( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -792,7 +792,7 @@ def score_attribute_eval( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H, damping_factor=damping_factor) + h_inv[name] = damped_psd_power(H, power=-1, damping_factor=damping_factor) def load_grad_as_float(grads: np.memmap, name: str) -> np.ndarray: g = grads[name] @@ -1765,7 +1765,7 @@ def score_majority_style_eval( from bergson.data import load_gradients from bergson.gradients import GradientProcessor - from bergson.utils.math import compute_damped_inverse + from bergson.utils.math import damped_psd_power base_path = Path(base_path) index_path = base_path / "index" @@ -1818,7 +1818,7 @@ def score_majority_style_eval( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H, damping_factor=damping_factor) + h_inv[name] = damped_psd_power(H, power=-1, damping_factor=damping_factor) def load_grad_as_float(grads: np.memmap, name: str) -> np.ndarray: g = grads[name] diff --git a/examples/semantic/metrics.py b/examples/semantic/metrics.py index c89481af..d6c792e0 100644 --- a/examples/semantic/metrics.py +++ b/examples/semantic/metrics.py @@ -184,18 +184,23 @@ def compute_metrics( exclude_llama: bool = False, query_preconditioner_path: str | None = None, index_preconditioner_path: str | None = None, + mixing_coefficient: float = 0.99, ) -> dict[str, float]: """Compute intra/inter similarities for subject (identifier) and style. Uses bergson score_dataset to compute pairwise similarities instead of custom gradient inner product implementation. + If both query_preconditioner_path and index_preconditioner_path are given, + they are mixed internally using mixing_coefficient before scoring. + Args: index_path: Path to the gradient index. scores_path: Optional path to precomputed scores. exclude_llama: Whether to exclude Llama-generated samples. query_preconditioner_path: Optional path to query preconditioner. index_preconditioner_path: Optional path to index preconditioner. + mixing_coefficient: Weight for the query preconditioner when mixing. Returns: Dictionary of similarity statistics. @@ -214,6 +219,7 @@ def compute_metrics( scores_path, query_preconditioner_path=query_preconditioner_path, index_preconditioner_path=index_preconditioner_path, + mixing_coefficient=mixing_coefficient, ) # Load metadata from HF dataset (fast) diff --git a/examples/semantic/scoring.py b/examples/semantic/scoring.py index 34da210e..2a1bb683 100644 --- a/examples/semantic/scoring.py +++ b/examples/semantic/scoring.py @@ -10,7 +10,8 @@ from bergson.data import load_gradients from bergson.gradients import GradientProcessor -from bergson.utils.math import compute_damped_inverse +from bergson.process_grads import mix_preconditioners +from bergson.utils.math import damped_psd_power def load_scores_matrix(scores_path: Path | str) -> np.ndarray: @@ -106,7 +107,7 @@ def compute_scores_fast( device = torch.device("cuda:0") for name in tqdm(module_names, desc="Computing H^(-1)"): H = proc.preconditioners[name].to(device=device) - h_inv[name] = compute_damped_inverse(H) + h_inv[name] = damped_psd_power(H, power=-1) # Bergson's approach (from score.py): # 1. Query: precondition with H^(-1), then unit normalize @@ -211,8 +212,9 @@ def compute_scores_fast( def compute_scores_with_bergson( index_path: Path | str, output_path: Path | str, - query_preconditioner_path: str | None = None, - index_preconditioner_path: str | None = None, + query_preconditioner_path: str | Path | None = None, + index_preconditioner_path: str | Path | None = None, + mixing_coefficient: float = 0.99, unit_normalize: bool = True, ) -> None: """Run bergson score to compute pairwise similarities. @@ -220,20 +222,41 @@ def compute_scores_with_bergson( NOTE: This recomputes gradients, which is slow. For index-vs-index scoring, use compute_scores_fast() instead. + If both query_preconditioner_path and index_preconditioner_path are given, + they are mixed internally using mixing_coefficient before scoring. + Args: index_path: Path to the gradient index. output_path: Path to save scores. query_preconditioner_path: Optional path to query preconditioner. index_preconditioner_path: Optional path to index preconditioner. + mixing_coefficient: Weight for the query preconditioner when mixing. unit_normalize: Whether to unit normalize gradients. """ output_path = Path(output_path) index_path = Path(index_path) - if output_path.exists(): + if (output_path / "info.json").exists(): print(f"Scores already exist at {output_path}, skipping...") return + # Mix preconditioners if both paths are given, otherwise use whichever is provided + preconditioner_path = None + if query_preconditioner_path and index_preconditioner_path: + mixed_path = output_path / "mixed_preconditioner" + output_path.mkdir(parents=True, exist_ok=True) + mix_preconditioners( + query_preconditioner_path, + index_preconditioner_path, + mixed_path, + mixing_coefficient=mixing_coefficient, + ) + preconditioner_path = str(mixed_path) + elif query_preconditioner_path: + preconditioner_path = str(query_preconditioner_path) + elif index_preconditioner_path: + preconditioner_path = str(index_preconditioner_path) + # Load index config to get model and dataset info with open(index_path / "index_config.json") as f: index_cfg = json.load(f) @@ -269,11 +292,8 @@ def compute_scores_with_bergson( if unit_normalize: cmd.append("--unit_normalize") - if query_preconditioner_path: - cmd.extend(["--query_preconditioner_path", query_preconditioner_path]) - - if index_preconditioner_path: - cmd.extend(["--index_preconditioner_path", index_preconditioner_path]) + if preconditioner_path: + cmd.extend(["--preconditioner_path", preconditioner_path]) print("Running:", " ".join(cmd)) result = subprocess.run(cmd, capture_output=True, text=True) diff --git a/tests/test_attribute_tokens.py b/tests/test_attribute_tokens.py index 527bcfa1..5ad91f4b 100644 --- a/tests/test_attribute_tokens.py +++ b/tests/test_attribute_tokens.py @@ -16,9 +16,10 @@ fit_normalizers, load_token_gradients, ) +from bergson.builders import TokenBuilder from bergson.collector.gradient_collectors import GradientCollector from bergson.config import IndexConfig -from bergson.data import TokenBuilder, compute_num_token_grads, create_token_index +from bergson.data import compute_num_token_grads, create_token_index from bergson.score.score_writer import MemmapTokenScoreWriter from bergson.score.scorer import Scorer from bergson.utils.utils import convert_dtype_to_np, get_gradient_dtype diff --git a/tests/test_attributor.py b/tests/test_attributor.py index 9c6052bd..7c125359 100644 --- a/tests/test_attributor.py +++ b/tests/test_attributor.py @@ -82,6 +82,58 @@ def test_faiss(tmp_path: Path, model, dataset): assert result.scores[0, 1].item() < 0.50 # Different item +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attributor_precondition_split(tmp_path: Path, model, dataset): + """Test split preconditioning (unit_norm=True): H^(-1/2) on both query and index.""" + cfg = IndexConfig(run_path=str(tmp_path), token_batch_size=1024) + + collect_gradients( + model=model, + data=dataset, + processor=GradientProcessor(), + cfg=cfg, + ) + + attr = Attributor( + cfg.partial_run_path, device="cpu", unit_norm=True, precondition=True + ) + + x = torch.tensor(dataset[0]["input_ids"]).unsqueeze(0) + + with attr.trace(model.base_model, 5) as result: + model(x, labels=x).loss.backward() + model.zero_grad() + + assert result.scores[0, 0].item() > 0.90 # Same item, top match + assert result.indices[0, 0].item() == 0 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attributor_precondition_one_sided(tmp_path: Path, model, dataset): + """Test one-sided preconditioning (unit_norm=False): H^(-1) on query only.""" + cfg = IndexConfig(run_path=str(tmp_path), token_batch_size=1024) + + collect_gradients( + model=model, + data=dataset, + processor=GradientProcessor(), + cfg=cfg, + ) + + attr = Attributor( + cfg.partial_run_path, device="cpu", unit_norm=False, precondition=True + ) + + x = torch.tensor(dataset[0]["input_ids"]).unsqueeze(0) + + with attr.trace(model.base_model, 5) as result: + model(x, labels=x).loss.backward() + model.zero_grad() + + # Same item should still be the top match + assert result.indices[0, 0].item() == 0 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_attributor_reverse(tmp_path: Path, model, dataset): """Test that reverse mode returns lowest influence examples.""" diff --git a/tests/test_reduce.py b/tests/test_reduce.py index a3e8a5b2..14ddff8d 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -1,4 +1,3 @@ -import shutil import subprocess from pathlib import Path @@ -8,17 +7,15 @@ from bergson import ( CollectorComputer, DataConfig, + GradientProcessor, IndexConfig, InMemoryCollector, + PreprocessConfig, ReduceConfig, + collect_gradients, ) -from bergson.data import allocate_batches, load_gradient_dataset +from bergson.data import load_gradient_dataset from bergson.reduce import reduce -from bergson.utils.worker_utils import ( - create_processor, - setup_data_pipeline, - setup_model_and_peft, -) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -73,8 +70,9 @@ def test_programmatic_reduce(tmp_path: Path): token_batch_size=1024, ) reduce_cfg = ReduceConfig() + preprocess_cfg = PreprocessConfig() - reduce(index_cfg, reduce_cfg) + reduce(index_cfg, reduce_cfg, preprocess_cfg) # Assert 1-row reduction exists at the tmp_path ds = load_gradient_dataset(Path(index_cfg.run_path), structured=False) @@ -82,43 +80,67 @@ def test_programmatic_reduce(tmp_path: Path): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_in_memory_reduce(tmp_path: Path): - index_cfg = IndexConfig( +def test_reduce_with_preconditioning(tmp_path: Path, model, dataset): + # Step 1: build an index WITH preconditioners + build_cfg = IndexConfig(run_path=str(tmp_path / "build"), token_batch_size=1024) + + collect_gradients( + model=model, + data=dataset, + processor=GradientProcessor(), + cfg=build_cfg, + ) + + # Step 2: reduce with preconditioning pointing at the built index + reduce_cfg = ReduceConfig() + preprocess_cfg = PreprocessConfig( + preconditioner_path=str(build_cfg.partial_run_path) + ) + reduce_index_cfg = IndexConfig( + run_path=str(tmp_path / "reduce_precond"), + token_batch_size=1024, + skip_preconditioners=True, + ) + + collect_gradients( + model=model, + data=dataset, + processor=GradientProcessor(), + cfg=reduce_index_cfg, + reduce_cfg=reduce_cfg, + preprocess_cfg=preprocess_cfg, + ) + + ds_out = load_gradient_dataset(reduce_index_cfg.partial_run_path, structured=False) + assert len(ds_out) == 1 + grads = torch.tensor(ds_out["gradients"][:]) + assert not torch.isnan(grads).any() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_in_memory_reduce(tmp_path: Path, model, dataset): + model.cuda() + cfg = IndexConfig( run_path=str(tmp_path / "reduction"), - data=DataConfig(truncation=True, split="train[:100]"), - model="EleutherAI/pythia-14m", skip_preconditioners=True, token_batch_size=1024, ) - reduce_cfg = ReduceConfig() - - ds = setup_data_pipeline(index_cfg) - model, target_modules = setup_model_and_peft(index_cfg) - processor = create_processor(model, ds, index_cfg, target_modules) - batches = allocate_batches(ds["length"][:], index_cfg.token_batch_size) + cfg.partial_run_path.mkdir(parents=True, exist_ok=True) collector = InMemoryCollector( - model=model.base_model, # type: ignore - cfg=index_cfg, - processor=processor, - target_modules=target_modules, - data=ds, - scorer=None, - reduce_cfg=reduce_cfg, + model=model.base_model, + cfg=cfg, + processor=GradientProcessor(), + data=dataset, + reduce_cfg=ReduceConfig(), attention_cfgs={}, ) - computer = CollectorComputer( - model=model, # type: ignore - data=ds, + CollectorComputer( + model=model, + data=dataset, collector=collector, - batches=batches, - cfg=index_cfg, - ) - computer.run_with_collector_hooks(desc="New worker - Collecting gradients") - - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) - - results = collector.gradients + cfg=cfg, + ).run_with_collector_hooks(desc="In-memory reduce") - assert all(len(results[name]) == 1 for name in results.keys()) + assert all(len(collector.gradients[name]) == 1 for name in collector.gradients) diff --git a/tests/test_score.py b/tests/test_score.py index c197a151..b5450358 100644 --- a/tests/test_score.py +++ b/tests/test_score.py @@ -10,18 +10,19 @@ from ml_dtypes import bfloat16 from transformers import AutoConfig, AutoModelForCausalLM -from bergson import ( - GradientProcessor, - collect_gradients, -) +from bergson.collector.collector import CollectorComputer from bergson.collector.gradient_collectors import GradientCollector -from bergson.config import IndexConfig, ScoreConfig -from bergson.data import create_index, load_scores -from bergson.score.score import precondition_grads -from bergson.score.score_writer import MemmapSequenceScoreWriter +from bergson.collector.in_memory_collector import InMemoryCollector +from bergson.config import IndexConfig, ReduceConfig +from bergson.data import create_index +from bergson.gradients import GradientProcessor +from bergson.process_grads import get_trackstar_preconditioner +from bergson.score.score_writer import ( + InMemorySequenceScoreWriter, + MemmapSequenceScoreWriter, +) from bergson.score.scorer import Scorer from bergson.utils.utils import ( - convert_precision_to_torch, get_gradient_dtype, tensor_to_numpy, ) @@ -84,98 +85,97 @@ def test_large_gradients_query(tmp_path: Path, dataset): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_score(tmp_path: Path, model, dataset): + model = model.cuda() processor = GradientProcessor(projection_dim=16) - collector = GradientCollector( - model.base_model, + + # Step 1: Reduce query gradients using InMemoryCollector + reduce_cfg = ReduceConfig(method="mean") + reduce_index_cfg = IndexConfig( + run_path=str(tmp_path / "reduce"), token_batch_size=1024 + ) + reduce_index_cfg.partial_run_path.mkdir(parents=True, exist_ok=True) + + query_collector = InMemoryCollector( + model=model.base_model, data=dataset, - cfg=IndexConfig(run_path=str(tmp_path)), + cfg=reduce_index_cfg, processor=processor, + reduce_cfg=reduce_cfg, ) - shapes = collector.shapes() - cfg = IndexConfig(run_path=str(tmp_path), token_batch_size=1024) - score_cfg = ScoreConfig( - query_path=str(tmp_path / "query_gradient_ds"), - modules=list(shapes.keys()), - score="mean", + computer = CollectorComputer( + model=model, + data=dataset, + collector=query_collector, + cfg=reduce_index_cfg, ) + computer.run_with_collector_hooks(desc="Reducing query gradients") - query_grads = { - module: torch.randn(1, shape.numel()) for module, shape in shapes.items() - } - - score_dtype = ( - convert_precision_to_torch(score_cfg.precision) - if score_cfg.precision != "auto" - else get_gradient_dtype(model) - ) + query_grads = query_collector.gradients + modules = list(query_collector.shapes().keys()) - score_writer = MemmapSequenceScoreWriter( - tmp_path, len(dataset), 1, dtype=score_dtype - ) + # Step 2: Score using InMemoryCollector with scorer + score_dtype = get_gradient_dtype(model) + score_writer = InMemorySequenceScoreWriter(len(dataset), 1, dtype=score_dtype) scorer = Scorer( query_grads=query_grads, - modules=list(shapes.keys()), + modules=modules, writer=score_writer, - device=torch.device("cpu"), + device=torch.device("cuda:0"), dtype=score_dtype, ) - collect_gradients( - model=model, + index_processor = GradientProcessor(projection_dim=16) + index_cfg = IndexConfig(run_path=str(tmp_path / "index"), token_batch_size=1024) + index_cfg.partial_run_path.mkdir(parents=True, exist_ok=True) + + index_collector = InMemoryCollector( + model=model.base_model, data=dataset, - processor=processor, - cfg=cfg, + cfg=index_cfg, + processor=index_processor, scorer=scorer, ) - assert (tmp_path / "info.json").exists() - assert (tmp_path / "scores.bin").exists() - - with open(tmp_path / "info.json", "r") as f: - info = json.load(f) - - scores = load_scores(tmp_path) - - assert len(scores) == len(dataset) - - assert info["num_scores"] == 1 - - assert np.allclose( - scores[:], - np.array( - [ - [ - 1.8334405, - ], - [ - 0.3371131, - ], - ] - ), + computer = CollectorComputer( + model=model, + data=dataset, + collector=index_collector, + cfg=index_cfg, ) + computer.run_with_collector_hooks(desc="Scoring") + + scores = index_collector.scores + assert scores is not None + assert scores.shape == (len(dataset), 1) + assert torch.isfinite(scores).all() + assert not torch.allclose(scores, torch.zeros_like(scores)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_precondition_ds(tmp_path: Path, model, dataset): - cfg = IndexConfig(run_path=str(tmp_path), token_batch_size=1024) - + model = model.cuda() preprocess_device = torch.device("cuda:0") - # Populate and save preconditioners + # Collect gradients and build preconditioners using InMemoryCollector processor = GradientProcessor(projection_dim=16) - collector = GradientCollector( - model.base_model, + build_cfg = IndexConfig(run_path=str(tmp_path / "build"), token_batch_size=1024) + build_cfg.partial_run_path.mkdir(parents=True, exist_ok=True) + + collector = InMemoryCollector( + model=model.base_model, data=dataset, - cfg=cfg, + cfg=build_cfg, processor=processor, ) - collect_gradients( + + computer = CollectorComputer( model=model, data=dataset, - processor=processor, - cfg=cfg, + collector=collector, + cfg=build_cfg, ) + computer.run_with_collector_hooks(desc="Building preconditioners") processor.save(tmp_path) # Produce query gradients dict @@ -184,28 +184,21 @@ def test_precondition_ds(tmp_path: Path, model, dataset): for module, shape in collector.shapes().items() } - # Produce preconditioned query gradients - score_cfg = ScoreConfig( - query_path=str(tmp_path / "query_gradient_ds"), - modules=list(collector.shapes().keys()), - score="mean", - query_preconditioner_path=str(tmp_path), - ) - - preconditioned = precondition_grads( - query_grads, score_cfg, score_cfg.modules, preprocess_device - ) - - # Produce query gradients without preconditioning - score_cfg.query_preconditioner_path = None + target_modules = list(collector.shapes().keys()) - vanilla = precondition_grads( - query_grads, score_cfg, score_cfg.modules, preprocess_device + # Produce preconditioned query gradients + h_inv = get_trackstar_preconditioner( + str(tmp_path), device=preprocess_device, power=-1 ) + preconditioned = { + name: (query_grads[name].to(preprocess_device) @ h_inv[name]).cpu() + for name in target_modules + } - # Compare the two - for name in score_cfg.modules: - assert not torch.allclose(preconditioned[name], vanilla[name]) + # Compare against unpreconditioned — should differ + for name in target_modules: + vanilla = query_grads[name].to(preprocess_device).cpu() + assert not torch.allclose(preconditioned[name], vanilla) def test_memmap_score_writer_bfloat16(tmp_path: Path): @@ -290,3 +283,120 @@ def test_memmap_score_writer_float32(tmp_path: Path): np.testing.assert_array_almost_equal( writer.scores["score_1"][[0, 1]], np.array([2.5, 4.5], dtype=np.float32) ) + + +def test_compute_preconditioner_h_inv(): + """Test that get_trackstar_preconditioner returns empty dict for None path.""" + + # No path → empty dict + result = get_trackstar_preconditioner(None, device=torch.device("cpu"), power=-1) + assert result == {} + + +def test_scorer_preconditioners(tmp_path: Path): + """Test that Scorer applies preconditioners to index grads.""" + + modules = ["mod_a"] + query_grads = {"mod_a": torch.randn(1, 4)} + + # Save a processor with a non-identity preconditioner + proc = GradientProcessor(preconditioners={"mod_a": torch.eye(4) * 2.0}) + precond_path = tmp_path / "preconditioner" + proc.save(precond_path) + + writer = MemmapSequenceScoreWriter( + tmp_path / "scores_with", 2, 1, dtype=torch.float32 + ) + scorer = Scorer( + query_grads=query_grads, + modules=modules, + writer=writer, + device=torch.device("cpu"), + dtype=torch.float32, + preconditioner_path=str(precond_path), + ) + + # Score with preconditioners + mod_grads = {"mod_a": torch.randn(2, 4)} + scores_with = scorer.score(mod_grads) + + # Score without preconditioners + writer_no = MemmapSequenceScoreWriter( + tmp_path / "scores_without", 2, 1, dtype=torch.float32 + ) + scorer_no_precond = Scorer( + query_grads=query_grads, + modules=modules, + writer=writer_no, + device=torch.device("cpu"), + dtype=torch.float32, + ) + scores_without = scorer_no_precond.score(mod_grads) + + # Preconditioner is 2*I, so scores should differ + assert not torch.allclose(scores_with, scores_without) + + +def test_scorer_split_preconditioners(tmp_path: Path): + """Split preconditioning applies H^(-1/2) to both query and index grads, + then unit normalizes.""" + torch.manual_seed(0) + modules = ["mod_a"] + query_grads = {"mod_a": torch.randn(1, 4)} + index_grads = {"mod_a": torch.randn(2, 4)} + + # Save a processor with H = 2*I + proc = GradientProcessor(preconditioners={"mod_a": torch.eye(4) * 2.0}) + precond_path = tmp_path / "preconditioner" + proc.save(precond_path) + + # Score with split preconditioning (unit_normalize=True) + scorer_precond_norm = Scorer( + query_grads=query_grads, + modules=modules, + writer=InMemorySequenceScoreWriter(2, 1, dtype=torch.float32), + device=torch.device("cpu"), + dtype=torch.float32, + unit_normalize=True, + preconditioner_path=str(precond_path), + ) + scores_precond_norm = scorer_precond_norm.score(index_grads) + + # Score with unit_normalize=True but no preconditioner + scorer_norm = Scorer( + query_grads=query_grads, + modules=modules, + writer=InMemorySequenceScoreWriter(2, 1, dtype=torch.float32), + device=torch.device("cpu"), + dtype=torch.float32, + unit_normalize=True, + ) + scores_norm = scorer_norm.score(index_grads) + + # Score with preconditioner but unit_normalize=False (one-sided) + scorer_inner_products = Scorer( + query_grads=query_grads, + modules=modules, + writer=InMemorySequenceScoreWriter(2, 1, dtype=torch.float32), + device=torch.device("cpu"), + dtype=torch.float32, + unit_normalize=False, + preconditioner_path=str(precond_path), + ) + scores_inner_products = scorer_inner_products.score(index_grads) + + # Split preconditioning should differ from both: + # - unit_normalize without preconditioner (preconditioner changes the space) + # - one-sided preconditioning (different power and normalization) + assert not torch.allclose(scores_precond_norm, scores_norm) + assert not torch.allclose(scores_precond_norm, scores_inner_products) + + # Verify split math: H^(-1/2) applied to both sides + unit normalize + mod_idx = scorer_precond_norm.modules.index("mod_a") + assert scorer_precond_norm.preconditioners is not None + h_inv_sqrt = scorer_precond_norm.preconditioners[mod_idx] + q = query_grads["mod_a"] @ h_inv_sqrt # preconditioned query + g = index_grads["mod_a"] @ h_inv_sqrt # preconditioned index + g = g / g.norm(dim=1, keepdim=True) # unit normalize + expected = g @ q.T + assert torch.allclose(scores_precond_norm, expected, atol=1e-6)