diff --git a/bergson/__main__.py b/bergson/__main__.py index 79ec970a..b02ac2a9 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -8,6 +8,7 @@ from .build import build from .config import ( + DistributedConfig, HessianConfig, IndexConfig, QueryConfig, @@ -15,6 +16,7 @@ ScoreConfig, TrackstarConfig, ) +from .double_backward import DoubleBackwardConfig, double_backward from .hessians.hessian_approximations import approximate_hessians from .query.query_index import query from .reduce import reduce @@ -194,11 +196,25 @@ def execute(self): score_dataset(score_index_cfg, self.score_cfg) +@dataclass +class Magic: + """Run MAGIC attribution.""" + + run_cfg: DoubleBackwardConfig + dist_cfg: DistributedConfig + + def execute(self): + """Run MAGIC attribution.""" + double_backward(self.run_cfg, self.dist_cfg) + + @dataclass class Main: """Routes to the subcommands.""" - command: Union[Build, Query, Preconditioners, Reduce, Score, Hessian, Trackstar] + command: Union[ + Build, Query, Preconditioners, Reduce, Score, Hessian, Trackstar, Magic + ] def execute(self): """Run the script.""" diff --git a/bergson/double_backward.py b/bergson/double_backward.py new file mode 100644 index 00000000..392947ee --- /dev/null +++ b/bergson/double_backward.py @@ -0,0 +1,325 @@ +import json +import os +from dataclasses import asdict, dataclass +from datetime import timedelta +from pathlib import Path +from typing import Literal + +import torch +import torch.distributed as dist +import torchopt +from scipy.stats import spearmanr +from simple_parsing import ArgumentParser, field +from torch.distributed.tensor import init_device_mesh +from torchopt.pytree import tree_iter +from torchopt.typing import Numeric +from transformers import AutoModelForCausalLM, AutoTokenizer + +from bergson.config import DataConfig, DistributedConfig +from bergson.data import load_data_string +from bergson.distributed import grad_tree, launch_distributed_run, simple_fsdp +from bergson.magic_patch import apply_dtensor_patch +from bergson.trainer import BackwardState, DataStream, Trainer, TrainerState +from bergson.utils.math import weighted_causal_lm_ce + + +@dataclass +class DoubleBackwardConfig: + run_path: str = field(positional=True) + """Directory to save checkpoints and results.""" + + model: str = "EleutherAI/pythia-160m" + """HuggingFace model name.""" + + revision: str | None = None + """Model revision (branch, tag, or commit hash).""" + + data: DataConfig = field(default_factory=DataConfig) + """Training dataset.""" + + query: DataConfig = field(default_factory=lambda: DataConfig()) + """Query/eval dataset for computing attribution target gradients. + If not specified, defaults to the training dataset.""" + + query_method: Literal["mean", "sum"] = "mean" + """Method for reducing query gradients across batches.""" + + query_batches: int = 1 + """Number of query batches to use for computing eval gradients.""" + + fsdp: bool = False + """Whether to use FSDP for multi-GPU training.""" + + grad_checkpointing: bool = False + """Whether to use gradient checkpointing during the forward pass.""" + + lr: float = 1e-5 + """Base learning rate after warmup.""" + + warmup_steps: int = 10 + """Number of warmup steps before applying base lr.""" + + batch_size: int = 8 + """Per-device batch size.""" + + num_batches: int = 25 + """Number of training batches.""" + + max_length: int = 256 + """Maximum token sequence length.""" + + num_subsets: int = 100 + """Number of leave-one-out subsets for Spearman correlation.""" + + seed: int = 42 + """Random seed for subset permutation.""" + + +def compute_query_gradients( + trainer: Trainer, + fwd_state: TrainerState, + model: torch.nn.Module, + query_stream: DataStream, + method: str = "mean", +) -> dict[str, torch.Tensor]: + """Compute reduced query gradients over the query dataset. + + Iterates over the query stream, computing per-batch parameter gradients + and reducing them (mean or sum) into a single gradient dict. + """ + grad_accum: dict[str, torch.Tensor] | None = None + n_batches = 0 + + with fwd_state.activate(model) as params: + for batch in query_stream: + del batch["example_weight"] + loss = model(**batch).loss + grads = grad_tree(loss, params) + + if grad_accum is None: + grad_accum = {k: g.detach().clone() for k, g in grads.items()} + else: + for k, g in grads.items(): + grad_accum[k] += g.detach() + n_batches += 1 + + assert grad_accum is not None, "Query stream was empty" + + if method == "mean": + for k in grad_accum: + grad_accum[k] /= n_batches + + return grad_accum + + +def worker( + global_rank: int, + rank: int, + world_size: int, + train_dataset, + query_dataset, + run_cfg: DoubleBackwardConfig, +): + torch.cuda.set_device(rank) + + model = AutoModelForCausalLM.from_pretrained( + run_cfg.model, + revision=run_cfg.revision, + torch_dtype=torch.float32, + attn_implementation="eager", + ) + model.loss_function = weighted_causal_lm_ce + model.to(f"cuda:{rank}") + if run_cfg.grad_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=dict(use_reentrant=False), + ) + + processor = AutoTokenizer.from_pretrained(run_cfg.model) + processor.pad_token = processor.eos_token + + if world_size > 1: + addr = os.environ.get("MASTER_ADDR", "localhost") + port = os.environ.get("MASTER_PORT", "29500") + + dist.init_process_group( + "cpu:gloo,cuda:nccl", + init_method=f"tcp://{addr}:{port}", + device_id=torch.device(f"cuda:{rank}"), + rank=rank, + timeout=timedelta(hours=1), + world_size=world_size, + ) + + if run_cfg.fsdp and world_size > 1: + apply_dtensor_patch() + mesh = init_device_mesh("cuda", (world_size,)) + with mesh: + model = simple_fsdp(model) + + def schedule(step: Numeric) -> Numeric: + if step < run_cfg.warmup_steps: + return 0.0 + return run_cfg.lr + + opt = torchopt.adamw( + schedule, + betas=(0.95, 0.975), + eps_root=1e-8, + ) + trainer, fwd_state = Trainer.initialize(model, opt) + + ckpts_path = os.path.join(run_cfg.run_path, "checkpoints") + path0 = os.path.join(ckpts_path, "state0.pt") + save_fut = fwd_state.save(path0) + + stream = DataStream( + train_dataset, + processor, + batch_size=run_cfg.batch_size, + num_batches=run_cfg.num_batches, + device=f"cuda:{rank}", + max_length=run_cfg.max_length, + input_key=run_cfg.data.prompt_column, + ) + fwd_state = trainer.train( + fwd_state, + stream, + inplace=True, + save_dir=ckpts_path, + ) + + # Compute query gradients + query_stream = DataStream( + query_dataset, + processor, + batch_size=run_cfg.batch_size, + num_batches=run_cfg.query_batches, + device=f"cuda:{rank}", + max_length=run_cfg.max_length, + input_key=run_cfg.query.prompt_column, + ) + + query_grads = compute_query_gradients( + trainer, fwd_state, model, query_stream, run_cfg.query_method + ) + + if world_size > 1: + reduce_op = ( + dist.ReduceOp.AVG if run_cfg.query_method == "mean" else dist.ReduceOp.SUM + ) + for v in query_grads.values(): + dist.all_reduce(v, op=reduce_op) + + stream.requires_grad = True + opt_grads = [ + torch.zeros_like(buf) + for buf in tree_iter(fwd_state.opt_state) + if isinstance(buf, torch.Tensor) and buf.is_floating_point() + ] + bwd_state = BackwardState(query_grads, opt_grads, torch.zeros_like(stream.weights)) + + # Compute baseline eval loss for validation + with fwd_state.activate(model): + baseline_batch = query_stream[0] + del baseline_batch["example_weight"] + baseline_loss = model(**baseline_batch).loss + + if world_size > 1: + dist.all_reduce(baseline_loss, op=dist.ReduceOp.AVG) + + bwd_state = trainer.backward( + ckpts_path, + stream, + bwd_state, + fwd_state, + inplace=True, + ) + if world_size > 1: + dist.all_reduce(bwd_state.weight_grads, op=dist.ReduceOp.AVG) + + baseline = baseline_loss.item() + if global_rank == 0: + print(f"Scores: {bwd_state.weight_grads.tolist()}") + print(f"Baseline: {baseline}") + print(f"Grad sum: {bwd_state.weight_grads.sum()}") + + stream.requires_grad = False + + # Validate attribution scores via leave-subset-out retraining + diffs = [] + score_sums = [] + + gen = torch.Generator().manual_seed(run_cfg.seed) + perm = torch.randperm(len(stream.weights), generator=gen) + subsets = perm.chunk(run_cfg.num_subsets) + + save_fut.result() # ensure state0 is saved before loading in loop + fwd_state.load(path0) + + for subset in subsets: + stream.weights.fill_(1.0) + stream.weights[subset] = 0.0 + + for x in stream: + fwd_state = trainer.step(fwd_state, x) + + with fwd_state.activate(model): + eval_batch = query_stream[0] + del eval_batch["example_weight"] + loss = model(**eval_batch).loss + + if world_size > 1: + dist.all_reduce(loss, op=dist.ReduceOp.AVG) + + diffs.append(baseline - loss.item()) + score_sums.append(bwd_state.weight_grads[subset].sum().item()) + + corr = spearmanr(diffs, score_sums) + if global_rank == 0: + print(f"Loss diff: {diffs[-1]}") + print(f"Score: {score_sums[-1]}") + print(f"Spearman correlation: {corr}") + + +def double_backward(run_cfg: DoubleBackwardConfig, dist_cfg: DistributedConfig): + run_path = Path(run_cfg.run_path) + run_path.mkdir(parents=True, exist_ok=True) + with (run_path / "run_config.json").open("w") as f: + json.dump(asdict(run_cfg), f, indent=2) + with (run_path / "dist_config.json").open("w") as f: + json.dump(asdict(dist_cfg), f, indent=2) + + train_ds = load_data_string( + run_cfg.data.dataset, + run_cfg.data.split, + run_cfg.data.subset, + run_cfg.data.data_args, + ) + + query_ds = load_data_string( + run_cfg.query.dataset, + run_cfg.query.split, + run_cfg.query.subset, + run_cfg.query.data_args, + ) + + launch_distributed_run( + "double_backward", worker, [train_ds, query_ds, run_cfg], dist_cfg + ) + + +def main(): + parser = ArgumentParser() + parser.add_arguments(DoubleBackwardConfig, dest="run_cfg") + parser.add_arguments(DistributedConfig, dest="dist_cfg") + args = parser.parse_args() + + run_cfg: DoubleBackwardConfig = args.run_cfg + dist_cfg: DistributedConfig = args.dist_cfg + + double_backward(run_cfg, dist_cfg) + + +if __name__ == "__main__": + main() diff --git a/bergson/magic_patch.py b/bergson/magic_patch.py new file mode 100644 index 00000000..b22b9295 --- /dev/null +++ b/bergson/magic_patch.py @@ -0,0 +1,202 @@ +"""Runtime monkey-patch for twice-differentiable DTensor redistribution. + +Implements pytorch/pytorch#160509 at runtime, avoiding the need to modify +torch source files on disk. Call `apply_dtensor_patch()` before any DTensor +operations that require double backward (e.g. MAGIC attribution with FSDP). + +Safe to call multiple times (idempotent). +""" + +import torch + +_PATCHED = False + + +def apply_dtensor_patch(): + """Patch DTensor redistribution to support double backward. + + Monkey-patches `Redistribute.backward` and `_ToTorchTensor.backward` + in the installed torch package so that FSDP redistribution is + twice-differentiable. + """ + global _PATCHED + if _PATCHED: + return + + _patch_redistribute() + _patch_to_torch_tensor() + _PATCHED = True + + +def _patch_redistribute(): + import torch.distributed.tensor._api as dtensor + from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta + from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, + ) + from torch.distributed.tensor.placement_types import Replicate + + def _redistribute_backward( + grad_output, + previous_spec, + original_dtype=None, + backward_dtype=None, + async_op=False, + ): + if backward_dtype != grad_output._local_tensor.dtype: + local_tensor = grad_output._local_tensor.to(dtype=backward_dtype) + current_spec = DTensorSpec( + mesh=grad_output._spec.device_mesh, + placements=grad_output._spec.placements, + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=backward_dtype, + ), + ) + previous_spec = DTensorSpec( + mesh=previous_spec.device_mesh, + placements=previous_spec.placements, + tensor_meta=current_spec.tensor_meta, + ) + else: + local_tensor = grad_output._local_tensor + current_spec = grad_output._spec + + normalized_placements = [] + for current, target in zip(current_spec.placements, previous_spec.placements): + if (current.is_shard() or current.is_replicate()) and target.is_partial(): + normalized_placements.append(Replicate()) + else: + normalized_placements.append(target) + + previous_spec = DTensorSpec( + previous_spec.device_mesh, + placements=tuple(normalized_placements), + tensor_meta=previous_spec.tensor_meta, + ) + + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + is_backward=True, + ) + + if output.dtype != original_dtype: + output = output.to(original_dtype) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=output.dtype, + ), + ) + return output, spec + + class NestedRedistribute(torch.autograd.Function): + """Makes DTensor redistribution twice-differentiable. + + Called during Redistribute.backward (first backward pass). + NestedRedistribute.backward handles the second backward pass. + """ + + @staticmethod + def forward( + ctx, + grad_output, + previous_spec, + async_op=False, + backward_dtype=None, + original_dtype=None, + ): + ctx.async_op = async_op + ctx.backward_dtype = backward_dtype or original_dtype + ctx.original_dtype = grad_output._local_tensor.dtype + + output, spec = _redistribute_backward( + grad_output, + previous_spec, + ctx.original_dtype, + backward_dtype, + async_op, + ) + + ctx.current_spec = spec + + return dtensor.DTensor( + output, + spec, + requires_grad=grad_output.requires_grad, + ) + + @staticmethod + def backward(ctx, grad2_output): + output_dtensor = NestedRedistribute.apply( + grad2_output, + ctx.current_spec, + ctx.async_op, + ctx.backward_dtype, + ctx.original_dtype, + ) + + return (output_dtensor, None, None, None, None) + + @staticmethod + def _new_redistribute_backward(ctx, grad_output): + previous_spec = ctx.current_spec + output_dtensor = NestedRedistribute.apply( + grad_output, + previous_spec, + ctx.async_op, + ctx.backward_dtype, + ctx.original_dtype, + ) + return (output_dtensor, None, None, None, None, None) + + Redistribute.backward = _new_redistribute_backward + + +def _patch_to_torch_tensor(): + import torch.distributed.tensor._api as dtensor_api + from torch.distributed.tensor._api import _ToTorchTensor + from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta + from torch.distributed.tensor._utils import compute_global_tensor_info + + @staticmethod + def _new_backward(ctx, grad_output): + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + + return ( + dtensor_api.DTensor.from_local( + grad_output, + grad_spec.device_mesh, + grad_spec.placements, + ), + None, + ) + + _ToTorchTensor.backward = _new_backward diff --git a/bergson/trainer.py b/bergson/trainer.py index af38f465..c4e9f0e3 100644 --- a/bergson/trainer.py +++ b/bergson/trainer.py @@ -13,7 +13,7 @@ import torchopt from datasets import Dataset from torch import nn -from torchopt.pytree import tree_iter +from torchopt.pytree import tree_iter, tree_map from torchopt.typing import GradientTransformation, OptState from tqdm.auto import tqdm @@ -78,6 +78,13 @@ def __init__( f"{self.world_size}" ) + needed = self.batch_size * self.num_batches + assert len(self.dataset) >= needed, ( + f"Dataset has {len(self.dataset)} examples but {self.num_batches} " + f"batches of size {self.batch_size} require {needed}. " + f"Pass a larger split or reduce --num_batches." + ) + n = self.batch_size * self.num_batches self.weights = nn.Parameter(torch.ones(n, device=device)) @@ -171,6 +178,10 @@ def save(self, path: str) -> Future: else: grp = None + def _done_callback(fut, g=grp): + if g is not None: + dist.destroy_process_group(g) + fut = dcp.async_save( self.state_dict(), checkpoint_id=path, @@ -178,19 +189,19 @@ def save(self, path: str) -> Future: process_group=grp, ) assert isinstance(fut, Future) - - fut.add_done_callback( - lambda _, g=grp: dist.destroy_process_group(g) if g else None - ) + fut.add_done_callback(_done_callback) return fut def detach_(self): - for p in self.params.values(): - p.detach_() + for k, p in self.params.items(): + self.params[k] = p.detach() - for t in tree_iter(self.opt_state): + def _detach_leaf(t): if isinstance(t, torch.Tensor) and t.is_floating_point(): - t.detach_() + return t.detach() + return t + + self.opt_state = tree_map(_detach_leaf, self.opt_state) @property def requires_grad(self) -> bool: @@ -335,7 +346,7 @@ def train( chunk_size = math.isqrt(len(data)) if save_mode == "sqrt" else 1 last_start = len(data) - chunk_size - save_futures: list[Future] = [] + pending_fut: Future | None = None main = not dist.is_initialized() or dist.get_rank() != 0 pbar = tqdm(data, desc="Training", disable=main) @@ -344,15 +355,19 @@ def train( # Save checkpoint BEFORE each step. Step 0 is the initial state prior to # any updates, step 1 is the state after the first update, etc. if save_dir and (i % chunk_size == 0 or i >= last_start): - p = os.path.join(save_dir, f"step_{i}.ckpt") + # Wait for the previous save before starting a new one to avoid + # multiple concurrent DCP saves with separate Gloo groups, which can + # deadlock when background threads call distributed operations. + if pending_fut is not None: + pending_fut.result() - fut = state.save(p) - save_futures.append(fut) + p = os.path.join(save_dir, f"step_{i}.ckpt") + pending_fut = state.save(p) state = self.step(state, x, inplace=inplace, trace=trace) - for fut in save_futures: - fut.result() # wait for all checkpoints to finish saving + if pending_fut is not None: + pending_fut.result() return state @@ -379,6 +394,10 @@ def backward( fwd_state.batch_index = idx fwd_state.load(path) + # Detach after loading so that replay steps can use in-place ops + # (loaded tensors may retain requires_grad from the previous traced step) + fwd_state.detach_() + # Only delete this checkpoint if it's the one we expected to load. If it's # not, we need to keep it around, and step forward through training if idx == expected_idx: diff --git a/docs/index.rst b/docs/index.rst index c3b833be..29e2a41a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -68,6 +68,7 @@ Experiments :maxdepth: 2 experiments + magic Content Index diff --git a/docs/magic.rst b/docs/magic.rst new file mode 100644 index 00000000..a9baf1ec --- /dev/null +++ b/docs/magic.rst @@ -0,0 +1,71 @@ +MAGIC Attribution +================= + +`MAGIC `_ (Model-Agnostic Generation-time Influence via Checkpointing) attributes evaluation loss to individual training examples by backpropagating through the entire training process. Unlike influence functions which use a local approximation, MAGIC computes exact counterfactual attribution by differentiating through checkpointed training steps. + +We provide a `Trainer` class that takes differentiable training steps and handles all three phases of MAGIC attribution. We support FSDP training using the `bergson.magic_patch` runtime patch, which makes PyTorch's DTensor redistribution twice-differentiable (`pytorch/pytorch#160509 `_). The patch is applied in memory, so no torch source files are modified. + +How it works +------------ + +MAGIC attribution has three phases: + +1. **Forward training with checkpoints**: Fine-tune the model, saving intermediate checkpoints at each step. +2. **Evaluate**: Compute the evaluation loss and its gradients with respect to the final model parameters. +3. **Backward through training**: Backpropagate the evaluation gradients through the checkpointed training steps using reverse-mode autodiff, accumulating attribution scores for each training example (or token). + +The ``Trainer`` class handles all three phases. It uses `torchopt `_ for functional (stateless) differentiable optimization. + +Usage +----- + +.. code-block:: bash + + CUDA_VISIBLE_DEVICES="0" bergson magic runs/magic-ckpts \ + --data.dataset NeelNanda/pile-10k \ + --query.dataset NeelNanda/pile-10k \ + --query.split "train[:8]" \ + --model EleutherAI/pythia-14m + +Core components +^^^^^^^^^^^^^^^ + +**Trainer**: Functional trainer that supports forward training with checkpoints and backward-through-training. + +.. code-block:: python + + from bergson.trainer import Trainer, DataStream, BackwardState, TrainerState + import torchopt + + # Initialize + opt = torchopt.adam(lr=1e-4) + trainer, state = Trainer.initialize(model, opt) + + # Forward training with checkpoints + stream = DataStream(dataset, tokenizer, batch_size=4, num_batches=250, device="cuda") + state = trainer.train(state, stream, save_dir="checkpoints/") + + # Compute eval gradients, then backward through training + bwd_state = trainer.backward("checkpoints/", stream, bwd_state) + scores = bwd_state.weight_grads # attribution scores + +**DataStream**: Wraps a dataset with differentiable per-example weights that receive gradients during the backward pass. + +.. code-block:: python + + stream = DataStream(dataset, tokenizer, batch_size=4, num_batches=250, device="cuda") + +**DTensor patch**: For multi-GPU runs with FSDP, apply the DTensor patch before any distributed operations: + +.. code-block:: python + + from bergson.magic_patch import apply_dtensor_patch + apply_dtensor_patch() + +Key implementation details +-------------------------- + +- **Functional optimization**: ``torchopt.adam`` (or similar) provides a pure-function optimizer whose state is a pytree of tensors. This allows ``torch.autograd.grad`` to differentiate through optimizer updates. +- **Checkpoint strategy**: By default, checkpoints are saved at ``sqrt(N)`` intervals, giving ``O(sqrt(N))`` memory and ``O(N * sqrt(N))`` recomputation cost. +- **FSDP compatibility**: The DTensor runtime patch adds a ``NestedRedistribute`` autograd function that makes the FSDP all-gather/reduce-scatter differentiable through second-order backward passes. +- **Loss weighting**: ``weighted_causal_lm_ce`` multiplies per-token cross-entropy by the DataStream weights before averaging. During backward-through-training, autograd accumulates gradients into these weights, yielding the attribution scores. diff --git a/examples/double_backward.py b/examples/double_backward.py deleted file mode 100644 index 3f079ba8..00000000 --- a/examples/double_backward.py +++ /dev/null @@ -1,209 +0,0 @@ -import os -from dataclasses import dataclass -from datetime import timedelta - -import torch -import torch.distributed as dist -import torchopt -from datasets import load_dataset -from scipy.stats import spearmanr -from simple_parsing import ArgumentParser -from torch.distributed.tensor import init_device_mesh -from torchopt.pytree import tree_iter -from torchopt.typing import Numeric -from transformers import AutoTokenizer, GPTNeoXConfig, GPTNeoXForCausalLM - -from bergson.config import DistributedConfig -from bergson.distributed import grad_tree, launch_distributed_run, simple_fsdp -from bergson.trainer import BackwardState, DataStream, Trainer -from bergson.utils.math import weighted_causal_lm_ce - - -@dataclass -class RunConfig: - model_name: str = "EleutherAI/pythia-160m" - """HuggingFace model name.""" - - dataset_name: str = "EleutherAI/SmolLM2-135M-10B" - """HuggingFace dataset name.""" - - dataset_split: str = "train" - """Dataset split to use.""" - - grad_checkpointing: bool = False - """Whether to use gradient checkpointing during the forward pass.""" - - lr: float = 1e-5 - """Base learning rate after warmup.""" - - warmup_steps: int = 10 - """Number of warmup steps before applying base lr.""" - - batch_size: int = 8 - """Per-device batch size.""" - - num_batches: int = 25 - """Number of training batches.""" - - max_length: int = 256 - """Maximum token sequence length.""" - - save_dir: str = "/mnt/ssd-3/nora/magic-ckpts" - """Directory to save forward pass checkpoints.""" - - num_subsets: int = 100 - """Number of leave-one-out subsets for Spearman correlation.""" - - seed: int = 42 - """Random seed for subset permutation.""" - - -def worker(global_rank: int, rank: int, world_size: int, dataset, run_cfg: RunConfig): - torch.cuda.set_device(rank) - - cfg = GPTNeoXConfig.from_pretrained(run_cfg.model_name, revision="step0") - model = GPTNeoXForCausalLM(cfg) - model.set_attn_implementation("eager") - model.loss_function = weighted_causal_lm_ce - model.to(f"cuda:{rank}") - if run_cfg.grad_checkpointing: - model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=dict(use_reentrant=False), - ) - - processor = AutoTokenizer.from_pretrained(run_cfg.model_name) - processor.pad_token = processor.eos_token - - if world_size > 1: - addr = os.environ.get("MASTER_ADDR", "localhost") - port = os.environ.get("MASTER_PORT", "29500") - - dist.init_process_group( - "cpu:gloo,cuda:nccl", - init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), - rank=rank, - timeout=timedelta(hours=1), - world_size=world_size, - ) - mesh = init_device_mesh("cuda", (world_size,)) - with mesh: - model = simple_fsdp(model) - - def schedule(step: Numeric) -> Numeric: - if step < run_cfg.warmup_steps: - return 0.0 - return run_cfg.lr - - opt = torchopt.adamw( - schedule, - betas=(0.95, 0.975), - eps_root=1e-8, - ) - trainer, fwd_state = Trainer.initialize(model, opt) - - # save state0 - path0 = os.path.join(run_cfg.save_dir, "state0.pt") - save_fut = fwd_state.save(path0) - - stream = DataStream( - dataset, - processor, - batch_size=run_cfg.batch_size, - num_batches=run_cfg.num_batches, - device=f"cuda:{rank}", - max_length=run_cfg.max_length, - ) - fwd_state = trainer.train( - fwd_state, - stream, - inplace=True, - save_dir=run_cfg.save_dir, - ) - - with fwd_state.activate(model) as params: - stream.requires_grad = True - - ex = stream[0] - del ex["example_weight"] - - loss = model(**ex).loss - - grads = grad_tree(loss, params, create_graph=True) - opt_grads = [ - torch.zeros_like(buf) - for buf in tree_iter(fwd_state.opt_state) - if isinstance(buf, torch.Tensor) and buf.is_floating_point() - ] - bwd_state = BackwardState(grads, opt_grads, torch.zeros_like(stream.weights)) - - if world_size > 1: - dist.all_reduce(loss, op=dist.ReduceOp.AVG) - - bwd_state = trainer.backward( - run_cfg.save_dir, - stream, - bwd_state, - fwd_state, - inplace=True, - ) - if world_size > 1: - dist.all_reduce(bwd_state.weight_grads, op=dist.ReduceOp.AVG) - if global_rank == 0: - print(f"Scores 2: {bwd_state.weight_grads.tolist()}") - - baseline = loss.item() - if global_rank == 0: - print(f"Baseline: {baseline}") - print("Grad:", bwd_state.weight_grads.sum()) - - stream.requires_grad = False - - diffs = [] - score_sums = [] - - gen = torch.Generator().manual_seed(run_cfg.seed) - perm = torch.randperm(len(stream.weights), generator=gen) - subsets = perm.chunk(run_cfg.num_subsets) - - save_fut.result() # ensure state0 is saved before loading in loop - fwd_state.load(path0) - - for subset in subsets: - stream.weights.fill_(1.0) - stream.weights[subset] = 0.0 - - for x in stream: - fwd_state = trainer.step(fwd_state, x) - - with fwd_state.activate(model): - loss = model(**stream[0]).loss - - if world_size > 1: - dist.all_reduce(loss, op=dist.ReduceOp.AVG) - - diffs.append(baseline - loss.item()) - score_sums.append(bwd_state.weight_grads[subset].sum().item()) - - corr = spearmanr(diffs, score_sums) - if global_rank == 0: - print(f"Loss diff: {diffs[-1]}") - print(f"Score: {score_sums[-1]}") - print(f"Spearman correlation: {corr}") - - -def main(): - parser = ArgumentParser() - parser.add_arguments(RunConfig, dest="run_cfg") - parser.add_arguments(DistributedConfig, dest="dist_cfg") - args = parser.parse_args() - - run_cfg: RunConfig = args.run_cfg - dist_cfg: DistributedConfig = args.dist_cfg - - ds = load_dataset(run_cfg.dataset_name, split=run_cfg.dataset_split) - launch_distributed_run("double_backward", worker, [ds, run_cfg], dist_cfg) - - -if __name__ == "__main__": - main() diff --git a/examples/double_backward_pretrain.py b/examples/double_backward_pretrain.py new file mode 100644 index 00000000..6b4f1999 --- /dev/null +++ b/examples/double_backward_pretrain.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +"""MAGIC attribution on a pretrained model. + +Trains from a random init checkpoint (Pythia step0), attributes eval loss +to training examples, and validates via leave-subset-out retraining. + +Usage: + python examples/magic_pretrain.py runs/magic_pretrain + + # Or via the CLI: + bergson magic runs/magic_pretrain \ + --model EleutherAI/pythia-160m \ + --revision step0 \ + --data.dataset EleutherAI/SmolLM2-135M-10B \ + --query.dataset EleutherAI/SmolLM2-135M-10B \ + --query.split "train[:1]" +""" + +from bergson.config import DataConfig, DistributedConfig +from bergson.double_backward import DoubleBackwardConfig, double_backward + + +def main(): + run_cfg = DoubleBackwardConfig( + run_path="runs/magic_pretrain", + model="EleutherAI/pythia-160m", + revision="step0", + data=DataConfig( + dataset="EleutherAI/SmolLM2-135M-10B", + split="train", + ), + query=DataConfig( + dataset="EleutherAI/SmolLM2-135M-10B", + split="train[:1]", + ), + query_batches=1, + lr=1e-5, + warmup_steps=10, + batch_size=8, + num_batches=25, + max_length=256, + num_subsets=100, + seed=42, + ) + dist_cfg = DistributedConfig() + double_backward(run_cfg, dist_cfg) + + +if __name__ == "__main__": + main()