From 97fe18f12ecb6c90fe75dd1075be72b6b600205b Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 6 Mar 2026 18:08:23 +1100 Subject: [PATCH 01/17] Add MAGIC CLI, runtime DTensor double-backward patch and per-token weight support - Add bergson/magic_patch.py: runtime monkey-patch for twice-differentiable DTensor redistribution (pytorch/pytorch#160509), replacing the old magic_wmdp_setup.sh that modified torch source files on disk - Add per_token mode to DataStream for [n_examples, max_length] weight tensors - Support 2D [B, T] per-token weights in weighted_causal_lm_ce - Fix backward weight_grads accumulation when autograd returns None --- bergson/__main__.py | 18 +- bergson/double_backward.py | 314 +++++++++++++++++++++++++++ bergson/magic_patch.py | 202 +++++++++++++++++ bergson/trainer.py | 13 +- bergson/utils/math.py | 13 +- docs/index.rst | 1 + docs/magic.rst | 83 +++++++ examples/double_backward.py | 209 ------------------ examples/double_backward_pretrain.py | 50 +++++ 9 files changed, 687 insertions(+), 216 deletions(-) create mode 100644 bergson/double_backward.py create mode 100644 bergson/magic_patch.py create mode 100644 docs/magic.rst delete mode 100644 examples/double_backward.py create mode 100644 examples/double_backward_pretrain.py diff --git a/bergson/__main__.py b/bergson/__main__.py index 79ec970a..b02ac2a9 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -8,6 +8,7 @@ from .build import build from .config import ( + DistributedConfig, HessianConfig, IndexConfig, QueryConfig, @@ -15,6 +16,7 @@ ScoreConfig, TrackstarConfig, ) +from .double_backward import DoubleBackwardConfig, double_backward from .hessians.hessian_approximations import approximate_hessians from .query.query_index import query from .reduce import reduce @@ -194,11 +196,25 @@ def execute(self): score_dataset(score_index_cfg, self.score_cfg) +@dataclass +class Magic: + """Run MAGIC attribution.""" + + run_cfg: DoubleBackwardConfig + dist_cfg: DistributedConfig + + def execute(self): + """Run MAGIC attribution.""" + double_backward(self.run_cfg, self.dist_cfg) + + @dataclass class Main: """Routes to the subcommands.""" - command: Union[Build, Query, Preconditioners, Reduce, Score, Hessian, Trackstar] + command: Union[ + Build, Query, Preconditioners, Reduce, Score, Hessian, Trackstar, Magic + ] def execute(self): """Run the script.""" diff --git a/bergson/double_backward.py b/bergson/double_backward.py new file mode 100644 index 00000000..e4d4dd94 --- /dev/null +++ b/bergson/double_backward.py @@ -0,0 +1,314 @@ +import os +from dataclasses import dataclass +from datetime import timedelta +from typing import Literal + +import torch +import torch.distributed as dist +import torchopt +from scipy.stats import spearmanr +from simple_parsing import ArgumentParser, field +from torch.distributed.tensor import init_device_mesh +from torchopt.pytree import tree_iter +from torchopt.typing import Numeric +from transformers import AutoModelForCausalLM, AutoTokenizer + +from bergson.config import DataConfig, DistributedConfig +from bergson.data import load_data_string +from bergson.distributed import grad_tree, launch_distributed_run, simple_fsdp +from bergson.magic_patch import apply_dtensor_patch +from bergson.trainer import BackwardState, DataStream, Trainer, TrainerState +from bergson.utils.math import weighted_causal_lm_ce + + +@dataclass +class DoubleBackwardConfig: + run_path: str = field(positional=True) + """Directory to save checkpoints and results.""" + + model: str = "EleutherAI/pythia-160m" + """HuggingFace model name.""" + + revision: str | None = None + """Model revision (branch, tag, or commit hash).""" + + data: DataConfig = field(default_factory=DataConfig) + """Training dataset.""" + + query: DataConfig = field(default_factory=lambda: DataConfig()) + """Query/eval dataset for computing attribution target gradients. + If not specified, defaults to the training dataset.""" + + query_method: Literal["mean", "sum"] = "mean" + """Method for reducing query gradients across batches.""" + + query_batches: int = 1 + """Number of query batches to use for computing eval gradients.""" + + fsdp: bool = False + """Whether to use FSDP for multi-GPU training.""" + + grad_checkpointing: bool = False + """Whether to use gradient checkpointing during the forward pass.""" + + lr: float = 1e-5 + """Base learning rate after warmup.""" + + warmup_steps: int = 10 + """Number of warmup steps before applying base lr.""" + + batch_size: int = 8 + """Per-device batch size.""" + + num_batches: int = 25 + """Number of training batches.""" + + max_length: int = 256 + """Maximum token sequence length.""" + + num_subsets: int = 100 + """Number of leave-one-out subsets for Spearman correlation.""" + + seed: int = 42 + """Random seed for subset permutation.""" + + +def compute_query_gradients( + trainer: Trainer, + fwd_state: TrainerState, + model: torch.nn.Module, + query_stream: DataStream, + method: str = "mean", +) -> dict[str, torch.Tensor]: + """Compute reduced query gradients over the query dataset. + + Iterates over the query stream, computing per-batch parameter gradients + and reducing them (mean or sum) into a single gradient dict. + """ + grad_accum: dict[str, torch.Tensor] | None = None + n_batches = 0 + + with fwd_state.activate(model) as params: + for batch in query_stream: + del batch["example_weight"] + loss = model(**batch).loss + grads = grad_tree(loss, params) + + if grad_accum is None: + grad_accum = {k: g.detach().clone() for k, g in grads.items()} + else: + for k, g in grads.items(): + grad_accum[k] += g.detach() + n_batches += 1 + + assert grad_accum is not None, "Query stream was empty" + + if method == "mean": + for k in grad_accum: + grad_accum[k] /= n_batches + + return grad_accum + + +def worker( + global_rank: int, + rank: int, + world_size: int, + train_dataset, + query_dataset, + run_cfg: DoubleBackwardConfig, +): + torch.cuda.set_device(rank) + + model = AutoModelForCausalLM.from_pretrained( + run_cfg.model, + revision=run_cfg.revision, + torch_dtype=torch.float32, + attn_implementation="eager", + ) + model.loss_function = weighted_causal_lm_ce + model.to(f"cuda:{rank}") + if run_cfg.grad_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=dict(use_reentrant=False), + ) + + processor = AutoTokenizer.from_pretrained(run_cfg.model) + processor.pad_token = processor.eos_token + + if world_size > 1: + addr = os.environ.get("MASTER_ADDR", "localhost") + port = os.environ.get("MASTER_PORT", "29500") + + dist.init_process_group( + "cpu:gloo,cuda:nccl", + init_method=f"tcp://{addr}:{port}", + device_id=torch.device(f"cuda:{rank}"), + rank=rank, + timeout=timedelta(hours=1), + world_size=world_size, + ) + + if run_cfg.fsdp and world_size > 1: + apply_dtensor_patch() + mesh = init_device_mesh("cuda", (world_size,)) + with mesh: + model = simple_fsdp(model) + + def schedule(step: Numeric) -> Numeric: + if step < run_cfg.warmup_steps: + return 0.0 + return run_cfg.lr + + opt = torchopt.adamw( + schedule, + betas=(0.95, 0.975), + eps_root=1e-8, + ) + trainer, fwd_state = Trainer.initialize(model, opt) + + ckpts_path = os.path.join(run_cfg.run_path, "checkpoints") + path0 = os.path.join(ckpts_path, "state0.pt") + save_fut = fwd_state.save(path0) + + stream = DataStream( + train_dataset, + processor, + batch_size=run_cfg.batch_size, + num_batches=run_cfg.num_batches, + device=f"cuda:{rank}", + max_length=run_cfg.max_length, + input_key=run_cfg.data.prompt_column, + ) + fwd_state = trainer.train( + fwd_state, + stream, + inplace=True, + save_dir=ckpts_path, + ) + + # Compute query gradients + query_stream = DataStream( + query_dataset, + processor, + batch_size=run_cfg.batch_size, + num_batches=run_cfg.query_batches, + device=f"cuda:{rank}", + max_length=run_cfg.max_length, + input_key=run_cfg.query.prompt_column, + ) + + query_grads = compute_query_gradients( + trainer, fwd_state, model, query_stream, run_cfg.query_method + ) + + if world_size > 1: + reduce_op = dist.ReduceOp.AVG if run_cfg.query_method == "mean" else dist.ReduceOp.SUM + for v in query_grads.values(): + dist.all_reduce(v, op=reduce_op) + + stream.requires_grad = True + opt_grads = [ + torch.zeros_like(buf) + for buf in tree_iter(fwd_state.opt_state) + if isinstance(buf, torch.Tensor) and buf.is_floating_point() + ] + bwd_state = BackwardState(query_grads, opt_grads, torch.zeros_like(stream.weights)) + + # Compute baseline eval loss for validation + with fwd_state.activate(model): + baseline_batch = query_stream[0] + del baseline_batch["example_weight"] + baseline_loss = model(**baseline_batch).loss + + if world_size > 1: + dist.all_reduce(baseline_loss, op=dist.ReduceOp.AVG) + + bwd_state = trainer.backward( + ckpts_path, + stream, + bwd_state, + fwd_state, + inplace=True, + ) + if world_size > 1: + dist.all_reduce(bwd_state.weight_grads, op=dist.ReduceOp.AVG) + + baseline = baseline_loss.item() + if global_rank == 0: + print(f"Scores: {bwd_state.weight_grads.tolist()}") + print(f"Baseline: {baseline}") + print(f"Grad sum: {bwd_state.weight_grads.sum()}") + + stream.requires_grad = False + + # Validate attribution scores via leave-subset-out retraining + diffs = [] + score_sums = [] + + gen = torch.Generator().manual_seed(run_cfg.seed) + perm = torch.randperm(len(stream.weights), generator=gen) + subsets = perm.chunk(run_cfg.num_subsets) + + save_fut.result() # ensure state0 is saved before loading in loop + fwd_state.load(path0) + + for subset in subsets: + stream.weights.fill_(1.0) + stream.weights[subset] = 0.0 + + for x in stream: + fwd_state = trainer.step(fwd_state, x) + + with fwd_state.activate(model): + eval_batch = query_stream[0] + del eval_batch["example_weight"] + loss = model(**eval_batch).loss + + if world_size > 1: + dist.all_reduce(loss, op=dist.ReduceOp.AVG) + + diffs.append(baseline - loss.item()) + score_sums.append(bwd_state.weight_grads[subset].sum().item()) + + corr = spearmanr(diffs, score_sums) + if global_rank == 0: + print(f"Loss diff: {diffs[-1]}") + print(f"Score: {score_sums[-1]}") + print(f"Spearman correlation: {corr}") + + +def double_backward(run_cfg: DoubleBackwardConfig, dist_cfg: DistributedConfig): + train_ds = load_data_string( + run_cfg.data.dataset, + run_cfg.data.split, + run_cfg.data.subset, + run_cfg.data.data_args, + ) + + query_ds = load_data_string( + run_cfg.query.dataset, + run_cfg.query.split, + run_cfg.query.subset, + run_cfg.query.data_args, + ) + + launch_distributed_run( + "double_backward", worker, [train_ds, query_ds, run_cfg], dist_cfg + ) + + +def main(): + parser = ArgumentParser() + parser.add_arguments(DoubleBackwardConfig, dest="run_cfg") + parser.add_arguments(DistributedConfig, dest="dist_cfg") + args = parser.parse_args() + + run_cfg: DoubleBackwardConfig = args.run_cfg + dist_cfg: DistributedConfig = args.dist_cfg + + double_backward(run_cfg, dist_cfg) + + +if __name__ == "__main__": + main() diff --git a/bergson/magic_patch.py b/bergson/magic_patch.py new file mode 100644 index 00000000..b22b9295 --- /dev/null +++ b/bergson/magic_patch.py @@ -0,0 +1,202 @@ +"""Runtime monkey-patch for twice-differentiable DTensor redistribution. + +Implements pytorch/pytorch#160509 at runtime, avoiding the need to modify +torch source files on disk. Call `apply_dtensor_patch()` before any DTensor +operations that require double backward (e.g. MAGIC attribution with FSDP). + +Safe to call multiple times (idempotent). +""" + +import torch + +_PATCHED = False + + +def apply_dtensor_patch(): + """Patch DTensor redistribution to support double backward. + + Monkey-patches `Redistribute.backward` and `_ToTorchTensor.backward` + in the installed torch package so that FSDP redistribution is + twice-differentiable. + """ + global _PATCHED + if _PATCHED: + return + + _patch_redistribute() + _patch_to_torch_tensor() + _PATCHED = True + + +def _patch_redistribute(): + import torch.distributed.tensor._api as dtensor + from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta + from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, + ) + from torch.distributed.tensor.placement_types import Replicate + + def _redistribute_backward( + grad_output, + previous_spec, + original_dtype=None, + backward_dtype=None, + async_op=False, + ): + if backward_dtype != grad_output._local_tensor.dtype: + local_tensor = grad_output._local_tensor.to(dtype=backward_dtype) + current_spec = DTensorSpec( + mesh=grad_output._spec.device_mesh, + placements=grad_output._spec.placements, + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=backward_dtype, + ), + ) + previous_spec = DTensorSpec( + mesh=previous_spec.device_mesh, + placements=previous_spec.placements, + tensor_meta=current_spec.tensor_meta, + ) + else: + local_tensor = grad_output._local_tensor + current_spec = grad_output._spec + + normalized_placements = [] + for current, target in zip(current_spec.placements, previous_spec.placements): + if (current.is_shard() or current.is_replicate()) and target.is_partial(): + normalized_placements.append(Replicate()) + else: + normalized_placements.append(target) + + previous_spec = DTensorSpec( + previous_spec.device_mesh, + placements=tuple(normalized_placements), + tensor_meta=previous_spec.tensor_meta, + ) + + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + is_backward=True, + ) + + if output.dtype != original_dtype: + output = output.to(original_dtype) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=output.dtype, + ), + ) + return output, spec + + class NestedRedistribute(torch.autograd.Function): + """Makes DTensor redistribution twice-differentiable. + + Called during Redistribute.backward (first backward pass). + NestedRedistribute.backward handles the second backward pass. + """ + + @staticmethod + def forward( + ctx, + grad_output, + previous_spec, + async_op=False, + backward_dtype=None, + original_dtype=None, + ): + ctx.async_op = async_op + ctx.backward_dtype = backward_dtype or original_dtype + ctx.original_dtype = grad_output._local_tensor.dtype + + output, spec = _redistribute_backward( + grad_output, + previous_spec, + ctx.original_dtype, + backward_dtype, + async_op, + ) + + ctx.current_spec = spec + + return dtensor.DTensor( + output, + spec, + requires_grad=grad_output.requires_grad, + ) + + @staticmethod + def backward(ctx, grad2_output): + output_dtensor = NestedRedistribute.apply( + grad2_output, + ctx.current_spec, + ctx.async_op, + ctx.backward_dtype, + ctx.original_dtype, + ) + + return (output_dtensor, None, None, None, None) + + @staticmethod + def _new_redistribute_backward(ctx, grad_output): + previous_spec = ctx.current_spec + output_dtensor = NestedRedistribute.apply( + grad_output, + previous_spec, + ctx.async_op, + ctx.backward_dtype, + ctx.original_dtype, + ) + return (output_dtensor, None, None, None, None, None) + + Redistribute.backward = _new_redistribute_backward + + +def _patch_to_torch_tensor(): + import torch.distributed.tensor._api as dtensor_api + from torch.distributed.tensor._api import _ToTorchTensor + from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta + from torch.distributed.tensor._utils import compute_global_tensor_info + + @staticmethod + def _new_backward(ctx, grad_output): + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + + return ( + dtensor_api.DTensor.from_local( + grad_output, + grad_spec.device_mesh, + grad_spec.placements, + ), + None, + ) + + _ToTorchTensor.backward = _new_backward diff --git a/bergson/trainer.py b/bergson/trainer.py index af38f465..77ab11cc 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -60,11 +60,13 @@ def __init__( device: torch.device | str = "cpu", input_key: str = "text", max_length: int = 256, + per_token: bool = False, ): self.dataset = dataset self.processor = processor self.input_key = input_key self.max_length = max_length + self.per_token = per_token self.batch_size = batch_size self.device = device @@ -79,7 +81,8 @@ def __init__( ) n = self.batch_size * self.num_batches - self.weights = nn.Parameter(torch.ones(n, device=device)) + shape = (n, max_length) if per_token else (n,) + self.weights = nn.Parameter(torch.ones(*shape, device=device)) @property def requires_grad(self) -> bool: @@ -114,10 +117,14 @@ def __getitem__(self, i: int) -> dict: truncation=True, ) x["labels"] = x["input_ids"] - x["example_weight"] = self.weights[ + w = self.weights[ i * self.batch_size + self.rank : (i + 1) * self.batch_size : self.world_size ] + if self.per_token and w.ndim == 2: + T_batch = x["input_ids"].shape[1] + w = w[:, :T_batch] + x["example_weight"] = w return {k: v.to(self.device) for k, v in x.items()} def __iter__(self): @@ -448,7 +455,7 @@ def backward( param_grads = {k: result[i] for i, k in enumerate(p_keys)} del result[: len(p_keys)] - weight_grads = result[-1] + w_grads + weight_grads = result[-1] + w_grads if result[-1] is not None else w_grads bwd_state = BackwardState(param_grads, result[:-1], weight_grads) for fut in save_futures: diff --git a/bergson/utils/math.py b/bergson/utils/math.py index 0abeffed..622006a0 100644 --- a/bergson/utils/math.py +++ b/bergson/utils/math.py @@ -19,7 +19,7 @@ def weighted_causal_lm_ce( Args: logits : [B, T, V] float tensor of prediction scores labels : [B, T] long tensor of target token ids, or ignore_index - example_weight : [B] float tensor of per-example weights + example_weight : [B] or [B, T] float tensor of weights ignore_index : int, label value to ignore in loss computation vocab_size : optional int, vocabulary size (for validation) """ @@ -27,7 +27,10 @@ def weighted_causal_lm_ce( B, T, V = logits.shape assert labels.shape == (B, T) if example_weight is not None: - assert example_weight.shape == (B,) + assert example_weight.ndim in (1, 2), ( + f"example_weight must be 1D [B] or 2D [B, T], " + f"got shape {example_weight.shape}" + ) # HuggingFace always passes a vocab_size kwarg if vocab_size is not None: @@ -51,7 +54,11 @@ def weighted_causal_lm_ce( if example_weight is None: return tok_loss.mean() - w = example_weight.to(tok_loss.dtype).view(B, 1) # [B,1] + if example_weight.ndim == 1: + w = example_weight.to(tok_loss.dtype).view(B, 1) # [B,1] + else: + # Per-token weights: shift to align with shifted labels + w = example_weight[:, 1:].to(tok_loss.dtype) # [B, T-1] return (tok_loss * w).mean() diff --git a/docs/index.rst b/docs/index.rst index c3b833be..29e2a41a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -68,6 +68,7 @@ Experiments :maxdepth: 2 experiments + magic Content Index diff --git a/docs/magic.rst b/docs/magic.rst new file mode 100644 index 00000000..6f48c1cb --- /dev/null +++ b/docs/magic.rst @@ -0,0 +1,83 @@ +MAGIC Attribution +================= + +`MAGIC `_ (Model-Agnostic Generation-time Influence via Checkpointing) attributes evaluation loss to individual training examples by backpropagating through the entire training process. Unlike influence functions which use a local approximation, MAGIC computes exact counterfactual attribution by differentiating through checkpointed training steps. + +We provide a `Trainer` class that takes differentiable training steps and handles all three phases of MAGIC attribution. We support FSDP training using the `bergson.magic_patch` runtime patch, which makes PyTorch's DTensor redistribution twice-differentiable (`pytorch/pytorch#160509 `_). The patch is applied in memory, so no torch source files are modified. + +How it works +------------ + +MAGIC attribution has three phases: + +1. **Forward training with checkpoints**: Fine-tune the model, saving intermediate checkpoints at each step. +2. **Evaluate**: Compute the evaluation loss and its gradients with respect to the final model parameters. +3. **Backward through training**: Backpropagate the evaluation gradients through the checkpointed training steps using reverse-mode autodiff, accumulating attribution scores for each training example (or token). + +The ``Trainer`` class handles all three phases. It uses `torchopt `_ for functional (stateless) differentiable optimization. + +Usage +----- + +bergson magic runs/magic-ckpts --dataset NeelNanda/pile-10k --model EleutherAI/pythia-14m + +Core components +^^^^^^^^^^^^^^^ + +**Trainer**: Functional trainer that supports forward training with checkpoints and backward-through-training. + +.. code-block:: python + + from bergson.trainer import Trainer, DataStream, BackwardState, TrainerState + import torchopt + + # Initialize + opt = torchopt.adam(lr=1e-4) + trainer, state = Trainer.initialize(model, opt) + + # Forward training with checkpoints + stream = DataStream(dataset, tokenizer, batch_size=4, num_batches=250, device="cuda") + state = trainer.train(state, stream, save_dir="checkpoints/") + + # Compute eval gradients, then backward through training + bwd_state = trainer.backward("checkpoints/", stream, bwd_state) + scores = bwd_state.weight_grads # attribution scores + +**DataStream**: Wraps a dataset with differentiable per-example (or per-token) weights that receive gradients during the backward pass. + +.. code-block:: python + + # Per-example attribution + stream = DataStream(dataset, tokenizer, batch_size=4, num_batches=250, device="cuda") + + # Per-token attribution + stream = DataStream(dataset, tokenizer, batch_size=4, num_batches=250, device="cuda", per_token=True) + +**DTensor patch**: For multi-GPU runs with FSDP, apply the DTensor patch before any distributed operations: + +.. code-block:: python + + from bergson.magic_patch import apply_dtensor_patch + apply_dtensor_patch() + + # Your MAGIC worker call here + +Per-token vs per-example attribution +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, ``DataStream`` creates a 1D weight tensor ``[n_examples]`` for per-example attribution. With ``per_token=True``, it creates a 2D tensor ``[n_examples, max_length]`` so that each token receives its own attribution score. The ``weighted_causal_lm_ce`` loss function supports both shapes. + +To use per-token attribution, set ``model.loss_function = weighted_causal_lm_ce`` so the model uses the weighted loss during training. + +.. code-block:: python + + from bergson.utils.math import weighted_causal_lm_ce + model.loss_function = weighted_causal_lm_ce + +Key implementation details +-------------------------- + +- **Functional optimization**: ``torchopt.adam`` (or similar) provides a pure-function optimizer whose state is a pytree of tensors. This allows ``torch.autograd.grad`` to differentiate through optimizer updates. +- **Checkpoint strategy**: By default, checkpoints are saved at ``sqrt(N)`` intervals, giving ``O(sqrt(N))`` memory and ``O(N * sqrt(N))`` recomputation cost. +- **FSDP compatibility**: The DTensor runtime patch adds a ``NestedRedistribute`` autograd function that makes the FSDP all-gather/reduce-scatter differentiable through second-order backward passes. +- **Loss weighting**: ``weighted_causal_lm_ce`` multiplies per-token cross-entropy by the DataStream weights before averaging. During backward-through-training, autograd accumulates gradients into these weights, yielding the attribution scores. diff --git a/examples/double_backward.py b/examples/double_backward.py deleted file mode 100644 index 3f079ba8..00000000 --- a/examples/double_backward.py +++ /dev/null @@ -1,209 +0,0 @@ -import os -from dataclasses import dataclass -from datetime import timedelta - -import torch -import torch.distributed as dist -import torchopt -from datasets import load_dataset -from scipy.stats import spearmanr -from simple_parsing import ArgumentParser -from torch.distributed.tensor import init_device_mesh -from torchopt.pytree import tree_iter -from torchopt.typing import Numeric -from transformers import AutoTokenizer, GPTNeoXConfig, GPTNeoXForCausalLM - -from bergson.config import DistributedConfig -from bergson.distributed import grad_tree, launch_distributed_run, simple_fsdp -from bergson.trainer import BackwardState, DataStream, Trainer -from bergson.utils.math import weighted_causal_lm_ce - - -@dataclass -class RunConfig: - model_name: str = "EleutherAI/pythia-160m" - """HuggingFace model name.""" - - dataset_name: str = "EleutherAI/SmolLM2-135M-10B" - """HuggingFace dataset name.""" - - dataset_split: str = "train" - """Dataset split to use.""" - - grad_checkpointing: bool = False - """Whether to use gradient checkpointing during the forward pass.""" - - lr: float = 1e-5 - """Base learning rate after warmup.""" - - warmup_steps: int = 10 - """Number of warmup steps before applying base lr.""" - - batch_size: int = 8 - """Per-device batch size.""" - - num_batches: int = 25 - """Number of training batches.""" - - max_length: int = 256 - """Maximum token sequence length.""" - - save_dir: str = "/mnt/ssd-3/nora/magic-ckpts" - """Directory to save forward pass checkpoints.""" - - num_subsets: int = 100 - """Number of leave-one-out subsets for Spearman correlation.""" - - seed: int = 42 - """Random seed for subset permutation.""" - - -def worker(global_rank: int, rank: int, world_size: int, dataset, run_cfg: RunConfig): - torch.cuda.set_device(rank) - - cfg = GPTNeoXConfig.from_pretrained(run_cfg.model_name, revision="step0") - model = GPTNeoXForCausalLM(cfg) - model.set_attn_implementation("eager") - model.loss_function = weighted_causal_lm_ce - model.to(f"cuda:{rank}") - if run_cfg.grad_checkpointing: - model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=dict(use_reentrant=False), - ) - - processor = AutoTokenizer.from_pretrained(run_cfg.model_name) - processor.pad_token = processor.eos_token - - if world_size > 1: - addr = os.environ.get("MASTER_ADDR", "localhost") - port = os.environ.get("MASTER_PORT", "29500") - - dist.init_process_group( - "cpu:gloo,cuda:nccl", - init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), - rank=rank, - timeout=timedelta(hours=1), - world_size=world_size, - ) - mesh = init_device_mesh("cuda", (world_size,)) - with mesh: - model = simple_fsdp(model) - - def schedule(step: Numeric) -> Numeric: - if step < run_cfg.warmup_steps: - return 0.0 - return run_cfg.lr - - opt = torchopt.adamw( - schedule, - betas=(0.95, 0.975), - eps_root=1e-8, - ) - trainer, fwd_state = Trainer.initialize(model, opt) - - # save state0 - path0 = os.path.join(run_cfg.save_dir, "state0.pt") - save_fut = fwd_state.save(path0) - - stream = DataStream( - dataset, - processor, - batch_size=run_cfg.batch_size, - num_batches=run_cfg.num_batches, - device=f"cuda:{rank}", - max_length=run_cfg.max_length, - ) - fwd_state = trainer.train( - fwd_state, - stream, - inplace=True, - save_dir=run_cfg.save_dir, - ) - - with fwd_state.activate(model) as params: - stream.requires_grad = True - - ex = stream[0] - del ex["example_weight"] - - loss = model(**ex).loss - - grads = grad_tree(loss, params, create_graph=True) - opt_grads = [ - torch.zeros_like(buf) - for buf in tree_iter(fwd_state.opt_state) - if isinstance(buf, torch.Tensor) and buf.is_floating_point() - ] - bwd_state = BackwardState(grads, opt_grads, torch.zeros_like(stream.weights)) - - if world_size > 1: - dist.all_reduce(loss, op=dist.ReduceOp.AVG) - - bwd_state = trainer.backward( - run_cfg.save_dir, - stream, - bwd_state, - fwd_state, - inplace=True, - ) - if world_size > 1: - dist.all_reduce(bwd_state.weight_grads, op=dist.ReduceOp.AVG) - if global_rank == 0: - print(f"Scores 2: {bwd_state.weight_grads.tolist()}") - - baseline = loss.item() - if global_rank == 0: - print(f"Baseline: {baseline}") - print("Grad:", bwd_state.weight_grads.sum()) - - stream.requires_grad = False - - diffs = [] - score_sums = [] - - gen = torch.Generator().manual_seed(run_cfg.seed) - perm = torch.randperm(len(stream.weights), generator=gen) - subsets = perm.chunk(run_cfg.num_subsets) - - save_fut.result() # ensure state0 is saved before loading in loop - fwd_state.load(path0) - - for subset in subsets: - stream.weights.fill_(1.0) - stream.weights[subset] = 0.0 - - for x in stream: - fwd_state = trainer.step(fwd_state, x) - - with fwd_state.activate(model): - loss = model(**stream[0]).loss - - if world_size > 1: - dist.all_reduce(loss, op=dist.ReduceOp.AVG) - - diffs.append(baseline - loss.item()) - score_sums.append(bwd_state.weight_grads[subset].sum().item()) - - corr = spearmanr(diffs, score_sums) - if global_rank == 0: - print(f"Loss diff: {diffs[-1]}") - print(f"Score: {score_sums[-1]}") - print(f"Spearman correlation: {corr}") - - -def main(): - parser = ArgumentParser() - parser.add_arguments(RunConfig, dest="run_cfg") - parser.add_arguments(DistributedConfig, dest="dist_cfg") - args = parser.parse_args() - - run_cfg: RunConfig = args.run_cfg - dist_cfg: DistributedConfig = args.dist_cfg - - ds = load_dataset(run_cfg.dataset_name, split=run_cfg.dataset_split) - launch_distributed_run("double_backward", worker, [ds, run_cfg], dist_cfg) - - -if __name__ == "__main__": - main() diff --git a/examples/double_backward_pretrain.py b/examples/double_backward_pretrain.py new file mode 100644 index 00000000..6b4f1999 --- /dev/null +++ b/examples/double_backward_pretrain.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +"""MAGIC attribution on a pretrained model. + +Trains from a random init checkpoint (Pythia step0), attributes eval loss +to training examples, and validates via leave-subset-out retraining. + +Usage: + python examples/magic_pretrain.py runs/magic_pretrain + + # Or via the CLI: + bergson magic runs/magic_pretrain \ + --model EleutherAI/pythia-160m \ + --revision step0 \ + --data.dataset EleutherAI/SmolLM2-135M-10B \ + --query.dataset EleutherAI/SmolLM2-135M-10B \ + --query.split "train[:1]" +""" + +from bergson.config import DataConfig, DistributedConfig +from bergson.double_backward import DoubleBackwardConfig, double_backward + + +def main(): + run_cfg = DoubleBackwardConfig( + run_path="runs/magic_pretrain", + model="EleutherAI/pythia-160m", + revision="step0", + data=DataConfig( + dataset="EleutherAI/SmolLM2-135M-10B", + split="train", + ), + query=DataConfig( + dataset="EleutherAI/SmolLM2-135M-10B", + split="train[:1]", + ), + query_batches=1, + lr=1e-5, + warmup_steps=10, + batch_size=8, + num_batches=25, + max_length=256, + num_subsets=100, + seed=42, + ) + dist_cfg = DistributedConfig() + double_backward(run_cfg, dist_cfg) + + +if __name__ == "__main__": + main() From 9e3a62a91d914d9c4aa1ad4e23d889ec1f7655cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Mar 2026 00:29:40 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- bergson/double_backward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bergson/double_backward.py b/bergson/double_backward.py index e4d4dd94..523e1bb1 100644 --- a/bergson/double_backward.py +++ b/bergson/double_backward.py @@ -203,7 +203,9 @@ def schedule(step: Numeric) -> Numeric: ) if world_size > 1: - reduce_op = dist.ReduceOp.AVG if run_cfg.query_method == "mean" else dist.ReduceOp.SUM + reduce_op = ( + dist.ReduceOp.AVG if run_cfg.query_method == "mean" else dist.ReduceOp.SUM + ) for v in query_grads.values(): dist.all_reduce(v, op=reduce_op) From dca41d94633b2cc56d6d2c3407b6cd7278960723 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 12:03:30 +1100 Subject: [PATCH 03/17] Drop no_dist from dcp.async_save and dcp.load for torch 2.8 compat Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index 77ab11cc..ee6daacc 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -167,7 +167,6 @@ def load(self, path: str): dcp.load( self.state_dict(), checkpoint_id=path, - no_dist=not dist.is_initialized(), ) def save(self, path: str) -> Future: @@ -181,7 +180,6 @@ def save(self, path: str) -> Future: fut = dcp.async_save( self.state_dict(), checkpoint_id=path, - no_dist=grp is None, process_group=grp, ) assert isinstance(fut, Future) From b83d3f6d0ca48a29ade7c519dd253bbb39d13350 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 12:04:35 +1100 Subject: [PATCH 04/17] Revert: restore no_dist in dcp.async_save and dcp.load Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bergson/trainer.py b/bergson/trainer.py index ee6daacc..77ab11cc 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -167,6 +167,7 @@ def load(self, path: str): dcp.load( self.state_dict(), checkpoint_id=path, + no_dist=not dist.is_initialized(), ) def save(self, path: str) -> Future: @@ -180,6 +181,7 @@ def save(self, path: str) -> Future: fut = dcp.async_save( self.state_dict(), checkpoint_id=path, + no_dist=grp is None, process_group=grp, ) assert isinstance(fut, Future) From 491bf063e87bea9447f3ac32120992428505ecaa Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 12:11:35 +1100 Subject: [PATCH 05/17] Remove unnecessary None guard on weight_grads accumulation The weight gradient from autograd.grad should always be a tensor since data.weights participates in the computation graph via weighted_causal_lm_ce. Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index 77ab11cc..68ceda26 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -455,7 +455,7 @@ def backward( param_grads = {k: result[i] for i, k in enumerate(p_keys)} del result[: len(p_keys)] - weight_grads = result[-1] + w_grads if result[-1] is not None else w_grads + weight_grads = result[-1] + w_grads bwd_state = BackwardState(param_grads, result[:-1], weight_grads) for fut in save_futures: From da4414f8bfa9324288590fa662b14d0ebac9d00a Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 16:33:34 +1100 Subject: [PATCH 06/17] Serialize DCP async saves in trainer.train() to fix post-training hang Multiple concurrent DCP async_save calls each create their own Gloo process group. With consecutive saves at steps 20-24 (last_start logic), up to 5 saves were in-flight simultaneously. Background threads from these saves may call distributed operations that conflict, causing all ranks to deadlock in fut.result() until the NCCL watchdog times out. Limit to one concurrent save at a time: wait for the previous save to complete before starting the next one. Each save still overlaps with at least one training step, so async I/O benefit is preserved. Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index 68ceda26..fc9bdff8 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -342,7 +342,7 @@ def train( chunk_size = math.isqrt(len(data)) if save_mode == "sqrt" else 1 last_start = len(data) - chunk_size - save_futures: list[Future] = [] + pending_fut: Future | None = None main = not dist.is_initialized() or dist.get_rank() != 0 pbar = tqdm(data, desc="Training", disable=main) @@ -351,15 +351,19 @@ def train( # Save checkpoint BEFORE each step. Step 0 is the initial state prior to # any updates, step 1 is the state after the first update, etc. if save_dir and (i % chunk_size == 0 or i >= last_start): - p = os.path.join(save_dir, f"step_{i}.ckpt") + # Wait for the previous save before starting a new one to avoid + # multiple concurrent DCP saves with separate Gloo groups, which can + # deadlock when background threads call distributed operations. + if pending_fut is not None: + pending_fut.result() - fut = state.save(p) - save_futures.append(fut) + p = os.path.join(save_dir, f"step_{i}.ckpt") + pending_fut = state.save(p) state = self.step(state, x, inplace=inplace, trace=trace) - for fut in save_futures: - fut.result() # wait for all checkpoints to finish saving + if pending_fut is not None: + pending_fut.result() return state From 46ea739ade5d3c58bdf4c4c2749c734b507d24e2 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 16:40:08 +1100 Subject: [PATCH 07/17] Add upfront dataset size validation in DataStream Raises a clear ValueError at init time when the dataset doesn't have enough examples for the requested number of batches, instead of crashing with an IndexError mid-training. Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bergson/trainer.py b/bergson/trainer.py index fc9bdff8..f1f61496 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -80,6 +80,14 @@ def __init__( f"{self.world_size}" ) + needed = self.batch_size * self.num_batches + if len(self.dataset) < needed: + raise ValueError( + f"Dataset has {len(self.dataset)} examples but {self.num_batches} " + f"batches of size {self.batch_size} require {needed}. " + f"Pass a larger split or reduce --num_batches." + ) + n = self.batch_size * self.num_batches shape = (n, max_length) if per_token else (n,) self.weights = nn.Parameter(torch.ones(*shape, device=device)) From 4d2df7e01f18b7f08d782b63f29c579e14ac17ac Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 16:51:09 +1100 Subject: [PATCH 08/17] Remove destroy_process_group done callback that caused fut.result() hang PyTorch's Future.result() waits for done callbacks to complete before returning. The destroy_process_group callback was invoked from DCP's background thread after each save, but destroy_process_group may do a barrier on the Gloo group. Since ranks complete their I/O at different times, the fast rank would deadlock waiting for the slow rank to also call destroy_process_group, while the slow rank was still in fut.result(). DCP holds its own reference to the process group, keeping it alive for the duration of the background I/O. GC will clean it up afterwards. Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index f1f61496..b2d54cc1 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -193,10 +193,6 @@ def save(self, path: str) -> Future: process_group=grp, ) assert isinstance(fut, Future) - - fut.add_done_callback( - lambda _, g=grp: dist.destroy_process_group(g) if g else None - ) return fut def detach_(self): From d73cb3b4be8340716026d2b3741829dc085dae9d Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 16:52:22 +1100 Subject: [PATCH 09/17] DEBUG: add logging to trace async save hang --- bergson/trainer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/bergson/trainer.py b/bergson/trainer.py index b2d54cc1..4b76f299 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -179,6 +179,9 @@ def load(self, path: str): ) def save(self, path: str) -> Future: + rank = dist.get_rank() if dist.is_initialized() else 0 + print(f"[rank {rank}] save: starting async_save for {path}", flush=True) + # Create a new process group so that we can overlap saves if dist.is_initialized(): grp = dist.new_group(backend="gloo", group_desc=path) @@ -186,6 +189,13 @@ def save(self, path: str) -> Future: else: grp = None + def _done_callback(fut, g=grp, p=path): + print(f"[rank {rank}] save: callback fired for {p}", flush=True) + if g is not None: + print(f"[rank {rank}] save: calling destroy_process_group for {p}", flush=True) + dist.destroy_process_group(g) + print(f"[rank {rank}] save: destroy_process_group done for {p}", flush=True) + fut = dcp.async_save( self.state_dict(), checkpoint_id=path, @@ -193,6 +203,7 @@ def save(self, path: str) -> Future: process_group=grp, ) assert isinstance(fut, Future) + fut.add_done_callback(_done_callback) return fut def detach_(self): @@ -359,7 +370,10 @@ def train( # multiple concurrent DCP saves with separate Gloo groups, which can # deadlock when background threads call distributed operations. if pending_fut is not None: + _rank = dist.get_rank() if dist.is_initialized() else 0 + print(f"[rank {_rank}] train: waiting for pending save before step {i}", flush=True) pending_fut.result() + print(f"[rank {_rank}] train: pending save done before step {i}", flush=True) p = os.path.join(save_dir, f"step_{i}.ckpt") pending_fut = state.save(p) @@ -367,7 +381,10 @@ def train( state = self.step(state, x, inplace=inplace, trace=trace) if pending_fut is not None: + _rank = dist.get_rank() if dist.is_initialized() else 0 + print(f"[rank {_rank}] train: waiting for final save", flush=True) pending_fut.result() + print(f"[rank {_rank}] train: final save done", flush=True) return state From b8eed22a019202057f78de779ecc103c922e1776 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Mar 2026 05:52:34 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- bergson/trainer.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index 4b76f299..5a5a037b 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -192,9 +192,15 @@ def save(self, path: str) -> Future: def _done_callback(fut, g=grp, p=path): print(f"[rank {rank}] save: callback fired for {p}", flush=True) if g is not None: - print(f"[rank {rank}] save: calling destroy_process_group for {p}", flush=True) + print( + f"[rank {rank}] save: calling destroy_process_group for {p}", + flush=True, + ) dist.destroy_process_group(g) - print(f"[rank {rank}] save: destroy_process_group done for {p}", flush=True) + print( + f"[rank {rank}] save: destroy_process_group done for {p}", + flush=True, + ) fut = dcp.async_save( self.state_dict(), @@ -371,9 +377,15 @@ def train( # deadlock when background threads call distributed operations. if pending_fut is not None: _rank = dist.get_rank() if dist.is_initialized() else 0 - print(f"[rank {_rank}] train: waiting for pending save before step {i}", flush=True) + print( + f"[rank {_rank}] train: waiting for pending save before step {i}", + flush=True, + ) pending_fut.result() - print(f"[rank {_rank}] train: pending save done before step {i}", flush=True) + print( + f"[rank {_rank}] train: pending save done before step {i}", + flush=True, + ) p = os.path.join(save_dir, f"step_{i}.ckpt") pending_fut = state.save(p) From b37a80cb393313fb341d022c93bce7776d12608f Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 07:47:25 +0000 Subject: [PATCH 11/17] Fix E501 line too long in trainer.py debug print Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index 5a5a037b..421844db 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -13,7 +13,7 @@ import torchopt from datasets import Dataset from torch import nn -from torchopt.pytree import tree_iter +from torchopt.pytree import tree_iter, tree_map from torchopt.typing import GradientTransformation, OptState from tqdm.auto import tqdm @@ -213,12 +213,15 @@ def _done_callback(fut, g=grp, p=path): return fut def detach_(self): - for p in self.params.values(): - p.detach_() + for k, p in self.params.items(): + self.params[k] = p.detach() - for t in tree_iter(self.opt_state): + def _detach_leaf(t): if isinstance(t, torch.Tensor) and t.is_floating_point(): - t.detach_() + return t.detach() + return t + + self.opt_state = tree_map(_detach_leaf, self.opt_state) @property def requires_grad(self) -> bool: @@ -378,7 +381,8 @@ def train( if pending_fut is not None: _rank = dist.get_rank() if dist.is_initialized() else 0 print( - f"[rank {_rank}] train: waiting for pending save before step {i}", + f"[rank {_rank}] train: waiting for pending save" + f" before step {i}", flush=True, ) pending_fut.result() @@ -423,6 +427,10 @@ def backward( fwd_state.batch_index = idx fwd_state.load(path) + # Detach after loading so that replay steps can use in-place ops + # (loaded tensors may retain requires_grad from the previous traced step) + fwd_state.detach_() + # Only delete this checkpoint if it's the one we expected to load. If it's # not, we need to keep it around, and step forward through training if idx == expected_idx: From c1e7f612b6bfe3bdd291123c3cfd44e318b7cdae Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 07:55:55 +0000 Subject: [PATCH 12/17] Save run and dist configs to run path in magic CLI Co-Authored-By: Claude Opus 4.6 --- bergson/double_backward.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/bergson/double_backward.py b/bergson/double_backward.py index 523e1bb1..392947ee 100644 --- a/bergson/double_backward.py +++ b/bergson/double_backward.py @@ -1,6 +1,8 @@ +import json import os -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import timedelta +from pathlib import Path from typing import Literal import torch @@ -281,6 +283,13 @@ def schedule(step: Numeric) -> Numeric: def double_backward(run_cfg: DoubleBackwardConfig, dist_cfg: DistributedConfig): + run_path = Path(run_cfg.run_path) + run_path.mkdir(parents=True, exist_ok=True) + with (run_path / "run_config.json").open("w") as f: + json.dump(asdict(run_cfg), f, indent=2) + with (run_path / "dist_config.json").open("w") as f: + json.dump(asdict(dist_cfg), f, indent=2) + train_ds = load_data_string( run_cfg.data.dataset, run_cfg.data.split, From 49e501f2007bb3fb55eb0b6311a3ff77d94c5be3 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 10:35:21 +0000 Subject: [PATCH 13/17] Remove per-token weight support (deferred to magic-per-token branch) Strip per_token parameter from DataStream and 2D weight path from weighted_causal_lm_ce to keep the merge scope minimal. The per-token code is preserved on the magic-per-token branch. Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 8 +------- bergson/utils/math.py | 15 +++++---------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index 421844db..aca12449 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -60,13 +60,11 @@ def __init__( device: torch.device | str = "cpu", input_key: str = "text", max_length: int = 256, - per_token: bool = False, ): self.dataset = dataset self.processor = processor self.input_key = input_key self.max_length = max_length - self.per_token = per_token self.batch_size = batch_size self.device = device @@ -89,8 +87,7 @@ def __init__( ) n = self.batch_size * self.num_batches - shape = (n, max_length) if per_token else (n,) - self.weights = nn.Parameter(torch.ones(*shape, device=device)) + self.weights = nn.Parameter(torch.ones(n, device=device)) @property def requires_grad(self) -> bool: @@ -129,9 +126,6 @@ def __getitem__(self, i: int) -> dict: i * self.batch_size + self.rank : (i + 1) * self.batch_size : self.world_size ] - if self.per_token and w.ndim == 2: - T_batch = x["input_ids"].shape[1] - w = w[:, :T_batch] x["example_weight"] = w return {k: v.to(self.device) for k, v in x.items()} diff --git a/bergson/utils/math.py b/bergson/utils/math.py index 622006a0..24b63a35 100644 --- a/bergson/utils/math.py +++ b/bergson/utils/math.py @@ -19,7 +19,7 @@ def weighted_causal_lm_ce( Args: logits : [B, T, V] float tensor of prediction scores labels : [B, T] long tensor of target token ids, or ignore_index - example_weight : [B] or [B, T] float tensor of weights + example_weight : [B] float tensor of weights ignore_index : int, label value to ignore in loss computation vocab_size : optional int, vocabulary size (for validation) """ @@ -27,10 +27,9 @@ def weighted_causal_lm_ce( B, T, V = logits.shape assert labels.shape == (B, T) if example_weight is not None: - assert example_weight.ndim in (1, 2), ( - f"example_weight must be 1D [B] or 2D [B, T], " - f"got shape {example_weight.shape}" - ) + assert ( + example_weight.ndim == 1 + ), f"example_weight must be 1D [B], got shape {example_weight.shape}" # HuggingFace always passes a vocab_size kwarg if vocab_size is not None: @@ -54,11 +53,7 @@ def weighted_causal_lm_ce( if example_weight is None: return tok_loss.mean() - if example_weight.ndim == 1: - w = example_weight.to(tok_loss.dtype).view(B, 1) # [B,1] - else: - # Per-token weights: shift to align with shifted labels - w = example_weight[:, 1:].to(tok_loss.dtype) # [B, T-1] + w = example_weight.to(tok_loss.dtype).view(B, 1) # [B,1] return (tok_loss * w).mean() From ef891df0bc001fa2fd9305d9ce9704c536b0fa67 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 10:45:20 +0000 Subject: [PATCH 14/17] Use exact shape assertion for example_weight in weighted_causal_lm_ce Co-Authored-By: Claude Opus 4.6 --- bergson/utils/math.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bergson/utils/math.py b/bergson/utils/math.py index 24b63a35..0abeffed 100644 --- a/bergson/utils/math.py +++ b/bergson/utils/math.py @@ -19,7 +19,7 @@ def weighted_causal_lm_ce( Args: logits : [B, T, V] float tensor of prediction scores labels : [B, T] long tensor of target token ids, or ignore_index - example_weight : [B] float tensor of weights + example_weight : [B] float tensor of per-example weights ignore_index : int, label value to ignore in loss computation vocab_size : optional int, vocabulary size (for validation) """ @@ -27,9 +27,7 @@ def weighted_causal_lm_ce( B, T, V = logits.shape assert labels.shape == (B, T) if example_weight is not None: - assert ( - example_weight.ndim == 1 - ), f"example_weight must be 1D [B], got shape {example_weight.shape}" + assert example_weight.shape == (B,) # HuggingFace always passes a vocab_size kwarg if vocab_size is not None: From 0c1f9aedc039edc709fab8c59da62292ecc03048 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 11:04:09 +0000 Subject: [PATCH 15/17] Clean up trainer.py: use assert for validation, remove debug prints Co-Authored-By: Claude Opus 4.6 --- bergson/trainer.py | 41 +++++++---------------------------------- 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/bergson/trainer.py b/bergson/trainer.py index aca12449..c4e9f0e3 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -79,12 +79,11 @@ def __init__( ) needed = self.batch_size * self.num_batches - if len(self.dataset) < needed: - raise ValueError( - f"Dataset has {len(self.dataset)} examples but {self.num_batches} " - f"batches of size {self.batch_size} require {needed}. " - f"Pass a larger split or reduce --num_batches." - ) + assert len(self.dataset) >= needed, ( + f"Dataset has {len(self.dataset)} examples but {self.num_batches} " + f"batches of size {self.batch_size} require {needed}. " + f"Pass a larger split or reduce --num_batches." + ) n = self.batch_size * self.num_batches self.weights = nn.Parameter(torch.ones(n, device=device)) @@ -122,11 +121,10 @@ def __getitem__(self, i: int) -> dict: truncation=True, ) x["labels"] = x["input_ids"] - w = self.weights[ + x["example_weight"] = self.weights[ i * self.batch_size + self.rank : (i + 1) * self.batch_size : self.world_size ] - x["example_weight"] = w return {k: v.to(self.device) for k, v in x.items()} def __iter__(self): @@ -173,9 +171,6 @@ def load(self, path: str): ) def save(self, path: str) -> Future: - rank = dist.get_rank() if dist.is_initialized() else 0 - print(f"[rank {rank}] save: starting async_save for {path}", flush=True) - # Create a new process group so that we can overlap saves if dist.is_initialized(): grp = dist.new_group(backend="gloo", group_desc=path) @@ -183,18 +178,9 @@ def save(self, path: str) -> Future: else: grp = None - def _done_callback(fut, g=grp, p=path): - print(f"[rank {rank}] save: callback fired for {p}", flush=True) + def _done_callback(fut, g=grp): if g is not None: - print( - f"[rank {rank}] save: calling destroy_process_group for {p}", - flush=True, - ) dist.destroy_process_group(g) - print( - f"[rank {rank}] save: destroy_process_group done for {p}", - flush=True, - ) fut = dcp.async_save( self.state_dict(), @@ -373,17 +359,7 @@ def train( # multiple concurrent DCP saves with separate Gloo groups, which can # deadlock when background threads call distributed operations. if pending_fut is not None: - _rank = dist.get_rank() if dist.is_initialized() else 0 - print( - f"[rank {_rank}] train: waiting for pending save" - f" before step {i}", - flush=True, - ) pending_fut.result() - print( - f"[rank {_rank}] train: pending save done before step {i}", - flush=True, - ) p = os.path.join(save_dir, f"step_{i}.ckpt") pending_fut = state.save(p) @@ -391,10 +367,7 @@ def train( state = self.step(state, x, inplace=inplace, trace=trace) if pending_fut is not None: - _rank = dist.get_rank() if dist.is_initialized() else 0 - print(f"[rank {_rank}] train: waiting for final save", flush=True) pending_fut.result() - print(f"[rank {_rank}] train: final save done", flush=True) return state From bc589cb595b60efbdc2a7b872107a9f957c235fd Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 7 Mar 2026 11:08:07 +0000 Subject: [PATCH 16/17] Update magic docs: fix CLI usage flags, remove per-token references Co-Authored-By: Claude Opus 4.6 --- docs/magic.rst | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/docs/magic.rst b/docs/magic.rst index 6f48c1cb..53baaaa4 100644 --- a/docs/magic.rst +++ b/docs/magic.rst @@ -19,7 +19,13 @@ The ``Trainer`` class handles all three phases. It uses `torchopt Date: Sat, 7 Mar 2026 12:18:48 +0000 Subject: [PATCH 17/17] update docs --- docs/magic.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/magic.rst b/docs/magic.rst index 53baaaa4..a9baf1ec 100644 --- a/docs/magic.rst +++ b/docs/magic.rst @@ -21,10 +21,10 @@ Usage .. code-block:: bash - bergson magic runs/magic-ckpts \ + CUDA_VISIBLE_DEVICES="0" bergson magic runs/magic-ckpts \ --data.dataset NeelNanda/pile-10k \ --query.dataset NeelNanda/pile-10k \ - --query.split "train[:1]" \ + --query.split "train[:8]" \ --model EleutherAI/pythia-14m Core components