diff --git a/bergson/__init__.py b/bergson/__init__.py index d05df6fa..aec17ea4 100644 --- a/bergson/__init__.py +++ b/bergson/__init__.py @@ -6,6 +6,7 @@ from .faiss_index import FaissConfig from .gradcheck import FiniteDiff from .gradients import GradientCollector, GradientProcessor +from .score_writer import MemmapScoreWriter __all__ = [ "collect_gradients", @@ -18,4 +19,5 @@ "IndexConfig", "DataConfig", "AttentionConfig", + "MemmapScoreWriter", ] diff --git a/bergson/__main__.py b/bergson/__main__.py index 27cf7a88..ac10ffa8 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -35,12 +35,14 @@ class Query: def execute(self): """Query the gradient dataset.""" + assert self.query_cfg.scores_path + assert self.query_cfg.query_path if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index: raise ValueError( "Index path already exists and save_index is True - " "running this query will overwrite the existing gradients. " - "If you meant to query the existing gradients, use " + "If you meant to query the existing gradients use " "Attributor instead." ) diff --git a/bergson/attributor.py b/bergson/attributor.py index 1035f06b..d37ccacc 100644 --- a/bergson/attributor.py +++ b/bergson/attributor.py @@ -39,7 +39,7 @@ def scores(self) -> Tensor: class Attributor: def __init__( self, - index_path: str, + index_path: Path, device: str = "cpu", dtype: torch.dtype = torch.float32, unit_norm: bool = False, @@ -59,7 +59,7 @@ def __init__( f"faiss_{faiss_cfg.index_factory.replace(',', '_')}" f"{'_cosine' if unit_norm else ''}" ) - faiss_path = Path(index_path) / faiss_index_name + faiss_path = index_path / faiss_index_name if not (faiss_path / "config.json").exists(): FaissIndex.create_index( diff --git a/bergson/build.py b/bergson/build.py index 8a032ab9..42120554 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -1,7 +1,10 @@ import json import os +import shutil import socket +from dataclasses import is_dataclass from datetime import timedelta +from pathlib import Path from typing import cast import pandas as pd @@ -133,7 +136,7 @@ def worker( print(f"Loading processor from '{cfg.processor_path}'") processor = GradientProcessor.load( - cfg.processor_path, + Path(cfg.processor_path), map_location=f"cuda:{rank}", ) else: @@ -170,6 +173,8 @@ def worker( save_index=cfg.save_index, save_processor=cfg.save_processor, drop_columns=cfg.drop_columns, + token_batch_size=cfg.token_batch_size, + module_wise=cfg.module_wise, ) else: # Convert each shard to a Dataset then map over its gradients @@ -185,7 +190,7 @@ def flush(): model, ds_shard, processor, - os.path.join(cfg.partial_run_path, f"shard-{shard_id:05d}"), + cfg.partial_run_path / f"shard-{shard_id:05d}", batches=batches, kl_divergence=cfg.loss_fn == "kl", loss_reduction=cfg.loss_reduction, @@ -196,6 +201,8 @@ def flush(): # Save a processor state checkpoint after each shard save_processor=cfg.save_processor, drop_columns=cfg.drop_columns, + token_batch_size=cfg.token_batch_size, + module_wise=cfg.module_wise, ) buf.clear() shard_id += 1 @@ -254,8 +261,14 @@ def build_gradient_dataset(cfg: IndexConfig): ) # Write index config to json + os.makedirs(cfg.partial_run_path, exist_ok=True) with open(os.path.join(cfg.partial_run_path, "index_config.json"), "w") as f: - json.dump(cfg, f) + index_cfg_dict = cfg.__dict__ + for key in index_cfg_dict: + if is_dataclass(index_cfg_dict[key]): + index_cfg_dict[key] = index_cfg_dict[key].__dict__ + + json.dump(index_cfg_dict, f) world_size = torch.cuda.device_count() if world_size <= 1: @@ -288,6 +301,6 @@ def build_gradient_dataset(cfg: IndexConfig): ctx.wait() try: - os.rename(cfg.partial_run_path, cfg.run_path) + shutil.move(cfg.partial_run_path, cfg.run_path) except Exception: pass diff --git a/bergson/collection.py b/bergson/collection.py index 7ef81da7..be2b2313 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -1,5 +1,6 @@ import math -from typing import Callable, Literal +from pathlib import Path +from typing import Literal import numpy as np import torch @@ -12,13 +13,14 @@ from .data import create_index, pad_and_tensor from .gradients import AttentionConfig, GradientCollector, GradientProcessor from .peft import set_peft_enabled +from .score_writer import ScoreWriter def collect_gradients( model: PreTrainedModel, data: Dataset, processor: GradientProcessor, - path: str, + path: Path, *, batches: list[list[int]] | None = None, kl_divergence: bool | None = None, @@ -29,7 +31,9 @@ def collect_gradients( save_index: bool = True, save_processor: bool = True, drop_columns: bool = False, - query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor] | None = None, + score_writer: ScoreWriter | None = None, + token_batch_size: int | None = None, + module_wise: bool = False, ): """ Compute projected gradients using a subset of the dataset. @@ -53,7 +57,7 @@ def collect_gradients( lo = torch.finfo(dtype).min hi = torch.finfo(dtype).max - def callback(name: str, g: torch.Tensor): + def callback(name: str, g: torch.Tensor, indices: list[int]): g = g.flatten(1).clamp_(lo, hi) if save_index: # Asynchronously move the gradient to CPU and convert to the final dtype @@ -61,6 +65,9 @@ def callback(name: str, g: torch.Tensor): else: mod_grads[name] = g.to(dtype=dtype) + if score_writer and module_wise: + score_writer(indices, mod_grads, name=name) + # Compute the outer product of the flattened gradient if not skip_preconditioners: g = g.float() @@ -94,12 +101,6 @@ def callback(name: str, g: torch.Tensor): dtype=dtype, fill_value=0.0, ) - per_doc_scores = torch.full( - (len(data),), - device=model.device, - dtype=dtype, - fill_value=0.0, - ) for indices in tqdm(batches, disable=rank != 0, desc="Building index"): batch = data[indices] @@ -118,6 +119,8 @@ def callback(name: str, g: torch.Tensor): set_peft_enabled(model, True) with collector: + collector.indices = indices + ft_lps = torch.log_softmax(model(x).logits[:, :-1], dim=-1) # Compute average KL across all unmasked tokens @@ -129,6 +132,8 @@ def callback(name: str, g: torch.Tensor): losses.mean().backward() else: with collector: + collector.indices = indices + logits = model(x).logits[:, :-1] losses = F.cross_entropy( @@ -156,9 +161,11 @@ def callback(name: str, g: torch.Tensor): for module_name in mod_grads.keys(): grad_buffer[module_name][indices] = mod_grads[module_name].numpy() - if query_callback is not None: - scores = query_callback(mod_grads) - per_doc_scores[indices] = scores.detach().type_as(per_doc_scores) + if score_writer is not None: + if module_wise: + score_writer.finalize_module_wise(indices) + else: + score_writer(indices, mod_grads) mod_grads.clear() per_doc_losses[indices] = losses.detach().type_as(per_doc_losses) @@ -178,13 +185,7 @@ def callback(name: str, g: torch.Tensor): feature=Value("float16" if dtype == torch.float16 else "float32"), new_fingerprint="loss", ) - data = data.add_column( - "scores", - per_doc_scores.cpu().numpy(), - feature=Value("float16" if dtype == torch.float16 else "float32"), - new_fingerprint="scores", - ) - data.save_to_disk(path + "/data.hf") + data.save_to_disk(path / "data.hf") if save_processor: processor.save(path) diff --git a/bergson/data.py b/bergson/data.py index 36083d26..8f76ac33 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -33,6 +33,9 @@ class DataConfig: split: str = "train" """Split of the dataset to use for building the index.""" + subset: str | None = None + """Subset of the dataset to use for building the index.""" + prompt_column: str = "text" """Column in the dataset that contains the prompts.""" @@ -70,29 +73,30 @@ class QueryConfig: """Config for querying an index on the fly.""" query_path: str = "" - """Path to the query dataset.""" + """Path to the existing query index.""" - query_method: Literal["mean", "nearest"] = "mean" - """Method to use for computing the query.""" + score: Literal["mean", "nearest", "individual"] = "mean" + """Method for scoring the gradients with the query. If mean + gradients will be scored by their similarity with the mean + query gradients, if max by the most similar query gradient, + if individual by each separate query gradient.""" - save_processor: bool = True - """Whether to write the query dataset gradient processor - to disk.""" + scores_path: str = "" + """Path to the directory where query scores should be written.""" query_preconditioner_path: str | None = None - """Path to a precomputed preconditioner. The precomputed - preconditioner is applied to the query dataset gradients.""" + """Path to a precomputed preconditioner to be applied to + the query dataset gradients.""" index_preconditioner_path: str | None = None - """Path to a precomputed preconditioner. The precomputed - preconditioner is applied to the query dataset gradients. - This does not affect the ability to compute a new - preconditioner during gradient collection.""" + """Path to a precomputed preconditioner to be applied to + the query dataset gradients. This does not affect the + ability to compute a new preconditioner during the query.""" - mixing_coefficient: float = 0.5 + mixing_coefficient: float = 0.99 """Coefficient to weight the application of the query preconditioner and the pre-computed index preconditioner. 0.0 means only use the - query preconditioner and 1.0 means only use the index preconditioner.""" + index preconditioner and 1.0 means only use the query preconditioner.""" modules: list[str] = field(default_factory=list) """Modules to use for the query. If empty, all modules will be used.""" @@ -162,6 +166,10 @@ class IndexConfig: loss_reduction: Literal["mean", "sum"] = "mean" """Reduction method for the loss function.""" + # TODO consider renaming this + module_wise: bool = False + """Whether to process the module gradients individually.""" + streaming: bool = False """Whether to use streaming mode for the dataset.""" @@ -179,9 +187,9 @@ class IndexConfig: Used for attention modules specified in `split_attention_modules`.""" @property - def partial_run_path(self) -> str: + def partial_run_path(self) -> Path: """Temporary path used while writing build artifacts.""" - return f"{self.run_path}.part" + return Path(self.run_path + ".part") def ceildiv(a: int, b: int) -> int: @@ -285,7 +293,7 @@ def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[lis # Split arbitrary (non-singleton) batches until we reach the target i = 0 - while len(batches) < target_batches: + while len(batches) < target_batches and i < len(batches): batch = batches[i % len(batches)] if len(batch) == 1: i += 1 # try another batch @@ -293,7 +301,11 @@ def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[lis batches.append([batch.pop()]) # split off a singleton i += 1 - assert len(batches) == target_batches + assert len(batches) == target_batches, ( + "Could not construct a number of batches divisible by the world size." + " If variability of item lengths in your dataset is low " + "consider using a different dataset size or token batch size." + ) assert all( max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches ) @@ -316,11 +328,11 @@ def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[lis def create_index( - root: str, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike + root: Path, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike ) -> np.memmap: """Create a memory-mapped file for storing structured gradients and persist metadata.""" - grad_path = os.path.join(root, "gradients.bin") + grad_path = root / "gradients.bin" rank = dist.get_rank() if dist.is_initialized() else 0 # Build a json-serializable structured dtype @@ -333,7 +345,7 @@ def create_index( # ── 1. Rank-0 creates file & metadata exactly once ───────────────────────── if rank == 0: # Ensure the directory exists - os.makedirs(root, exist_ok=True) + root.mkdir(parents=True, exist_ok=True) # Allocate (extends file to right size without writing zeros byte-by-byte) nbytes = np.dtype(struct_dtype).itemsize * num_grads # type: ignore @@ -344,7 +356,7 @@ def create_index( os.fsync(f.fileno()) # Persist metadata for future runs - with open(root + "/info.json", "w") as f: + with (root / "info.json").open("w") as f: json.dump({"num_grads": num_grads, "dtype": struct_dtype}, f, indent=2) # ── 2. Everyone blocks until the file is definitely there & sized ───────────── @@ -360,7 +372,10 @@ def create_index( def load_data_string( - data_str: str, split: str = "train", streaming: bool = False + data_str: str, + split: str = "train", + subset: str | None = None, + streaming: bool = False, ) -> Dataset | IterableDataset: """Load a dataset from a string identifier or path.""" if data_str.endswith(".csv"): @@ -369,7 +384,7 @@ def load_data_string( ds = assert_type(Dataset, Dataset.from_json(data_str)) else: try: - ds = load_dataset(data_str, split=split, streaming=streaming) + ds = load_dataset(data_str, subset, split=split, streaming=streaming) if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): raise NotImplementedError( @@ -385,17 +400,17 @@ def load_data_string( return ds -def load_gradients(root_dir: str) -> np.memmap: +def load_gradients(root_dir: Path) -> np.memmap: """Map the structured gradients stored in `root_dir` into memory.""" - with open(os.path.join(root_dir, "info.json")) as f: + with (root_dir / "info.json").open("r") as f: info = json.load(f) dtype = info["dtype"] num_grads = info["num_grads"] return np.memmap( - os.path.join(root_dir, "gradients.bin"), + root_dir / "gradients.bin", dtype=dtype, mode="r", shape=(num_grads,), @@ -403,13 +418,13 @@ def load_gradients(root_dir: str) -> np.memmap: def load_gradient_dataset( - root_dir: str, concatenate_gradients: bool = False + root_dir: Path, concatenate_gradients: bool = False ) -> Dataset: """Load a dataset of gradients from `root_dir`.""" - def load_shard(dir: str) -> Dataset: + def load_shard(dir: Path) -> Dataset: mmap = load_gradients(dir) - ds = Dataset.load_from_disk(dir + "/data.hf") + ds = Dataset.load_from_disk(dir / "data.hf") # concatenate the extracted module gradients into a single column if concatenate_gradients: @@ -428,14 +443,12 @@ def load_shard(dir: str) -> Dataset: ds = ds.add_column(field_name, col, new_fingerprint=field_name) return ds - root = Path(root_dir) - - if (root / "data.hf").exists(): + if (root_dir / "data.hf").exists(): return load_shard(root_dir) # Flatten indices to avoid CPU OOM return concatenate_datasets( - [load_shard(str(path)) for path in sorted(root.iterdir()) if path.is_dir()] + [load_shard(path) for path in sorted(root_dir.iterdir()) if path.is_dir()] ).flatten_indices() diff --git a/bergson/faiss_index.py b/bergson/faiss_index.py index 33de5a15..d77e2a4e 100644 --- a/bergson/faiss_index.py +++ b/bergson/faiss_index.py @@ -1,5 +1,4 @@ import json -import os from dataclasses import dataclass from pathlib import Path from time import perf_counter @@ -93,25 +92,24 @@ def normalize_grads( return normalized_grads -def gradients_loader(root_dir: str): - def load_shard(shard_dir: str) -> np.memmap: - with open(os.path.join(shard_dir, "info.json")) as f: +def gradients_loader(root_dir: Path): + def load_shard(shard_dir: Path) -> np.memmap: + with (shard_dir / "info.json").open("r") as f: info = json.load(f) return np.memmap( - os.path.join(shard_dir, "gradients.bin"), + shard_dir / "gradients.bin", dtype=info["dtype"], mode="r", shape=(info["num_grads"],), ) - root_path = Path(root_dir) - if (root_path / "info.json").exists(): + if (root_dir / "info.json").exists(): yield load_shard(root_dir) else: - for path in sorted(root_path.iterdir()): + for path in sorted(root_dir.iterdir()): if "shard" in path.name: - yield load_shard(str(path)) + yield load_shard(path) def _require_faiss() -> ModuleType: @@ -196,7 +194,7 @@ def __init__(self, path: Path, device: str, mmap_index: bool): @staticmethod def create_index( - gradients_path: str, + gradients_path: Path, faiss_path: Path, faiss_cfg: FaissConfig, device: str, @@ -210,13 +208,12 @@ def create_index( faiss_path.mkdir(exist_ok=True, parents=True) # Write the gradients into an on-disk FAISS index - root_path = Path(gradients_path) - if (root_path / "info.json").exists(): - info_paths = [root_path / "info.json"] + if (gradients_path / "info.json").exists(): + info_paths = [gradients_path / "info.json"] else: info_paths = [ shard_path / "info.json" - for shard_path in root_path.iterdir() + for shard_path in gradients_path.iterdir() if (shard_path / "info.json").exists() ] @@ -312,7 +309,7 @@ def build_shard_from_buffer( json.dump( { "faiss_cfg": faiss_cfg.__dict__, - "gradients_path": gradients_path, + "gradients_path": str(gradients_path), "device": device, "unit_norm": unit_norm, "ordered_modules": ordered_modules, diff --git a/bergson/gradients.py b/bergson/gradients.py index 664730ba..e3f74f8e 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -1,8 +1,8 @@ import json -import os from abc import ABC, abstractmethod from contextlib import ContextDecorator from dataclasses import asdict, astuple, dataclass, field +from pathlib import Path from typing import Callable, Literal, Mapping import torch @@ -219,20 +219,20 @@ def __post_init__(self): @classmethod def load( cls, - path: str, + path: Path, *, map_location: str | torch.device | None = None, ) -> "GradientProcessor": """ Load the normalizers and preconditioners from a file. """ - cfg_path = os.path.join(path, "processor_config.json") - norm_path = os.path.join(path, "normalizers.pth") - precond_path = os.path.join(path, "preconditioners.pth") - precond_eigen_path = os.path.join(path, "preconditioners_eigen.pth") + cfg_path = path / "processor_config.json" + norm_path = path / "normalizers.pth" + precond_path = path / "preconditioners.pth" + precond_eigen_path = path / "preconditioners_eigen.pth" # Load configuration - with open(cfg_path, "r") as f: + with cfg_path.open("r") as f: cfg = json.load(f) # Backward compatibility @@ -267,23 +267,23 @@ def load( **cfg, ) - def save(self, path: str): + def save(self, path: Path): """ Save the normalizers and preconditioners to a file. """ - os.makedirs(path, exist_ok=True) + path.mkdir(parents=True, exist_ok=True) - cfg_path = os.path.join(path, "processor_config.json") - norm_path = os.path.join(path, "normalizers.pth") - precond_path = os.path.join(path, "preconditioners.pth") - precond_eigen_path = os.path.join(path, "preconditioners_eigen.pth") + cfg_path = path / "processor_config.json" + norm_path = path / "normalizers.pth" + precond_path = path / "preconditioners.pth" + precond_eigen_path = path / "preconditioners_eigen.pth" # Save configuration separately cfg = asdict(self) del cfg["normalizers"] del cfg["preconditioners"] del cfg["preconditioners_eigen"] - with open(cfg_path, "w") as f: + with cfg_path.open("w") as f: json.dump(cfg, f, indent=2) # Save normalizers @@ -358,6 +358,9 @@ class GradientCollector(ContextDecorator): Dictionary of head configurations for each module to be split into head matrices. """ + indices: list[int] = field(default_factory=list) + """List of indices for the current batch.""" + def __post_init__(self): self._fwd_hooks: list[RemovableHandle] = [] self._bwd_hooks: list[RemovableHandle] = [] @@ -598,7 +601,10 @@ def _process_grad(self, module: nn.Module, _, grad_out): P = G.mT @ I # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q] - self.closure(name, P) + if self.indices: + self.closure(name, P, self.indices) + else: + self.closure(name, P) # Save memory ASAP del module._inputs diff --git a/bergson/huggingface.py b/bergson/huggingface.py index a9bb40a3..8dbb9693 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -2,6 +2,7 @@ import os from functools import wraps from itertools import chain +from pathlib import Path from typing import Sized import numpy as np @@ -28,7 +29,7 @@ class GradientCollectorCallback(TrainerCallback): def __init__( self, - path: str, + path: Path, attention_cfgs: dict[str, AttentionConfig] = {}, projection_dim: int = 16, include_bias: bool = False, @@ -153,7 +154,7 @@ def on_epoch_begin( raise ValueError("Dataset must be sized for gradient collection") self.train_grad_buffer = create_index( - os.path.join(self.path, "train" + epoch_suffix), + self.path / ("train" + epoch_suffix), num_grads=len(ds), grad_sizes=self.grad_sizes, dtype=self.dtype, @@ -170,7 +171,7 @@ def on_epoch_begin( for dataset_name, dataloader in eval_datasets.items(): self.eval_grad_buffers[dataset_name] = create_index( - os.path.join(self.path, dataset_name + epoch_suffix), + self.path / (dataset_name + epoch_suffix), num_grads=len(dataloader), grad_sizes=self.grad_sizes, dtype=self.dtype, @@ -191,7 +192,7 @@ def on_epoch_end( if rank == 0: epoch = int(state.epoch or 0) - 1 epoch_suffix = "" if self.accumulate_grads else f"/epoch_{epoch}" - path = os.path.join(self.path, "train" + epoch_suffix) + path = self.path / ("train" + epoch_suffix) assert self.collector is not None self.collector.processor.save(path) @@ -238,6 +239,7 @@ def on_step_end( **kwargs, ): self.on_substep_end(args, state, control) + print("Step end") # Record training order if enabled if self.order is not None: diff --git a/bergson/query.py b/bergson/query.py index 6c3ca03c..1de340a3 100644 --- a/bergson/query.py +++ b/bergson/query.py @@ -1,7 +1,8 @@ -import json import os +import shutil import socket from datetime import timedelta +from pathlib import Path from typing import cast import torch @@ -26,154 +27,21 @@ QueryConfig, allocate_batches, load_data_string, - load_gradient_dataset, tokenize, ) from .gradients import GradientProcessor from .peft import detect_peft_modules +from .score_writer import MemmapScoreWriter +from .scorer import get_scorer from .utils import assert_type, get_layer_list -def get_query_data(index_cfg: IndexConfig, query_cfg: QueryConfig): - """ - Load and optionally precondition the query dataset. Preconditioners - may be mixed as described in https://arxiv.org/html/2410.17413v1#S3. - """ - # Collect the query gradients if they don't exist - if not os.path.exists(query_cfg.query_path): - raise FileNotFoundError( - f"Query dataset not found at {query_cfg.query_path}. " - "Please build a query dataset index first." - ) - - # Load the query dataset - with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f: - target_modules = json.load(f)["dtype"]["names"] - - query_ds = load_gradient_dataset(query_cfg.query_path, concatenate_gradients=False) - query_ds = query_ds.with_format("torch", columns=target_modules) - - use_q = query_cfg.query_preconditioner_path is not None - use_i = query_cfg.index_preconditioner_path is not None - - if use_q or use_i: - q, i = {}, {} - if use_q: - assert query_cfg.query_preconditioner_path is not None - q = GradientProcessor.load( - query_cfg.query_preconditioner_path, - map_location="cuda", - ).preconditioners - if use_i: - assert query_cfg.index_preconditioner_path is not None - i = GradientProcessor.load( - query_cfg.index_preconditioner_path, map_location="cuda" - ).preconditioners - - mixed_preconditioner = ( - { - k: q[k] * query_cfg.mixing_coefficient - + i[k] * (1 - query_cfg.mixing_coefficient) - for k in q - } - if (q and i) - else (q or i) - ) - mixed_preconditioner = {k: v.cuda() for k, v in mixed_preconditioner.items()} - - def precondition(batch): - for name in target_modules: - batch[name] = (batch[name].cuda() @ mixed_preconditioner[name]).cpu() - - return batch - - query_ds = query_ds.map( - precondition, batched=True, batch_size=query_cfg.batch_size - ) - - return query_ds - - -def get_mean_query( - query_ds: Dataset, query_cfg: QueryConfig, device: torch.device, dtype: torch.dtype -): - """ - Compute the mean query and return a callback function that scores gradients - according to their inner products or cosine similarities with the mean query. - """ - acc = { - module: torch.zeros_like( - query_ds[0][module], device=device, dtype=torch.float32 - ) - for module in query_cfg.modules - } - - def sum_(*cols): - for module, x in zip(query_cfg.modules, cols): - if query_cfg.unit_normalize: - x = x / (x.norm(dim=1, keepdim=True) + 1e-12) - acc[module] += x.to(device=device, dtype=torch.float32).sum(0) - - query_ds.map( - sum_, - input_columns=query_cfg.modules, - batched=True, - batch_size=query_cfg.batch_size, - ) - - callback_query = torch.cat( - [ - (acc[module] / len(query_ds)).to(device=device, dtype=dtype) - for module in query_cfg.modules - ], - dim=0, - ) - - @torch.inference_mode() - def callback(mod_grads: dict[str, torch.Tensor]): - grads = torch.cat([mod_grads[name] for name in query_cfg.modules], dim=1) - if query_cfg.unit_normalize: - grads /= grads.norm(dim=1, keepdim=True) - return grads @ callback_query - - return callback - - -def get_nearest_query( - query_ds: Dataset, query_cfg: QueryConfig, device: torch.device, dtype: torch.dtype -): - """ - Return a callback function that scores gradients according to their cosine - similarities or inner products with the most similar gradient in the query - dataset. - """ - - queries = torch.cat([query_ds[:][name] for name in query_cfg.modules], dim=1).to( - device=device, dtype=dtype - ) - - if query_cfg.unit_normalize: - queries /= queries.norm(dim=1, keepdim=True) - - def callback(mod_grads: dict[str, torch.Tensor]): - grads = torch.cat([mod_grads[name] for name in query_cfg.modules], dim=1) - if query_cfg.unit_normalize: - grads /= grads.norm(dim=1, keepdim=True) - - # Calculate scores as the max of the inner products with the queries - all_scores = grads @ queries.T - return all_scores.max(dim=-1).values - - return callback - - def worker( rank: int, world_size: int, index_cfg: IndexConfig, query_cfg: QueryConfig, ds: Dataset | IterableDataset, - query_ds: Dataset, ): torch.cuda.set_device(rank) @@ -270,8 +138,8 @@ def worker( # Shard the entire model fully_shard(model) - processor_dir = index_cfg.processor_path or index_cfg.run_path - processor_cfg_path = os.path.join(processor_dir, "processor_config.json") + processor_dir = Path(index_cfg.processor_path or index_cfg.run_path) + processor_cfg_path = processor_dir / "processor_config.json" if os.path.exists(processor_cfg_path): if rank == 0: @@ -299,24 +167,24 @@ def worker( else: attention_cfgs = {} - with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f: - query_cfg.modules = json.load(f)["dtype"]["names"] - - query_ds = query_ds.with_format("torch", columns=query_cfg.modules) - - query_device = torch.device(f"cuda:{rank}") - query_dtype = dtype if dtype != "auto" else torch.float16 + score_writer_dtype = dtype if dtype != "auto" else torch.float32 + if isinstance(ds, Dataset): - if query_cfg.query_method == "mean": - query_callback = get_mean_query(query_ds, query_cfg, query_device, query_dtype) - elif query_cfg.query_method == "nearest": - query_callback = get_nearest_query( - query_ds, query_cfg, query_device, query_dtype + scorer = get_scorer( + query_cfg, + index_cfg.module_wise, + torch.device(f"cuda:{rank}"), + score_writer_dtype, + ) + score_writer = MemmapScoreWriter( + scorer, + len(ds), + Path(query_cfg.scores_path), + rank=rank, + modules=query_cfg.modules, + module_wise=index_cfg.module_wise, + dtype=score_writer_dtype, ) - else: - raise ValueError(f"Invalid query method: {query_cfg.query_method}") - - if isinstance(ds, Dataset): batches = allocate_batches(ds["length"][:], index_cfg.token_batch_size) collect_gradients( model, @@ -330,9 +198,10 @@ def worker( target_modules=target_modules, attention_cfgs=attention_cfgs, drop_columns=index_cfg.drop_columns, - query_callback=query_callback, + score_writer=score_writer, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, + module_wise=index_cfg.module_wise, ) else: # Convert each shard to a Dataset then collect its gradients @@ -346,11 +215,25 @@ def flush(): batches = allocate_batches( ds_shard["length"][:], index_cfg.token_batch_size ) + scorer = get_scorer( + query_cfg, + index_cfg.module_wise, + torch.device(f"cuda:{rank}"), + score_writer_dtype, + ) + score_writer = MemmapScoreWriter( + scorer, + len(ds_shard), + Path(query_cfg.scores_path) / f"shard-{shard_id:05d}", + rank=rank, + modules=query_cfg.modules, + module_wise=index_cfg.module_wise, + ) collect_gradients( model, ds_shard, processor, - os.path.join(index_cfg.partial_run_path, f"shard-{shard_id:05d}"), + index_cfg.partial_run_path / f"shard-{shard_id:05d}", batches=batches, kl_divergence=index_cfg.loss_fn == "kl", loss_reduction=index_cfg.loss_reduction, @@ -358,7 +241,7 @@ def flush(): target_modules=target_modules, attention_cfgs=attention_cfgs, drop_columns=index_cfg.drop_columns, - query_callback=query_callback, + score_writer=score_writer, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, ) @@ -378,10 +261,9 @@ def dist_worker( index_cfg: IndexConfig, query_cfg: QueryConfig, ds: Dataset, - query_ds: Dataset, ): try: - worker(rank, world_size, index_cfg, query_cfg, ds, query_ds) + worker(rank, world_size, index_cfg, query_cfg, ds) finally: dist.destroy_process_group() @@ -416,13 +298,11 @@ def query_gradient_dataset(query_cfg: QueryConfig, index_cfg: IndexConfig): new_fingerprint="advantage", # type: ignore ) - query_ds = get_query_data(index_cfg, query_cfg) - world_size = torch.cuda.device_count() if world_size <= 1: # Run the worker directly if no distributed training is needed. This is great # for debugging purposes. - worker(0, 1, index_cfg, query_cfg, ds, query_ds) + worker(0, 1, index_cfg, query_cfg, ds) else: # Set up multiprocessing and distributed training mp.set_sharing_strategy("file_system") @@ -436,8 +316,7 @@ def query_gradient_dataset(query_cfg: QueryConfig, index_cfg: IndexConfig): "query", dist_worker, args={ - i: (i, world_size, index_cfg, query_cfg, ds, query_ds) - for i in range(world_size) + i: (i, world_size, index_cfg, query_cfg, ds) for i in range(world_size) }, envs={ i: { @@ -452,6 +331,6 @@ def query_gradient_dataset(query_cfg: QueryConfig, index_cfg: IndexConfig): ctx.wait() try: - os.rename(index_cfg.partial_run_path, index_cfg.run_path) + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) except Exception: pass diff --git a/bergson/score_writer.py b/bergson/score_writer.py new file mode 100644 index 00000000..7f97b316 --- /dev/null +++ b/bergson/score_writer.py @@ -0,0 +1,206 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist + +from .scorer import Scorer + + +class ScoreWriter(ABC): + """ + Base class for score writers. + """ + + @abstractmethod + def __call__( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + name: str | None = None, + ): + """ + Write the scores to the score writer. + """ + raise NotImplementedError("Subclasses must implement this method") + + @abstractmethod + def flush(self): + """ + Flush the score writer. + """ + raise NotImplementedError("Subclasses must implement this method") + + @abstractmethod + def finalize_module_wise(self, indices: list[int]): + """ + Finalize the module-wise scores and write to the memmap. + """ + raise NotImplementedError( + "Module-wise scoring is not supported by this score writer." + ) + + +class MemmapScoreWriter(ScoreWriter): + """ + Wraps a score scoring callback and stores the resulting scores in a tensor. + """ + + def __init__( + self, + scorer: Scorer, + num_items: int, + scores_path: Path, + *, + dtype: torch.dtype = torch.float32, + rank: int, + modules: list[str], + module_wise: bool = False, + flush_batches_interval: int = 40, + ): + self.scorer = scorer + self.num_scores = scorer.num_scores + self.scores_path = scores_path + self.rank = rank + self.dtype = dtype + self.module_wise = module_wise + + self.flush_interval = flush_batches_interval + self.num_batches_since_flush = 0 + + self.num_modules = len(modules) + + self.scores_path.mkdir(parents=True, exist_ok=True) + scores_file_path = self.scores_path / "scores.bin" + + # Build a json-serializable structured dtype + names = [] + formats = [] + offsets = [] + for i in range(self.scorer.num_scores): + names.append(f"score_{i}") + formats.append("float32") + offsets.append(i * 6) + + names.append(f"written_{i}") + formats.append("bool") + offsets.append(i * 6 + 4) + + total_bytes = sum(np.dtype(fmt).itemsize for fmt in formats) + # Round up to the nearest 8 bytes + itemsize = ((total_bytes + 7) // 8) * 8 + + struct_dtype = { + "names": names, + "formats": formats, + "offsets": offsets, + "itemsize": itemsize, + } + + if rank == 0 and not scores_file_path.exists(): + print(f"Creating new scores file: {scores_file_path}") + + self.scores = np.memmap( + str(scores_file_path), + dtype=np.dtype(struct_dtype), # type: ignore + mode="w+", + shape=(num_items,), + ) + + # Write zeros + zeros = np.zeros(len(self.scores), dtype=np.float32) + for name in names: + if "score" in name: + self.scores[name][:] = zeros + if "written" in name: + self.scores[name][:] = False + self.flush() + + # Persist metadata for future runs + with (scores_path / "info.json").open("w") as f: + json.dump( + { + "num_items": num_items, + "num_modules": self.num_modules, + "dtype": struct_dtype, + }, + f, + indent=2, + ) + + if dist.is_initialized(): + dist.barrier() + + self.scores = np.memmap( + str(scores_file_path), + dtype=np.dtype(struct_dtype), # type: ignore + mode="r+", + shape=(num_items,), + ) + print(f"Loaded {len(self.scores)} scores from {scores_file_path}") + + self.module_wise_scores = {} + self.module_wise_sum_squares = {} + + def _write_to_memmap(self, indices: list[int], scores: torch.Tensor): + print("len indices", len(indices)) + print("scores shape", scores.shape) + # scores is [len(indices), num_scores] + for i in range(self.num_scores): + self.scores[f"score_{i}"][indices] = ( + scores[:, i].cpu().numpy().astype(np.float32).flatten() + ) + self.scores[f"written_{i}"][indices] = True + + self.num_batches_since_flush += 1 + if self.num_batches_since_flush >= self.flush_interval: + self.flush() + + def __call__( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + name: str | None = None, + ): + # Module-wise scores + if name: + scores, sum_of_squares = self.scorer(mod_grads, name=name) + self.module_wise_scores[name] = scores.to(device="cpu", dtype=self.dtype) + self.module_wise_sum_squares[name] = sum_of_squares.to( + device="cpu", dtype=self.dtype + ) + else: + scores = self.scorer(mod_grads) + scores = scores.to(device="cpu", dtype=self.dtype) + + self._write_to_memmap(indices, scores) + + def finalize_module_wise(self, indices: list[int]): + """Finalize the score by accumulating module-wise scores and writing + to the memmap. Normalize with the sum of squares if needed.""" + + # Accumulate scores + scores = torch.cat( + [scores for scores in self.module_wise_scores.values()], dim=1 + ) + + # Normalize with the sum of squares + if self.module_wise_sum_squares: + # [num_modules, num_items, num_scores] -> [num_items, num_scores] + sum_of_squares = torch.stack( + [ + sum_of_squares + for sum_of_squares in self.module_wise_sum_squares.values() + ], + ).sum(dim=0) + assert scores.shape[0] == sum_of_squares.shape[0] + scores *= sum_of_squares.rsqrt() + + # Write accumulated scores + self._write_to_memmap(indices, scores) + + def flush(self): + self.scores.flush() + self.num_batches_since_flush = 0 diff --git a/bergson/scorer.py b/bergson/scorer.py new file mode 100644 index 00000000..3ea74392 --- /dev/null +++ b/bergson/scorer.py @@ -0,0 +1,265 @@ +import json +from pathlib import Path +from typing import Callable, Literal + +import torch +from datasets import Dataset + +from .data import QueryConfig, load_gradient_dataset +from .gradients import GradientProcessor + + +class Scorer: + num_scores: int + + def __init__(self, callback: Callable, num_scores: int): + self.callback = callback + self.num_scores = num_scores + + def __call__(self, mod_grads: dict[str, torch.Tensor], **kwargs): + return self.callback(mod_grads, **kwargs) + + +def preprocess_grads( + grad_ds: Dataset, + grad_column_names: list[str], + unit_normalize: bool, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + accumulate_grads: Literal["mean", "sum", "none"] = "none", + normalize_accumulated_grad: bool = False, +) -> dict[str, torch.Tensor]: + """Preprocess the gradients in the dataset. Returns a dictionary + of preprocessed gradients with shape [1, grad_dim]. Preprocessing + includes some combination of unit normalization, accumulation, + accumulated gradient normalization, and dtype conversion.""" + + # Short-circuit if possible + if accumulate_grads == "none" and not unit_normalize: + return { + column_name: grad_ds[:][column_name].to(device=device, dtype=dtype) + for column_name in grad_column_names + } + + # Get sum and sum of squares of the gradients + acc = { + column_name: torch.zeros_like( + grad_ds[0][column_name], device=device, dtype=torch.float32 + ) + for column_name in grad_column_names + } + ss_acc = torch.tensor(0.0, device=device, dtype=torch.float32) + if not unit_normalize: + ss_acc.fill_(1.0) + + def sum_(cols): + nonlocal ss_acc + + for column_name in grad_column_names: + x = cols[column_name].to(device=device, dtype=torch.float32) + acc[column_name].add_(x.sum(0)) + + if unit_normalize: + # To normalize the mean gradient we can divide by the sum of + # squares of every gradient element in the dataset + ss_acc += x.pow(2).sum() + + grad_ds.map( + sum_, + batched=True, + batch_size=batch_size, + ) + + ss_acc = ss_acc.sqrt() + assert ss_acc > 0, "Sum of squares of entire dataset is zero" + + # Process the gradient dataset + if accumulate_grads == "mean": + grads = { + column_name: (acc[column_name] / ss_acc / len(grad_ds)) + .unsqueeze(0) + .to(dtype) + for column_name in grad_column_names + } + elif accumulate_grads == "sum": + grads = { + column_name: (acc[column_name] / ss_acc).unsqueeze(0).to(dtype) + for column_name in grad_column_names + } + elif accumulate_grads == "none": + grads = { + column_name: grad_ds[:][column_name].to(device=device, dtype=dtype) / ss_acc + for column_name in grad_column_names + } + else: + raise ValueError(f"Invalid accumulate_grads: {accumulate_grads}") + + # Normalize the accumulated gradient + if normalize_accumulated_grad: + grad_norm = torch.cat( + [grads[column_name].flatten() for column_name in grad_column_names], dim=0 + ).norm() + for column_name in grad_column_names: + grads[column_name] /= grad_norm + + return grads + + +def build_scorer( + query_ds: Dataset, + query_cfg: QueryConfig, + device: torch.device, + dtype: torch.dtype, + *, + accumulate_grads: Literal["mean", "sum", "none"] = "none", + module_wise: bool = False, + nearest: bool = False, + normalize_accumulated_grad: bool = False, +) -> Scorer: + """Unified scorer builder for all scorer types.""" + + query_grads = preprocess_grads( + query_ds, + query_cfg.modules, + query_cfg.unit_normalize, + query_cfg.batch_size, + device, + dtype, + accumulate_grads=accumulate_grads, + normalize_accumulated_grad=normalize_accumulated_grad, + ) + + if not module_wise: + query_tensor = torch.cat( + [query_grads[m].to(device=device, dtype=dtype) for m in query_cfg.modules], + dim=1, + ) + else: + query_tensor = None + + @torch.inference_mode() + def callback(mod_grads: dict[str, torch.Tensor], **kwargs): + if query_tensor is None: + name = kwargs["name"] + module_scores = mod_grads[name] @ query_grads[name].T + ssq = mod_grads[name].pow(2).sum(dim=1) + return module_scores, ssq + + grads = torch.cat([mod_grads[m] for m in query_cfg.modules], dim=1) + if query_cfg.unit_normalize: + grads /= grads.norm(dim=1, keepdim=True) + + if nearest: + all_scores = grads @ query_tensor.T + return all_scores.max(dim=-1).values + + return grads @ query_tensor.T + + num_scores = len(query_grads[query_cfg.modules[0]]) + + return Scorer(callback, num_scores) + + +def get_query_ds(query_cfg: QueryConfig): + """ + Load and optionally precondition the query dataset. Preconditioners + may be mixed as described in https://arxiv.org/html/2410.17413v1#S3. + """ + # Collect the query gradients if they don't exist + query_path = Path(query_cfg.query_path) + if not query_path.exists(): + raise FileNotFoundError( + f"Query dataset not found at {query_cfg.query_path}. " + "Please build a query dataset index first." + ) + + # Load the query dataset + with open(query_path / "info.json", "r") as f: + target_modules = json.load(f)["dtype"]["names"] + + if not query_cfg.modules: + query_cfg.modules = target_modules + + query_ds = load_gradient_dataset( + Path(query_cfg.query_path), concatenate_gradients=False + ) + query_ds = query_ds.with_format("torch", columns=target_modules) + + use_q = query_cfg.query_preconditioner_path is not None + use_i = query_cfg.index_preconditioner_path is not None + + if use_q or use_i: + q, i = {}, {} + if use_q: + assert query_cfg.query_preconditioner_path is not None + q = GradientProcessor.load( + Path(query_cfg.query_preconditioner_path), + map_location="cuda", + ).preconditioners + if use_i: + assert query_cfg.index_preconditioner_path is not None + i = GradientProcessor.load( + Path(query_cfg.index_preconditioner_path), map_location="cuda" + ).preconditioners + + mixed_preconditioner = ( + { + k: q[k] * query_cfg.mixing_coefficient + + i[k] * (1 - query_cfg.mixing_coefficient) + for k in q + } + if (q and i) + else (q or i) + ) + mixed_preconditioner = {k: v.cuda() for k, v in mixed_preconditioner.items()} + + def precondition(batch): + for name in target_modules: + batch[name] = (batch[name].cuda() @ mixed_preconditioner[name]).cpu() + + return batch + + query_ds = query_ds.map( + precondition, batched=True, batch_size=query_cfg.batch_size + ) + + return query_ds.with_format("torch", columns=query_cfg.modules) + + +def get_scorer( + query_cfg: QueryConfig, module_wise: bool, device: torch.device, dtype: torch.dtype +) -> Scorer: + query_ds = get_query_ds(query_cfg) + + if query_cfg.score == "mean": + return build_scorer( + query_ds, + query_cfg, + device, + dtype, + module_wise=module_wise, + accumulate_grads="mean", + ) + elif query_cfg.score == "nearest": + assert not module_wise, "Module-wise scoring not supported for nearest query" + return build_scorer( + query_ds, + query_cfg, + device, + dtype, + accumulate_grads="none", + module_wise=False, + nearest=True, + ) + elif query_cfg.score == "individual": + return build_scorer( + query_ds, + query_cfg, + device, + dtype, + module_wise=module_wise, + accumulate_grads="none", + ) + else: + raise ValueError(f"Invalid scoring method: {query_cfg.score}") diff --git a/bergson/utils.py b/bergson/utils.py index 8d4a7ca9..7529e4bf 100644 --- a/bergson/utils.py +++ b/bergson/utils.py @@ -18,7 +18,7 @@ def assert_type(typ: Type[T], obj: Any) -> T: def get_layer_list(model: PreTrainedModel) -> nn.ModuleList: - """Get the list of layers to train SAEs on.""" + """Get the list of layers to train on.""" N = assert_type(int, model.config.num_hidden_layers) candidates = [ mod diff --git a/examples/trainer_grad_collection.py b/examples/trainer_grad_collection.py index 612ce112..631705d5 100644 --- a/examples/trainer_grad_collection.py +++ b/examples/trainer_grad_collection.py @@ -1,6 +1,7 @@ import os import socket from datetime import timedelta +from pathlib import Path import torch import torch.distributed as dist @@ -56,7 +57,7 @@ def worker( ) callback = GradientCollectorCallback( - f"{run_name}/gradients", + Path(run_name) / "gradients", accumulate_grads=True, ) diff --git a/tests/test_attributor.py b/tests/test_attributor.py index 48c23e52..2de1a26b 100644 --- a/tests/test_attributor.py +++ b/tests/test_attributor.py @@ -10,10 +10,13 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_attributor(tmp_path: Path, model, dataset): collect_gradients( - model=model, data=dataset, processor=GradientProcessor(), path=str(tmp_path) + model=model, + data=dataset, + processor=GradientProcessor(projection_dim=16), + path=tmp_path, ) - attr = Attributor(str(tmp_path), device="cpu", unit_norm=True) + attr = Attributor(tmp_path, device="cpu", unit_norm=True) x = torch.tensor(dataset[0]["input_ids"]).unsqueeze(0) @@ -35,11 +38,11 @@ def test_faiss(tmp_path: Path, model, dataset): model=model, data=dataset, processor=GradientProcessor(projection_dim=16), - path=str(tmp_path), + path=tmp_path, ) attr = Attributor( - str(tmp_path), + tmp_path, device="cuda", unit_norm=True, faiss_cfg=FaissConfig(), diff --git a/tests/test_build.py b/tests/test_build.py index d1cfe16f..dbe94b6f 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,3 +1,4 @@ +import subprocess from pathlib import Path import numpy as np @@ -13,16 +14,39 @@ from bergson.data import load_gradients +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_build_e2e(tmp_path: Path): + result = subprocess.run( + [ + "python", + "-m", + "bergson", + "build", + "test_e2e", + "--model", + "EleutherAI/pythia-14m", + "--dataset", + "NeelNanda/pile-10k", + "--split", + "train[:100]", + "--truncation", + ], + cwd=tmp_path, + ) + + assert result.returncode == 0 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_build_consistency(tmp_path: Path, model, dataset): collect_gradients( model=model, data=dataset, processor=GradientProcessor(), - path=str(tmp_path), + path=tmp_path, skip_preconditioners=True, ) - index = load_gradients(str(tmp_path)) + index = load_gradients(tmp_path) # Regenerate cache cache_path = Path("runs/test_build_cache.npy") @@ -47,7 +71,7 @@ def test_split_attention_build(tmp_path: Path, model, dataset): model=model, data=dataset, processor=GradientProcessor(projection_dim=16), - path=str(tmp_path), + path=tmp_path, attention_cfgs=attention_cfgs, ) @@ -65,8 +89,8 @@ def test_conv1d_build(tmp_path: Path, dataset): collect_gradients( model=model, data=dataset, - processor=GradientProcessor(), - path=str(tmp_path), + processor=GradientProcessor(projection_dim=16), + path=tmp_path, # This build hangs in pytest with preconditioners enabled. # It works when run directly so it may be a pytest issue. skip_preconditioners=True, @@ -74,7 +98,7 @@ def test_conv1d_build(tmp_path: Path, dataset): assert any(tmp_path.iterdir()), "Expected artifacts in the run path" - index = load_gradients(str(tmp_path)) + index = load_gradients(tmp_path) assert len(modules := index.dtype.names) != 0 assert len(index[modules[0]]) == len(dataset) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index bcf1d727..19528c36 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -67,12 +67,12 @@ def closure(name: str, g: torch.Tensor): previous_collected_grads = {} for do_load in (False, True): if do_load: - processor = GradientProcessor.load(str(temp_dir / "processor")) + processor = GradientProcessor.load(temp_dir / "processor") else: processor = GradientProcessor( normalizers=normalizers, projection_dim=p ) - processor.save(str(temp_dir / "processor")) + processor.save(temp_dir / "processor") collector = GradientCollector(model, closure, processor) with collector: model.zero_grad() diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 00000000..a383e6a2 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,108 @@ +from pathlib import Path + +import pytest +import torch +from datasets import Dataset + +from bergson import ( + GradientCollector, + GradientProcessor, + MemmapScoreWriter, + collect_gradients, +) +from bergson.data import QueryConfig +from bergson.scorer import build_scorer + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_query(tmp_path: Path, model, dataset): + processor = GradientProcessor(projection_dim=16) + shapes = GradientCollector(model.base_model, lambda x: x, processor).shapes() + + query_gradient_ds = Dataset.from_list( + [ + {module: torch.randn(shape.numel()) for module, shape in shapes.items()} + for _ in range(2) + ] + ).with_format("torch", columns=list(shapes.keys())) + + scorer = build_scorer( + query_gradient_ds, + QueryConfig( + query_path=str(tmp_path / "query_gradient_ds"), + modules=list(shapes.keys()), + score="mean", + ), + device=torch.device("cpu"), + dtype=torch.float32, + module_wise=False, + accumulate_grads="mean", + normalize_accumulated_grad=True, + ) + score_writer = MemmapScoreWriter( + scorer, + len(dataset), + tmp_path, + rank=0, + modules=list(shapes.keys()), + module_wise=False, + ) + + collect_gradients( + model=model, + data=dataset, + processor=processor, + path=tmp_path, + score_writer=score_writer, + module_wise=False, + ) + + assert any(tmp_path.iterdir()), "Expected artifacts in the temp run_path" + assert any(Path(tmp_path).glob("scores.bin")), "Expected scores file" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_module_wise_query(tmp_path: Path, model, dataset): + processor = GradientProcessor(projection_dim=16) + shapes = GradientCollector(model.base_model, lambda x: x, processor).shapes() + + query_gradient_ds = Dataset.from_list( + [ + {module: torch.randn(shape.numel()) for module, shape in shapes.items()} + for _ in range(2) + ] + ).with_format("torch", columns=list(shapes.keys())) + + scorer = build_scorer( + query_gradient_ds, + QueryConfig( + query_path=str(tmp_path / "query_gradient_ds"), + modules=list(shapes.keys()), + score="mean", + ), + device=torch.device("cpu"), + dtype=torch.float32, + module_wise=True, + accumulate_grads="mean", + normalize_accumulated_grad=True, + ) + score_writer = MemmapScoreWriter( + scorer, + len(dataset), + tmp_path, + rank=0, + modules=list(shapes.keys()), + module_wise=True, + ) + + collect_gradients( + model=model, + data=dataset, + processor=processor, + path=tmp_path, + score_writer=score_writer, + module_wise=True, + ) + + assert any(tmp_path.iterdir()), "Expected artifacts in the temp run_path" + assert any(tmp_path.glob("scores.bin")), "Expected scores file" diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index ab25e8bd..56f38cfc 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -1,26 +1,10 @@ import os -import pytest - -try: - import torch - - HAS_CUDA = torch.cuda.is_available() -except Exception: - HAS_CUDA = False - -if not HAS_CUDA: - pytest.skip( - "Skipping GPU-only tests: no CUDA/NVIDIA driver available.", - allow_module_level=True, - ) - - os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["WANDB_MODE"] = "disabled" - import pytest +import torch from datasets import Dataset from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments from trl import SFTConfig, SFTTrainer @@ -59,6 +43,7 @@ def dataset(self): } return Dataset.from_dict(data) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_single_gpu_order_tracking(self, tmp_path, model, dataset): """Test that every step has an associated order record in single-GPU mode.""" # Train the model with the callback @@ -74,7 +59,7 @@ def test_single_gpu_order_tracking(self, tmp_path, model, dataset): ) callback = GradientCollectorCallback( - path=str(tmp_path / "gradients"), + path=tmp_path / "gradients", track_order=True, use_optimizer_state=False, ) @@ -121,6 +106,7 @@ def test_single_gpu_order_tracking(self, tmp_path, model, dataset): for record in callback.order: assert 0 <= record["_idx"] < len(dataset) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_order_tracking_disabled(self, tmp_path, model, dataset): """Test that no order records are created when tracking is disabled.""" # Train the model with the callback @@ -136,7 +122,7 @@ def test_order_tracking_disabled(self, tmp_path, model, dataset): ) callback = GradientCollectorCallback( - path=str(tmp_path / "gradients"), use_optimizer_state=False + path=tmp_path / "gradients", use_optimizer_state=False ) trainer = Trainer( @@ -152,6 +138,7 @@ def test_order_tracking_disabled(self, tmp_path, model, dataset): # Verify no order records were created assert callback.order is None + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_order_save_and_load(self, tmp_path, model, dataset): """Test that order records are properly saved and can be loaded.""" # Train the model with the callback @@ -167,7 +154,7 @@ def test_order_save_and_load(self, tmp_path, model, dataset): ) callback = GradientCollectorCallback( - path=str(tmp_path / "gradients"), + path=tmp_path / "gradients", track_order=True, use_optimizer_state=False, ) @@ -201,6 +188,7 @@ def test_order_save_and_load(self, tmp_path, model, dataset): assert record["global_step"] == callback.order[i]["global_step"] assert record["epoch"] == callback.order[i]["epoch"] + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_sft_trainer(self, tmp_path, model, dataset): """Test that gradient and order files are created and can be loaded after training with SFTTrainer.""" @@ -217,7 +205,7 @@ def test_sft_trainer(self, tmp_path, model, dataset): ) callback = GradientCollectorCallback( - path=str(tmp_path / "gradients"), + path=tmp_path / "gradients", track_order=True, use_optimizer_state=False, ) @@ -247,7 +235,7 @@ def test_sft_trainer(self, tmp_path, model, dataset): assert (gradient_dir / "order.hf").exists() # Test loading the gradient data directly - gradients = load_gradients(str(train_gradient_dir)) + gradients = load_gradients(train_gradient_dir) assert len(gradients) > 0 # Verify order data was saved and can be loaded