diff --git a/CLAUDE.md b/CLAUDE.md index b59475cd..eaef2862 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,4 +1,4 @@ -Always test your changes by running the appropriate script or CLI command. Never complete a task without testing your changes until the script or CLI command runs without issues for 3 minutes+ (at minimum). If you find an error unrelated to your task, at minimum quote the exact error back to me when you have completed your task and offer to investigate and fix it. +Always test your changes by running the appropriate script or CLI command. Never complete a task without testing your changes until the script or CLI command runs without issues. If it's a long-running script let it run for at least a few iterations of the main loop. If you find an error unrelated to your task, at minimum quote the exact error back to me when you have completed your task and offer to investigate and fix it. ## Project Structure and Conventions @@ -18,17 +18,39 @@ Put imports at the top of the file unless you have a good reason to do otherwise # Development +Never use try/except blocks - fail fast, fail explicitly. + +Never use "fallbacks". + +Do not write lines longer than 88 characters. + +Don't use ALL CAPS unless it's proper English (e.g. an acronym). + +Don't keep default run path values inside low level code - if a module calls another module, the higher level module should always pass through inject a base path. + +Don't save data to a directory that is not in the .gitignore - especially the data/ directory. + +Don't remove large datasets from the HF cache without asking. + You can call CLI commands without prefixing `python -m`, like `bergson build`. Use `pre-commit run --all-files` if you forget to install pre-commit and it doesn't run in the hook. Run bash commands in the dedicated tmux pane named "claude" if it is available. -Don't keep default run path values inside low level code - if a module calls another module, the higher level module should always pass through inject a base path. +Don't betray lineage. An example of betraying lineage is duplicating a file, making changes in the duplicate, then calling it "foo_fixed" rather than "foo". Instead, commit the file and modify it directly. Another example is adding a RoundButton to a module containing a Button but not updating the original Button to be called RectangleButton. This betrays that the rectangular button was written first. -Don't save data to a directory that is not in the gitignore - especially the data/ directory. +If you think some data files (e.g. CSVs) have been invalidated but you're not 100% sure, you can add them to a .gitignore'd archive directory along with an equivalentally named markdown file explaining the context. -Don't remove large datasets from the HF cache without asking. +File names always use snake case - in_memory, not inmemory. + +When writing files to disk python scripts should choose their own filenames but be provided with their file paths. + +### Documentation + +Do not mark documentation for code that has been removed as deprecated - simply remove the documentation. + +No context leakage: do not write code or comments that link features to the specific experiment for which the feature was developed, unless it's only useful for that particular experiment. Be as generic as is correctly possible and not more. ### Tests diff --git a/bergson/build.py b/bergson/build.py index 628a54bf..2f4cd0b5 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -3,6 +3,7 @@ import shutil from dataclasses import asdict from datetime import timedelta +from pathlib import Path import torch import torch.distributed as dist @@ -10,9 +11,13 @@ from tqdm.auto import tqdm from bergson.collection import collect_gradients +from bergson.collector.gradient_collectors import GradientCollector from bergson.config import IndexConfig from bergson.data import allocate_batches from bergson.distributed import launch_distributed_run +from bergson.utils.auto_batch_size import ( + determine_batch_size, +) from bergson.utils.utils import assert_type, setup_reproducibility from bergson.utils.worker_utils import ( create_processor, @@ -63,6 +68,24 @@ def build_worker( model, target_modules = setup_model_and_peft(cfg) processor = create_processor(model, ds, cfg, target_modules) + # Auto batch size determination if enabled + if cfg.autobatchsize: + cfg.token_batch_size = determine_batch_size( + root=Path(".cache"), + cfg=cfg, + model=model, + collector=GradientCollector( + model=model.base_model, + cfg=cfg, + processor=processor, + target_modules=target_modules, + data=ds, + scorer=None, + reduce_cfg=None, + ), + starting_batch_size=cfg.token_batch_size, + ) + attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules} kwargs = { diff --git a/bergson/config.py b/bergson/config.py index 2d51cc7f..304e92a1 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -144,6 +144,9 @@ class IndexConfig: token_batch_size: int = 2048 """Batch size in tokens for building the index.""" + autobatchsize: bool = False + """Whether to automatically determine the optimal batch size.""" + processor_path: str = "" """Path to a precomputed processor.""" diff --git a/bergson/utils/auto_batch_size.py b/bergson/utils/auto_batch_size.py index ad8024ec..84b7ba12 100644 --- a/bergson/utils/auto_batch_size.py +++ b/bergson/utils/auto_batch_size.py @@ -1,345 +1,187 @@ -""" -In-memory auto batch size determination for Bergson benchmarks. - -This module provides utilities to automatically find the optimal token_batch_size -that fits in GPU memory for already-loaded models and datasets. - -Main function: find_optimal_token_batch_size() -- Call this with your loaded model, tokenizer, and dataset -- Returns optimal token_batch_size that fits in memory - -Adapted from HuggingFace Accelerate's find_executable_batch_size utility. -""" - import gc import json from pathlib import Path -from typing import Callable, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional import torch -from datasets import Dataset -from transformers import PreTrainedModel, PreTrainedTokenizer - -from bergson.collector.collector import CollectorComputer -from bergson.collector.in_memory_collector import InMemoryCollector -from bergson.config import DataConfig, IndexConfig -from bergson.data import allocate_batches, tokenize -from bergson.gradients import GradientProcessor - - -def should_reduce_batch_size(exception: Exception) -> bool: - """Check if exception relates to out-of-memory errors or batch size issues.""" - _statements = [ - " out of memory.", - "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", - "DefaultCPUAllocator: can't allocate memory", - "FATAL ERROR :: MODULE:PT_DEVMEM Allocation failed", - # Catches "Token batch size X exceeds model's max sequence length" - "Token batch size", - # Catches "distributed worker error or insufficient documents" - "insufficient documents", - ] - if isinstance(exception, RuntimeError) and len(exception.args) == 1: - return any(err in exception.args[0] for err in _statements) - return False - - -def clear_device_cache(garbage_collection: bool = False) -> None: - """Clear device cache and optionally run garbage collection.""" - if garbage_collection: - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() +from transformers import PreTrainedModel +from bergson.config import IndexConfig -def round_to_power_of_2(n: int) -> int: - """Round down to the nearest power of 2.""" - if n <= 0: - return 1 - power = 1 - while power * 2 <= n: - power *= 2 - return power +if TYPE_CHECKING: + from bergson.collector.gradient_collectors import HookCollectorBase -def save_batch_size_cache( - cache_path: Path, model_name: str, token_batch_size: int, fsdp: bool = False -) -> None: - """Save optimal token_batch_size to cache file.""" - cache_path.parent.mkdir(parents=True, exist_ok=True) - - cache_data = { - "model_name": model_name, - "token_batch_size": token_batch_size, - "fsdp": fsdp, - "gpu_name": ( - torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu" - ), - "gpu_memory_gb": ( - torch.cuda.get_device_properties(0).total_memory / 1e9 - if torch.cuda.is_available() - else None - ), - } +def _clear_cache() -> None: + """Aggressively clear memory.""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() - with open(cache_path, "w") as f: - json.dump(cache_data, f, indent=2) - print(f"Saved batch size cache to {cache_path}") +def _get_system_metadata(cfg: IndexConfig) -> Dict[str, Any]: + """Identify the current hardware and model configuration.""" + gpu_name = "cpu" + gpu_mem = 0.0 + + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(0) + gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 + + return { + "model": cfg.model, + "fsdp": cfg.fsdp, + "precision": getattr(cfg, "precision", "fp32"), + "projection_dim": getattr(cfg, "projection_dim", 0), + "reshape_to_square": getattr(cfg, "reshape_to_square", False), + "gpu_name": gpu_name, + "gpu_memory_gb": round(gpu_mem, 1), + } -def load_batch_size_cache( - cache_path: Path, model_name: str, fsdp: bool = False -) -> Optional[int]: - """Load optimal token_batch_size from cache if available and valid.""" - if not cache_path.exists(): +def _check_cache(cache_file: Path, current_meta: Dict[str, Any]) -> Optional[int]: + """Read JSONL file and look for a matching configuration.""" + if not cache_file.exists(): return None try: - with open(cache_path, "r") as f: - cache_data = json.load(f) - - # Verify cache is for the same configuration - if cache_data.get("model_name") != model_name: - print( - f"Cache model mismatch: {cache_data.get('model_name')} != {model_name}" - ) - return None - - if cache_data.get("fsdp") != fsdp: - print(f"Cache FSDP mismatch: {cache_data.get('fsdp')} != {fsdp}") - return None - - # Check if GPU matches (optional warning) - if torch.cuda.is_available(): - current_gpu = torch.cuda.get_device_name(0) - cached_gpu = cache_data.get("gpu_name") - if cached_gpu and cached_gpu != current_gpu: - print(f"Warning: GPU changed from {cached_gpu} to {current_gpu}") - print("Cached batch size may not be optimal for this GPU") - - token_batch_size = cache_data.get("token_batch_size") - if token_batch_size and isinstance(token_batch_size, int): - print( - f"Loaded cached token_batch_size={token_batch_size} from {cache_path}" - ) - return token_batch_size - + with open(cache_file, "r") as f: + for line in f: + try: + row = json.loads(line) + if all(row.get(k) == v for k, v in current_meta.items()): + return row.get("token_batch_size") + except json.JSONDecodeError: + continue + except Exception: return None - except Exception as e: - print(f"Failed to load batch size cache: {e}") - return None + return None -def find_optimal_token_batch_size_raw( - test_fn: Callable[[int], None], - starting_batch_size: int = 4096, - round_to_pow2: bool = True, - max_batch_size: Optional[int] = None, -) -> int: - """ - Find optimal token_batch_size by testing with progressively larger/smaller sizes. +def _append_to_cache( + cache_file: Path, current_meta: Dict[str, Any], batch_size: int +) -> None: + """Append a new row to the JSONL cache file.""" + cache_file.parent.mkdir(parents=True, exist_ok=True) - Args: - test_fn: Function that takes token_batch_size and performs a test pass - starting_batch_size: Initial token_batch_size to try - round_to_pow2: Round final batch size down to nearest power of 2 - max_batch_size: Maximum batch size to try (e.g., model's max sequence length) + entry = current_meta.copy() + entry["token_batch_size"] = batch_size - Returns: - Optimal token_batch_size that fits in memory - """ - token_batch_size = starting_batch_size - successful_batch_size = None - iteration = 0 + with open(cache_file, "a") as f: + f.write(json.dumps(entry) + "\n") - clear_device_cache(garbage_collection=True) + print(f"Cached batch size {batch_size} to {cache_file}") - while True: - iteration += 1 - if token_batch_size < 128: - if successful_batch_size is not None: - break - raise RuntimeError( - f"No executable token_batch_size found, reached minimum (128). " - f"Started from {starting_batch_size}." - ) - - try: - print( - f"[Iteration {iteration}] Trying token_batch_size={token_batch_size}..." - ) - test_fn(token_batch_size) - successful_batch_size = token_batch_size - print( - f"✓ [Iteration {iteration}] " - f"token_batch_size={token_batch_size} succeeded" - ) - - # Try larger batch size - next_size = int(token_batch_size * 1.5) - # Cap at model's max sequence length if specified - if max_batch_size is not None: - next_size = min(next_size, max_batch_size) - # Cap at a reasonable max (1M tokens) to avoid infinite growth - if next_size > token_batch_size and token_batch_size < 1_000_000: - token_batch_size = next_size - clear_device_cache(garbage_collection=True) - continue - else: - break - except Exception as e: - if should_reduce_batch_size(e): - print( - f"✗ [Iteration {iteration}] " - f"token_batch_size={token_batch_size} failed (OOM)" - ) - clear_device_cache(garbage_collection=True) - token_batch_size = int(token_batch_size * 0.7) - - if successful_batch_size is not None: - break - else: - raise +def _try_validate( + model: PreTrainedModel, size: int, collector: "HookCollectorBase" +) -> bool: + """ + Returns True if the batch size fits, False otherwise. + Wraps the user-provided logic to handle cleanup and error catching. + """ + _clear_cache() + + # Check model max length constraint immediately + max_seq_len = getattr(model.config, "max_position_embeddings", None) + if max_seq_len is not None and size > max_seq_len: + return False - if successful_batch_size is None: - raise RuntimeError("Could not find a working token_batch_size") + try: + # Create random tokens + random_tokens = torch.randint( + 0, 10, (1, size), device=model.device, dtype=torch.long + ) - final_batch_size = successful_batch_size - if round_to_pow2: - final_batch_size = round_to_power_of_2(successful_batch_size) - print(f"Rounded {successful_batch_size} → {final_batch_size} (power of 2)") + # Run the collector pass + with collector: + loss = model(random_tokens).logits[0, 0, 0].float() + loss.backward() + model.zero_grad() - print(f"\n{'='*60}") - print(f"Optimal token_batch_size found: {final_batch_size}") - print(f"{'='*60}\n") + return True - return final_batch_size + except (RuntimeError, ValueError, torch.cuda.OutOfMemoryError): + return False + finally: + # Ensure gradients/graphs are wiped before next attempt + model.zero_grad(set_to_none=True) + _clear_cache() -def find_optimal_token_batch_size( +def determine_batch_size( + root: Path, + cfg: IndexConfig, model: PreTrainedModel, - tokenizer: PreTrainedTokenizer, - dataset: Dataset, - starting_batch_size: int = 4096, + collector: "HookCollectorBase", + starting_batch_size: int = 8192, ) -> int: """ - Determine optimal token_batch_size for loaded models and data. - - This function assumes the model, tokenizer, and dataset are already loaded - and ready to use. It will test different batch sizes to find the optimal - token_batch_size that fits in available memory. + Finds the largest viable token batch size that fits in memory. - Args: - model: Already loaded and initialized model - tokenizer: Already loaded tokenizer - dataset: Small test dataset (already loaded) - starting_batch_size: Starting batch size to test - - Returns: - Optimal token_batch_size (power of 2) + 1. Checks cache. + 2. Performs binary search for max size. + 3. Saves to cache. + 4. Returns optimal size. """ - print("\n" + "=" * 60) - print("Finding optimal token_batch_size for loaded model...") - print("=" * 60 + "\n") - - # Cap starting batch size to model's max sequence length - max_seq_len = getattr(model.config, "max_position_embeddings", None) - if max_seq_len is not None and starting_batch_size > max_seq_len: - print( - f"Capping starting_batch_size from {starting_batch_size} " - f"to model's max sequence length {max_seq_len}" - ) - starting_batch_size = max_seq_len - - processor = GradientProcessor( - normalizers={}, - projection_dim=None, - reshape_to_square=False, - projection_type="rademacher", - ) - - def test_batch_size(token_batch_size: int) -> None: - """Test function that tries a single forward/backward pass.""" - test_dataset = dataset.select(range(min(5, len(dataset)))) - - test_dataset = test_dataset.map( - tokenize, - batched=True, - fn_kwargs=dict(args=DataConfig(truncation=True), tokenizer=tokenizer), - ) - test_dataset.set_format( - type="torch", columns=["input_ids", "attention_mask", "labels", "length"] - ) + cache_path = root / "batch_size_cache.jsonl" + metadata = _get_system_metadata(cfg) - index_cfg = IndexConfig( - run_path="temp", - model="test", - token_batch_size=token_batch_size, - loss_fn="ce", - loss_reduction="mean", - ) + # Check Cache + cached_size = _check_cache(cache_path, metadata) + if cached_size is not None: + print(f"Loaded optimal token_batch_size from cache: {cached_size}") + return cached_size - test_collector = InMemoryCollector( - model=model.base_model, # type: ignore - processor=processor, - data=test_dataset, - cfg=index_cfg, - ) + print("Determining optimal batch size via binary search...") - batches = allocate_batches(test_dataset["length"], token_batch_size) # type: ignore + # Setup Bounds + max_pos = getattr(model.config, "max_position_embeddings", None) + if max_pos and starting_batch_size > max_pos: + starting_batch_size = max_pos - computer = CollectorComputer( - model=model, - data=test_dataset, - collector=test_collector, - batches=batches, - cfg=index_cfg, - ) - computer.run_with_collector_hooks(desc="batch size test") + current_size = starting_batch_size + last_working_size = None - # Get max sequence length from model config - max_seq_len = getattr(model.config, "max_position_embeddings", None) + while True: + # Safety break + if current_size < 16: + if last_working_size is not None: + current_size = last_working_size + break + raise RuntimeError("Could not fit even token_batch_size=16 in memory.") - return find_optimal_token_batch_size_raw( - test_fn=test_batch_size, - starting_batch_size=starting_batch_size, - round_to_pow2=True, - max_batch_size=max_seq_len, - ) + print(f"Testing batch size: {current_size}...", end=" ", flush=True) + is_success = _try_validate(model, current_size, collector) + if is_success: + print("✓ Fits") + last_working_size = current_size -def get_optimal_batch_size( - cache_path: Path, - model_hf_id: str, - fsdp: bool, - determine_fn: Callable[[], int], -) -> int: - """ - Get optimal batch size from cache or determine it. + # Growth Phase: Try next power of 2 + next_size = current_size * 2 - Args: - cache_path: Path to cache file - model_hf_id: HuggingFace model ID - fsdp: Whether FSDP is enabled - starting_batch_size: Starting batch size for determination - determine_fn: Function to determine batch size if not cached + # Stop if we exceed model max length + if max_pos and next_size > max_pos: + break - Returns: - Optimal token_batch_size - """ - # Try to load from cache - cached_batch_size = load_batch_size_cache(cache_path, model_hf_id, fsdp) + current_size = next_size + else: + print("✗ OOM / Too Large") - if cached_batch_size is not None: - return cached_batch_size + if last_working_size is not None: + # We were growing and just hit the ceiling, + # so the last one was the winner. + current_size = last_working_size + break + else: + # We haven't found a working size yet (Shrink Phase) + current_size = current_size // 2 - # Determine optimal batch size - optimal_batch_size = determine_fn() + print(f"Largest viable batch size found: {current_size}") - # Save to cache - save_batch_size_cache(cache_path, model_hf_id, optimal_batch_size, fsdp) + # Save to Cache + _append_to_cache(cache_path, metadata, current_size) - return optimal_batch_size + return current_size diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 00000000..063f0ce2 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,1280 @@ +# Claude-Generated Bergson Architecture Overview + +This document provides a comprehensive overview of the Bergson architecture, including code structure, key components, design patterns, and data flow. + +## Table of Contents + +1. [High-Level Overview](#high-level-overview) +2. [Code Structure](#code-structure) +3. [Core Abstractions](#core-abstractions) +4. [Data Flow](#data-flow) +5. [Gradient Collection](#gradient-collection) +6. [Indexing and Querying](#indexing-and-querying) +7. [FAISS Integration](#faiss-integration) +8. [Distributed Training](#distributed-training) +9. [Design Patterns](#design-patterns) +10. [File Formats](#file-formats) + +--- + +## High-Level Overview + +Bergson is a library for **gradient-based data attribution** using the TrackStar algorithm. It enables tracing which training examples most influenced a model's predictions through efficient per-sample gradient computation and similarity search. + +### Key Capabilities + +- **Efficient Gradient Collection**: Per-sample gradients computed via PyTorch hooks +- **Memory-Efficient Compression**: Random projection and factored normalization +- **Scalable Indexing**: Memory-mapped storage and FAISS approximate search +- **Distributed Training**: Multi-GPU/multi-node support with FSDP +- **Flexible Querying**: Interactive CLI and programmatic API + +### Core Workflow + +``` +Training Data → Gradient Collection → Index Building → Query → Attribution + ↓ ↓ ↑ + Compression/Normalization Storage FAISS/Search +``` + +--- + +## Code Structure + +The codebase is organized into focused modules with clear responsibilities: + +``` +bergson/ +├── __init__.py # Public API exports +├── __main__.py # CLI entry point with command routing +├── config.py # Dataclass configurations for all components +│ +├── build.py # Index building orchestration +├── collection.py # High-level gradient collection API +├── reduce.py # Gradient reduction (mean/sum) +│ +├── data.py # Data loading, batching, and storage utilities +├── gradients.py # GradientProcessor and normalizer abstractions +├── distributed.py # Multi-GPU/multi-node orchestration +├── process_preconditioners.py # Preconditioner computation and eigendecomposition +│ +├── collector/ # Hook-based gradient collection +│ ├── collector.py # Base classes for hook collectors +│ ├── gradient_collectors.py # GradientCollector, TraceCollector +│ └── in_memory_collector.py # In-memory gradient collection +│ +├── query/ # Index querying and attribution +│ ├── attributor.py # Main attribution interface +│ ├── faiss_index.py # FAISS integration for ANN search +│ └── query_index.py # Interactive query CLI +│ +├── score/ # On-the-fly scoring +│ ├── score.py # Score dataset against query index +│ ├── scorer.py # Scorer class for computing similarities +│ └── score_writer.py # Memory-mapped score storage +│ +├── normalizer/ # Gradient normalization +│ └── fit_normalizers.py # Estimate Adam/Adafactor normalizers +│ +├── utils/ # Utility functions +│ ├── worker_utils.py # Model/data setup for distributed workers +│ ├── logger.py # Logging utilities +│ ├── peft.py # PEFT adapter detection +│ ├── math.py # Mathematical utilities +│ └── auto_batch_size.py # Automatic batch size tuning +``` + +### Module Responsibilities + +| Module | Responsibility | +|--------|---------------| +| **collector/** | How to collect gradients (hooks, computation) | +| **query/** | How to search gradients (attribution, FAISS) | +| **score/** | How to compute similarities on-the-fly | +| **data.py** | Where to store gradients (memory-mapped files) | +| **gradients.py** | How to transform gradients (normalization, compression) | +| **distributed.py** | How to coordinate workers (multi-GPU/node) | + +--- + +## Core Abstractions + +### GradientProcessor + +**Location**: `gradients.py` + +The central configuration object for gradient processing. + +```python +@dataclass +class GradientProcessor: + normalizers: Dict[str, Normalizer] # Per-module normalizers + preconditioners: Dict[str, Tensor] # Preconditioner matrices + preconditioners_eigen: Dict[str, Tuple[Tensor, Tensor]] # Eigendecompositions + projection_dim: int # Target dimension for compression + projection_type: str # "normal" or "rademacher" + include_bias: bool # Whether to include bias gradients +``` + +**Responsibilities**: +- Stores normalization state (Adam/Adafactor) +- Stores preconditioners (Hessian approximations) +- Configures compression (projection dimension) +- Provides serialization (`save()` / `load()`) + +### Normalizer + +**Location**: `gradients.py` + +Abstract base class for gradient normalization strategies. + +```python +class Normalizer(ABC): + @abstractmethod + def normalize_(self, grad: Tensor) -> Tensor: + """Normalize gradient in-place""" +``` + +**Implementations**: + +1. **AdafactorNormalizer**: Factored second moments + - Memory: O(O + I) for layer with shape [O, I] + - Stores `row` (size O) and `col` (size I) statistics + - Approximates full second moment matrix + +2. **AdamNormalizer**: Full second moment + - Memory: O(O × I) for layer with shape [O, I] + - Stores complete second moment matrix + - More accurate but less scalable + +**Usage**: +```python +normalizer.normalize_(grad) # In-place normalization +``` + +### HookCollectorBase + +**Location**: `collector/collector.py` + +Abstract base class for all gradient collectors using PyTorch hooks. + +```python +class HookCollectorBase(ABC): + def __enter__(self): + """Register hooks on model""" + + def __exit__(self, exc_type, exc_val, exc_tb): + """Cleanup hooks and state""" + + @abstractmethod + def forward_hook(self, module, input, output): + """Cache activations during forward pass""" + + @abstractmethod + def backward_hook(self, module, grad_input, grad_output): + """Compute per-sample gradients during backward""" + + @abstractmethod + def process_batch(self): + """Process collected gradients after batch""" + + @abstractmethod + def teardown(self): + """Final processing and save""" +``` + +**Lifecycle**: +```python +with collector: # __enter__: registers hooks + model(input).loss.backward() # Hooks execute + collector.process_batch() # Process gradients +# __exit__: cleanup hooks +collector.teardown() # Final processing +``` + +### GradientCollector + +**Location**: `collector/gradient_collectors.py` + +Main collector for building gradient indexes. + +**Key Features**: +- Per-sample gradient computation via hooks +- Random projection compression +- Adafactor/Adam normalization +- Preconditioner accumulation +- Distributed gradient aggregation + +**Example**: +```python +collector = GradientCollector( + model=model, + processor=processor, + builder=builder, # Writes gradients to disk + attention_config=attention_cfg +) + +with collector: + for batch in batches: + loss = model(**batch).loss + loss.backward() + collector.process_batch() + +collector.teardown() # Process preconditioners, save +``` + +### Attributor + +**Location**: `query/attributor.py` + +High-level interface for querying gradient indexes. + +```python +class Attributor: + def __init__( + self, + index_path: Path, + model: nn.Module, + unit_norm: bool = True, + faiss_config: Optional[FaissConfig] = None + ): + """Load index and prepare for querying""" + + def trace(self, model: nn.Module, k: int = 5): + """Context manager for computing attribution""" +``` + +**Usage**: +```python +attributor = Attributor(index_path="runs/my_index", model=model) + +with attributor.trace(model, k=5) as trace_result: + loss = model(query_input).loss + loss.backward() + +# trace_result contains top-k training example indices and scores +``` + +### Scorer + +**Location**: `score/scorer.py` + +Computes similarity scores on-the-fly without saving gradients. + +**Scoring Modes**: +- `mean`: Score against mean query gradient +- `nearest`: Score against most similar query gradient +- `individual`: Separate score for each query gradient + +**Example**: +```python +scorer = Scorer( + query_path="runs/query_vector", + score_method="mean", + writer=score_writer +) + +with scorer: + for batch in dataset: + loss = model(**batch).loss + loss.backward() + scorer.process_batch() + +scorer.teardown() +``` + +### Builder + +**Location**: `data.py` + +Handles writing gradients to memory-mapped files. + +```python +class Builder: + def __init__( + self, + modules: List[str], + grad_sizes: Dict[str, int], + num_grads: int, + output_path: Path + ): + """Create memory-mapped gradient storage""" + + def write(self, indices: List[int], gradients: Dict[str, Tensor]): + """Write gradients at specified indices""" +``` + +**Storage Format**: +- Structured numpy array with one field per module +- Memory-mapped for efficient out-of-core processing +- Supports concurrent writes in distributed setting + +--- + +## Data Flow + +### Build Command Flow + +``` +CLI Input + ↓ +bergson build --model --dataset + ↓ +__main__.py: Parse args into Build dataclass + ↓ +build.py: build(index_cfg) + ├─ Setup data pipeline + ├─ Validate paths + └─ Launch distributed run + ↓ + distributed.py: launch_distributed_run() + ↓ + build.py: build_worker() [on each GPU/node] + ├─ Initialize process group + ├─ Setup model + ├─ Create GradientProcessor + │ └─ Fit normalizers if needed + ├─ Create GradientCollector + └─ Run collection + ↓ + collection.py: collect_gradients() + ├─ Create CollectorComputer + └─ Run with collector hooks + ↓ + collector/collector.py: run_with_collector_hooks() + ├─ For each batch: + │ ├─ Enter collector (register hooks) + │ ├─ Forward pass → forward_hook caches activations + │ ├─ Backward pass → backward_hook computes gradients + │ ├─ process_batch() writes to disk + │ └─ Exit collector (cleanup hooks) + └─ teardown() processes preconditioners + ↓ + data.py: Builder writes gradients to memory-mapped file + ↓ + process_preconditioners.py: Aggregate and eigendecompose + ↓ + Save: gradients.bin, processor, dataset, metadata +``` + +### Query Command Flow + +``` +CLI Input + ↓ +bergson query --index + ↓ +query/query_index.py: query() + ├─ Load index config and dataset + ├─ Create Attributor (with optional FAISS) + └─ Interactive loop: + ├─ Get query text from user + ├─ Tokenize query + └─ attributor.trace(model, k=5) + ├─ Enter context: Create TraceCollector + ├─ Forward pass with query + ├─ Backward pass → collect query gradients + ├─ Search index for top-k matches + │ ├─ If FAISS: faiss_index.search() + │ └─ Else: In-memory search via matmul + └─ Return TraceResult(indices, scores) + ↓ + Display top-k training examples from dataset +``` + +### Score Command Flow + +``` +CLI Input + ↓ +bergson score --query_path --score mean + ↓ +score/score.py: score() + ├─ Load query index + ├─ Setup data pipeline + ├─ Create Scorer with ScoreWriter + └─ Launch distributed run + ↓ + build_worker() + ├─ Initialize process group + ├─ Setup model + ├─ Create Scorer + └─ For each batch: + ├─ Forward pass + ├─ Backward pass → collect gradients + ├─ Compute similarity to query + ├─ Write scores to memory-mapped file + └─ Continue + ↓ + Save: scores.bin with similarity values +``` + +--- + +## Gradient Collection + +Bergson implements the **TrackStar algorithm** for scalable gradient-based attribution. + +### Per-Sample Gradient Computation + +The core innovation is computing per-sample gradients efficiently without storing full batch gradients. + +#### Forward Hook: Cache Preprocessed Activations + +```python +def forward_hook(self, module, input, output): + """Cache activations with optional preprocessing""" + a = input[0] # [N, S, I] - batch, sequence, input_dim + + # Apply Adafactor column normalization + if self.adafactor_normalizer: + col_norm = self.adafactor_normalizer.col.rsqrt() + a = a * col_norm + + # Apply random projection + if self.projection_dim: + proj = self._get_projection(module) # Cached [I, p] + a = a @ proj # [N, S, p] + + # Cache preprocessed activations + module._cached_inputs = a +``` + +#### Backward Hook: Compute Per-Sample Gradients + +```python +def backward_hook(self, module, grad_input, grad_output): + """Compute per-sample gradients via outer product""" + a = module._cached_inputs # [N, S, I] or [N, S, p] + g = grad_output[0] # [N, S, O] + + # Apply Adafactor row normalization + if self.adafactor_normalizer: + row_norm = self.adafactor_normalizer.row + g = g * (row_norm.mean().sqrt() * row_norm.rsqrt()) + + # Apply gradient projection + if self.projection_dim: + g_proj = self._get_grad_projection(module) # [O, p] + g = g @ g_proj.T # [N, S, p] + + # Compute per-sample gradient as outer product + # P[i] = g[i].T @ a[i] for each sample i + P = g.mT @ a # [N, O/p, I/p] + P = P.flatten(1) # [N, (O/p)*(I/p)] + + # Accumulate preconditioner (gradient covariance) + if self.compute_preconditioners: + self.preconditioner += P.mT @ P + + # Write gradients to disk + self.builder.write(batch_indices, {module_name: P}) +``` + +### Memory Efficiency Techniques + +#### 1. Random Projections + +Compress gradients from `[O, I]` to `[p, p]` where `p << min(O, I)`. + +**Properties**: +- Preserves inner products approximately (Johnson-Lindenstrauss) +- Rademacher matrices: `{-1, +1}` entries (fast, no random generation) +- Gaussian matrices: N(0, 1) entries (better theoretical guarantees) + +**Memory Savings**: +- Original: O × I parameters +- Projected: p × p parameters +- Typical: p=16, O=4096, I=4096 → 99.99% compression + +**Example**: +```python +# Without projection: 4096 × 4096 = 16M parameters +# With projection: 16 × 16 = 256 parameters +compression_ratio = (O * I) / (p * p) # 65,536x +``` + +#### 2. Adafactor Normalization + +Factored representation of second moment matrix. + +**Standard Adam**: +- Second moment: [O, I] matrix +- Memory: O(O × I) + +**Adafactor**: +- Row factors: [O] vector +- Column factors: [I] vector +- Memory: O(O + I) + +**Normalization**: +```python +# Full second moment (conceptual): +M = row[:, None] * col[None, :] # [O, I] + +# Applied factorized: +a_normalized = a * col.rsqrt() # Apply to activations +g_normalized = g * row.rsqrt() # Apply to gradients +``` + +**Memory Savings**: +- For O=4096, I=4096: + - Adam: 16M parameters + - Adafactor: 8K parameters (2000× reduction) + +#### 3. Lazy Materialization + +Gradients never fully materialized in memory: + +1. **Forward pass**: Cache preprocessed activations +2. **Backward pass**: Compute gradient on-the-fly +3. **Immediate write**: Write to disk via memory-mapped file +4. **Discard**: Clear cache for next batch + +**Benefits**: +- Constant memory usage per batch +- Supports datasets larger than RAM +- Enables distributed gradient aggregation + +### TrackStar-Specific Components + +#### Preconditioners (Hessian Approximation) + +Accumulated during gradient collection: + +```python +# Gradient covariance matrix +preconditioner = sum(g_i @ g_i.T for g_i in gradients) +preconditioner /= num_examples + +# Eigendecomposition for efficient inversion +eigval, eigvec = torch.linalg.eigh(preconditioner) + +# Inverse square root (for influence computation) +inv_sqrt = eigvec @ diag(eigval ** -0.5) @ eigvec.T +``` + +**Usage in Attribution**: +```python +# Apply preconditioning to query gradient +q_preconditioned = inv_sqrt @ q + +# Compute influence scores +influences = q_preconditioned @ gradients.T +``` + +#### Distributed Preconditioner Aggregation + +```python +# Each worker computes local preconditioner +local_prec = local_gradients.T @ local_gradients / local_count + +# Reduce to rank 0 (on CPU to save GPU memory) +dist.reduce(local_prec, dst=0, op=dist.ReduceOp.SUM) + +# Rank 0 computes eigendecomposition +if rank == 0: + global_prec = local_prec / world_size + eigval, eigvec = torch.linalg.eigh(global_prec) +``` + +--- + +## Indexing and Querying + +### Index Structure + +Gradients stored in **structured memory-mapped numpy arrays**: + +```python +# Create structured dtype with one field per module +dtype = { + 'names': ['gpt_neox.layers.0.attention.dense', + 'gpt_neox.layers.0.mlp.dense_h_to_4h', ...], + 'formats': ['(256,)float16', '(512,)float16', ...] +} + +# Create memory-mapped array +gradients = np.memmap( + 'gradients.bin', + dtype=dtype, + mode='w+', + shape=(num_examples,) +) + +# Access gradients by module +layer_grads = gradients['gpt_neox.layers.0.attention.dense'] # [num_examples, 256] +``` + +**Benefits**: +- Efficient out-of-core processing +- Named field access +- Supports partial loading (select modules) +- Works with datasets larger than RAM + +### Metadata Format + +**info.json**: +```json +{ + "num_grads": 100000, + "dtype": { + "names": ["layer1", "layer2"], + "formats": ["(256,)float16", "(512,)float16"] + }, + "grad_sizes": { + "layer1": 256, + "layer2": 512 + }, + "base_dtype": "float16" +} +``` + +### Query Methods + +#### In-Memory Search + +Fast exact search when index fits in GPU memory: + +```python +# Load gradients into GPU +grads = {} +for name in module_names: + grads[name] = torch.tensor(mmap[name], device='cuda') + +# Compute scores via batch matrix multiplication +scores = sum( + query_grad[name] @ grads[name].T + for name in module_names +) # [num_examples] + +# Get top-k +topk_values, topk_indices = torch.topk(scores, k) +``` + +**Complexity**: +- Time: O(num_examples × grad_dim) +- Space: O(num_examples × grad_dim) GPU memory + +#### FAISS Search + +Approximate nearest neighbor for large-scale indices: + +```python +# Build index +index = faiss.index_factory( + grad_dim, + "IVF1024,SQfp16", + faiss.METRIC_INNER_PRODUCT +) +index.train(gradients[:train_size]) +index.add(gradients) + +# Search +distances, indices = index.search(query, k) +``` + +**Complexity**: +- Time: O(log(num_examples) × grad_dim) with IVF +- Space: Compressed on disk, partial loading + +--- + +## FAISS Integration + +### Index Creation Workflow + +```python +def create_faiss_index( + gradient_path: Path, + factory_string: str = "IVF1024,SQfp16", + num_shards: int = 1, + mmap_index: bool = False +): + """Create FAISS index from gradients""" + + # 1. Load gradients from memory-mapped files + gradients = load_gradients(gradient_path) + + # 2. Normalize if needed + if unit_norm: + gradients = normalize_on_gpu(gradients) + + # 3. Create sharded indexes + shard_size = len(gradients) // num_shards + + for shard_id in range(num_shards): + start = shard_id * shard_size + end = start + shard_size + shard_grads = gradients[start:end] + + # 4. Build FAISS index + index = faiss.index_factory( + grad_dim, + factory_string, + faiss.METRIC_INNER_PRODUCT + ) + + # 5. Train (for IVF indexes) + if "IVF" in factory_string: + train_size = min(len(shard_grads), 1_000_000) + index.train(shard_grads[:train_size]) + + # 6. Add vectors + index.add(shard_grads) + + # 7. Save to disk + faiss.write_index(index, f"{shard_id}.faiss") +``` + +### Multi-Shard Search + +Enables querying indices larger than memory: + +```python +def search_sharded(query: Tensor, k: int): + """Search across multiple FAISS shards""" + all_distances = [] + all_indices = [] + + # Search each shard independently + for shard_id, shard in enumerate(shards): + # Load shard (optionally mmap) + if mmap_index: + index = faiss.read_index(f"{shard_id}.faiss", faiss.IO_FLAG_MMAP) + else: + index = faiss.read_index(f"{shard_id}.faiss") + + # Search this shard + distances, indices = index.search(query, k) + + # Offset indices by shard position + offset = shard_id * shard_size + indices += offset + + all_distances.append(distances) + all_indices.append(indices) + + # Concatenate results from all shards + combined_distances = np.concatenate(all_distances, axis=1) + combined_indices = np.concatenate(all_indices, axis=1) + + # Rerank to get global top-k + topk_positions = np.argsort(-combined_distances, axis=1)[:, :k] + topk_distances = np.take_along_axis(combined_distances, topk_positions, axis=1) + topk_indices = np.take_along_axis(combined_indices, topk_positions, axis=1) + + return topk_distances, topk_indices +``` + +### FAISS Factory Strings + +Common configurations: + +| Factory String | Description | Speed | Memory | Accuracy | +|---------------|-------------|-------|--------|----------| +| `"Flat"` | Exact search (brute force) | Slow | High | 100% | +| `"IVF1,SQfp16"` | Exact with fp16 quantization | Medium | Medium | 100% | +| `"IVF1024,SQfp16"` | ANN with 1024 clusters, fp16 | Fast | Low | ~95% | +| `"IVF4096,PQ32"` | ANN with product quantization | Very Fast | Very Low | ~90% | +| `"HNSW32"` | Hierarchical graph search | Fast | Medium | ~98% | + +**Parameters**: +- `nprobe`: Number of clusters to search (IVF) + - Higher → more accurate, slower + - Default: 1024 +- `mmap_index`: Query on disk vs load into memory + - `True` → lower memory, slower + - `False` → higher memory, faster + +--- + +## Distributed Training + +### Multi-GPU/Multi-Node Setup + +Bergson supports distributed gradient collection across multiple GPUs and nodes. + +#### Configuration + +```python +@dataclass +class DistributedConfig: + nnode: int = 1 # Number of nodes + nproc_per_node: int = 1 # GPUs per node + node_rank: int = 0 # Current node rank + master_addr: str = "localhost" # Master node address + master_port: str = "29500" # Master node port +``` + +#### Launch Distributed Run + +```python +from bergson.distributed import launch_distributed_run + +launch_distributed_run( + name="build", + worker_fn=build_worker, + const_worker_args=[index_cfg, dataset], + dist_config=cfg.distributed +) +``` + +**What it does**: +1. Spawns `nproc_per_node` processes +2. Each process gets a unique `local_rank` (0 to nproc_per_node-1) +3. Sets environment variables for distributed training +4. Calls `worker_fn` on each process + +#### Worker Initialization + +```python +def build_worker(local_rank, index_cfg, dataset): + # Set CUDA device + torch.cuda.set_device(local_rank) + + # Initialize process group + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=world_size, + rank=global_rank + ) + + # Compute work assignment + batches_per_worker = len(batches) // world_size + start_idx = global_rank * batches_per_worker + end_idx = start_idx + batches_per_worker + my_batches = batches[start_idx:end_idx] + + # Run gradient collection + collect_gradients(model, my_batches, ...) +``` + +### Data Distribution + +#### Batch Allocation + +Sophisticated bin-packing algorithm ensures equal work per worker: + +```python +def allocate_batches( + dataset, + token_batch_size: int, + world_size: int +): + """Distribute batches across workers""" + + # Constraint: max_len * batch_size ≤ token_batch_size + # Goal: Equal batches per worker + + # 1. Sort documents by length + sorted_docs = sorted(dataset, key=lambda x: len(x['input_ids'])) + + # 2. Greedy bin packing + batches = [] + current_batch = [] + current_max_len = 0 + + for doc in sorted_docs: + doc_len = len(doc['input_ids']) + + # Check if adding doc exceeds token budget + new_max_len = max(current_max_len, doc_len) + new_size = len(current_batch) + 1 + + if new_max_len * new_size > token_batch_size: + # Start new batch + batches.append(current_batch) + current_batch = [doc] + current_max_len = doc_len + else: + current_batch.append(doc) + current_max_len = new_max_len + + # 3. Round to multiple of world_size + total_batches = len(batches) + batches_per_worker = total_batches // world_size + + # 4. Return worker assignments + return [ + batches[i * batches_per_worker : (i+1) * batches_per_worker] + for i in range(world_size) + ] +``` + +### Gradient Aggregation + +#### Preconditioner Reduction + +```python +def process_preconditioners( + local_preconditioners: Dict[str, Tensor], + num_local_examples: int +): + """Aggregate preconditioners across workers""" + + # Normalize by local dataset size + for name, prec in local_preconditioners.items(): + prec /= num_local_examples + + # Reduce to rank 0 (on CPU to save GPU memory) + for name, prec in local_preconditioners.items(): + prec_cpu = prec.cpu() + dist.reduce(prec_cpu, dst=0, op=dist.ReduceOp.SUM) + + if rank == 0: + # Average across workers + prec_cpu /= world_size + + # Eigendecomposition + eigval, eigvec = torch.linalg.eigh(prec_cpu) + + # Save + save_preconditioner(name, prec_cpu, eigval, eigvec) +``` + +#### Loss Aggregation + +```python +# Each worker tracks local per-document losses +local_losses = torch.zeros(len(dataset), device='cpu') +local_losses[local_indices] = computed_losses + +# Reduce to rank 0 +dist.reduce(local_losses, dst=0, op=dist.ReduceOp.SUM) + +# Rank 0 saves complete loss vector +if rank == 0: + save_losses(local_losses) +``` + +### FSDP Support + +Optional Fully Sharded Data Parallel for models that don't fit on single GPU: + +```python +from torch.distributed.fsdp import fully_shard + +if cfg.fsdp: + # Shard each transformer layer + for layer in model.layers: + fully_shard(layer) + + # Shard root module + fully_shard(model) +``` + +**Benefits**: +- Shard model parameters across GPUs +- Shard gradients and optimizer states +- Enables training models larger than single GPU memory +- Automatic all-gather/reduce-scatter communication + +--- + +## Design Patterns + +### 1. Hook-Based Architecture + +**Pattern**: Template Method + Strategy + +```python +class HookCollectorBase(ABC): + """Template for hook lifecycle""" + + def __enter__(self): + # Template: register hooks + for module in self.modules: + module.register_forward_hook(self.forward_hook) + module.register_full_backward_hook(self.backward_hook) + + @abstractmethod + def forward_hook(self, module, input, output): + """Strategy: how to cache activations""" + + @abstractmethod + def backward_hook(self, module, grad_input, grad_output): + """Strategy: how to compute gradients""" +``` + +**Benefits**: +- Non-intrusive: works with any PyTorch model +- Flexible: different strategies via subclasses +- Efficient: intercepts gradients at computation time + +### 2. Context Manager Protocol + +All collectors use context managers for resource management: + +```python +with GradientCollector(...) as collector: + loss.backward() + collector.process_batch() +# Automatic cleanup: hooks removed, memory freed +``` + +**Benefits**: +- Automatic resource cleanup +- Exception-safe +- Clear API boundaries +- Prevents resource leaks + +### 3. Lazy Evaluation + Streaming + +Gradients never fully materialized in memory: + +``` +Forward → Cache activations → Backward → Compute gradients → Write to disk → Discard + ↓ ↑ + └──────────────────────────── Constant memory ────────────────────────────────┘ +``` + +**Benefits**: +- O(1) memory per batch +- Supports datasets larger than RAM +- Enables distributed processing + +### 4. Composition Over Inheritance + +**GradientProcessor** composes strategies: +```python +processor = GradientProcessor( + normalizers={"layer1": AdafactorNormalizer(...)}, # Strategy + preconditioners={"layer1": torch.tensor(...)}, # Data + projection_dim=16 # Config +) +``` + +**CollectorComputer** composes components: +```python +computer = CollectorComputer( + model=model, # Component + dataset=dataset, # Component + collector=collector, # Strategy + batching=batching_fn # Strategy +) +``` + +### 5. Dataclass-Based Configuration + +All configs use `@dataclass` for type safety and serialization: + +```python +@dataclass +class IndexConfig: + model: str = "EleutherAI/pythia-160m" + dataset: str = "NeelNanda/pile-10k" + projection_dim: int = 16 + normalizer: str = "adafactor" + + def save(self, path: Path): + with open(path, 'w') as f: + json.dump(dataclasses.asdict(self), f) + + @classmethod + def load(cls, path: Path): + with open(path) as f: + return cls(**json.load(f)) +``` + +**Benefits**: +- Type checking +- Default values +- Easy serialization +- CLI parsing via simple_parsing + +### 6. Memory-Mapped Storage + +Uses `numpy.memmap` for out-of-core processing: + +```python +# Create memory-mapped array +gradients = np.memmap( + 'gradients.bin', + dtype=dtype, + mode='w+', + shape=(num_examples,) +) + +# Write gradients (appends to file) +gradients[indices] = new_gradients + +# Read gradients (loads from disk on access) +batch = gradients[start:end] +``` + +**Benefits**: +- Supports datasets larger than RAM +- Efficient random access +- Concurrent read/write +- OS-level caching + +### 7. Separation of Concerns + +Clear module boundaries: + +| Concern | Module | +|---------|--------| +| Gradient computation | `collector/` | +| Gradient storage | `data.py` | +| Gradient transformation | `gradients.py` | +| Similarity search | `query/` | +| Distributed coordination | `distributed.py` | +| Configuration | `config.py` | + +**Benefits**: +- Testable components +- Reusable abstractions +- Clear dependencies +- Easy to extend + +--- + +## File Formats + +### Gradient Index Directory + +``` +runs/my_index/ +├── index_config.json # IndexConfig serialized +├── data.hf/ # HuggingFace Dataset +│ ├── dataset_info.json # Dataset metadata +│ ├── state.json # Dataset state +│ └── data-00000-of-00001.arrow # Arrow format data +├── gradients.bin # Memory-mapped gradients +├── info.json # Gradient metadata +├── processor_config.json # GradientProcessor config +├── normalizers.pth # Normalizer state dicts (PyTorch) +├── preconditioners.pth # Preconditioner matrices (PyTorch) +└── preconditioners_eigen.pth # Eigendecompositions (PyTorch) +``` + +### Gradient Binary Format + +**Structured numpy array** with one field per module: + +```python +dtype = { + 'names': ['layer1', 'layer2', ...], + 'formats': ['(256,)float16', '(512,)float16', ...] +} + +# Shape: (num_examples,) +# Size: num_examples * sum(grad_dims) * dtype_bytes +``` + +**Example**: +```python +# 10,000 examples +# 2 layers: 256-dim and 512-dim +# float16 (2 bytes) +total_size = 10_000 * (256 + 512) * 2 = 15.36 MB +``` + +### info.json Format + +```json +{ + "num_grads": 10000, + "dtype": { + "names": ["layer1", "layer2"], + "formats": ["(256,)float16", "(512,)float16"] + }, + "grad_sizes": { + "layer1": 256, + "layer2": 512 + }, + "base_dtype": "float16" +} +``` + +### FAISS Index Directory + +``` +runs/my_index/faiss_IVF1024_SQfp16_cosine/ +├── config.json # FaissConfig + metadata +├── 0.faiss # Shard 0 +├── 1.faiss # Shard 1 +├── 2.faiss # Shard 2 +└── ... +``` + +**config.json**: +```json +{ + "factory_string": "IVF1024,SQfp16", + "metric": "cosine", + "num_shards": 4, + "shard_size": 250000, + "total_vectors": 1000000, + "dim": 768 +} +``` + +### Score Storage + +``` +runs/scores/ +├── info.json # Metadata +└── scores.bin # Memory-mapped scores +``` + +**Structured array format**: +```python +# For mean/nearest scoring +dtype = [ + ('score_0', 'float32'), # Score value + ('written_0', 'bool') # Whether score has been written +] + +# For individual scoring (multiple queries) +dtype = [ + ('score_0', 'float32'), + ('score_1', 'float32'), + ... + ('written_0', 'bool') +] +``` + +### Training Gradients + +``` +runs/training/ +├── train/ +│ ├── gradients.bin # Accumulated gradients +│ ├── info.json +│ └── processor_config.json +├── train/epoch_0/ # If not accumulating across epochs +│ └── gradients.bin +├── train/epoch_1/ +│ └── gradients.bin +└── order.hf/ # If track_order=True + └── data.arrow # Training order tracking +``` + +--- + +## Summary + +Bergson's architecture demonstrates several key principles: + +1. **Modularity**: Clear separation of concerns with well-defined interfaces +2. **Scalability**: Distributed training, memory-mapped storage, FAISS integration +3. **Efficiency**: Lazy evaluation, streaming, random projections, factored normalizers +4. **Flexibility**: Multiple normalizers, collectors, scoring methods +5. **Usability**: Simple CLI, context managers, sensible defaults +6. **Extensibility**: Hook-based design, composition patterns, strategy pattern + +The codebase is well-structured for both research experimentation and production deployment of gradient-based data attribution at scale. + +### Key Innovations + +- **Hook-based gradient collection**: Non-intrusive per-sample gradients +- **Factored normalization**: 1000× memory reduction for second moments +- **Random projections**: 10,000× compression with preserved similarity +- **Memory-mapped storage**: Process datasets larger than RAM +- **Distributed preconditioners**: Scalable Hessian approximation +- **Sharded FAISS**: Query billion-scale indices + +These architectural choices enable Bergson to scale from small models (14M parameters) to large models (billions of parameters) and from small datasets (10K examples) to massive datasets (millions of examples). diff --git a/docs/CLI_USAGE.md b/docs/CLI_USAGE.md new file mode 100644 index 00000000..a6fed4d9 --- /dev/null +++ b/docs/CLI_USAGE.md @@ -0,0 +1,518 @@ +# Claude-Generated Bergson CLI Usage Guide + +Bergson is a library for tracing the memory of deep neural networks using gradient-based data attribution. This guide covers practical usage of the Bergson CLI with hands-on examples. + +## Quick Start + +The fastest way to get started is to build a gradient index and query it: + +```bash +# Build an index from a small dataset +bergson build runs/quickstart \ + --model EleutherAI/pythia-14m \ + --dataset NeelNanda/pile-10k \ + --truncation \ + --token_batch_size 4096 + +# Query the index interactively +bergson query --index runs/quickstart +``` + +When prompted, enter any text and Bergson will show you the top 5 most influential training examples. + +## CLI Commands Overview + +Bergson provides 4 main commands: + +| Command | Purpose | +|---------|---------| +| `build` | Build a gradient index from training data | +| `query` | Interactively query a pre-built index | +| `reduce` | Aggregate dataset gradients to a query vector | +| `score` | Score a dataset against a query vector | + +--- + +## 1. Building a Gradient Index + +The `build` command collects per-example gradients from your training data and optionally compresses them. + +### Basic Usage + +```bash +bergson build \ + --model \ + --dataset +``` + +### Example: Small Model on Pile + +```bash +bergson build runs/pile_index \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --token_batch_size 2048 \ + --truncation +``` + +### Example: Larger Model with Compression + +```bash +# Compress gradients to dimension 32 using random projection +bergson build runs/compressed_index \ + --model EleutherAI/pythia-410m \ + --dataset NeelNanda/pile-10k \ + --token_batch_size 4096 \ + --projection_dim 32 \ + --normalizer adafactor +``` + +### Example: No Compression (Full Gradients) + +```bash +# Set projection_dim to 0 to disable compression +bergson build runs/full_gradients \ + --model EleutherAI/pythia-70m \ + --dataset NeelNanda/pile-10k \ + --projection_dim 0 \ + --token_batch_size 8192 +``` + +### Example: Custom Dataset Format + +```bash +# For datasets with specific column names +bergson build runs/custom_data \ + --model EleutherAI/pythia-160m \ + --dataset my_org/my_dataset \ + --prompt_column "input" \ + --completion_column "output" \ + --token_batch_size 4096 +``` + +### Example: Distributed Multi-GPU Build + +```bash +# Using FSDP (Fully Sharded Data Parallel) +bergson build runs/distributed_index \ + --model EleutherAI/pythia-1b \ + --dataset NeelNanda/pile-10k \ + --token_batch_size 16384 \ + --fsdp \ + --precision bf16 +``` + +### Key Parameters + +- `--token_batch_size`: Token budget per batch (controls memory usage) +- `--projection_dim`: Compression dimension (default: 16, set to 0 to disable) +- `--normalizer`: Gradient normalization method (`adafactor`, `adam`, `none`) +- `--truncation`: Truncate long documents to fit token budget +- `--precision`: Model dtype (`auto`, `bf16`, `fp16`, `fp32`, `int4`, `int8`) +- `--fsdp`: Enable Fully Sharded Data Parallel for multi-GPU +- `--skip_index`: Only compute preconditioners (don't build index) + +--- + +## 2. Querying an Index + +The `query` command launches an interactive session where you can enter text and find the most influential training examples. + +### Basic Usage + +```bash +bergson query --index +``` + +### Example: Basic Query + +```bash +bergson query --index runs/pile_index + +# Interactive prompt appears: +> Enter your query: The quick brown fox jumps over +# Returns top 5 most similar training examples +``` + +### Example: Query with Model Override + +```bash +# Use a different model than the one that built the index +bergson query \ + --index runs/pile_index \ + --model EleutherAI/pythia-70m +``` + +### Example: FAISS Approximate Search + +```bash +# Use FAISS for faster approximate nearest neighbor search +bergson query \ + --index runs/large_index \ + --faiss +``` + +### Example: Show Least Influential Examples + +```bash +# Reverse the ranking to show lowest influences +bergson query \ + --index runs/pile_index \ + --reverse +``` + +### Example: Custom Text Field + +```bash +# Display a specific column from the dataset +bergson query \ + --index runs/custom_index \ + --text_field "content" +``` + +### Key Parameters + +- `--index`: Path to the pre-built gradient index +- `--model`: Model to use (defaults to the model that built the index) +- `--faiss`: Use FAISS for approximate nearest neighbor search +- `--reverse`: Show lowest influences instead of highest +- `--unit_norm`: Unit normalize query gradient (default: True) +- `--text_field`: Dataset column to display (default: "text") + +--- + +## 3. Reducing a Dataset to a Query Vector + +The `reduce` command aggregates all examples in a dataset into a single gradient vector, useful for creating query vectors. + +### Basic Usage + +```bash +bergson reduce \ + --model \ + --dataset \ + --method mean +``` + +### Example: Mean Query Vector + +```bash +# Create a mean gradient from WikiText +bergson reduce runs/wikitext_query \ + --model EleutherAI/pythia-160m \ + --dataset wikitext \ + --method mean \ + --unit_normalize +``` + +### Example: Sum Aggregation + +```bash +bergson reduce runs/sum_query \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --method sum +``` + +### Key Parameters + +- `--method`: Reduction method (`mean` or `sum`) +- `--unit_normalize`: Unit normalize gradients before reduction +- All standard `IndexConfig` parameters (`model`, `dataset`, `token_batch_size`, etc.) + +--- + +## 4. Scoring a Dataset + +The `score` command computes attribution scores for a dataset against a pre-built query vector, without storing full gradients. + +### Basic Usage + +```bash +bergson score \ + --model \ + --dataset \ + --query_path \ + --score mean +``` + +### Example: Score Against Mean Query + +```bash +# Score training data against WikiText query vector +bergson score runs/pile_scores \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --query_path runs/wikitext_query \ + --score mean \ + --unit_normalize +``` + +### Example: Nearest Neighbor Scoring + +```bash +# For each example, find max similarity across all query examples +bergson score runs/nearest_scores \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --query_path runs/query_index \ + --score nearest +``` + +### Example: Individual Query Scores + +```bash +# Get per-query-example scores (returns a matrix) +bergson score runs/individual_scores \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --query_path runs/query_index \ + --score individual +``` + +### Example: With Preconditioner Mixing + +```bash +bergson score runs/mixed_scores \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --query_path runs/query_vector \ + --query_preconditioner_path runs/query_vector/preconditioner.safetensors \ + --score mean \ + --mixing_coefficient 0.5 +``` + +### Key Parameters + +- `--query_path`: Path to query index (from `reduce` or `build`) +- `--score`: Scoring method (`mean`, `nearest`, `individual`) +- `--unit_normalize`: Unit normalize before scoring +- `--batch_size`: Processing batch size (default: 1024) +- `--query_preconditioner_path`: Path to precomputed query preconditioner +- `--mixing_coefficient`: Weight between query/index preconditioners (0-1) + +--- + +## Common Workflows + +### Workflow 1: Build and Query + +The simplest workflow for data attribution: + +```bash +# Step 1: Build index +bergson build runs/my_index \ + --model EleutherAI/pythia-70m \ + --dataset NeelNanda/pile-10k \ + --token_batch_size 4096 + +# Step 2: Query interactively +bergson query --index runs/my_index +``` + +### Workflow 2: Build → Reduce → Score + +For on-the-fly scoring without storing full gradients: + +```bash +# Step 1: Build index from training data +bergson build runs/training_index \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --token_batch_size 4096 + +# Step 2: Reduce evaluation dataset to query vector +bergson reduce runs/eval_query \ + --model EleutherAI/pythia-160m \ + --dataset wikitext \ + --method mean \ + --unit_normalize + +# Step 3: Score training data against evaluation query +bergson score runs/attribution_scores \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --query_path runs/eval_query \ + --score mean +``` + +### Workflow 3: Distributed Build + +For large-scale distributed training: + +```bash +# Distributed training with appropriate batch size +bergson build runs/large_exp/index \ + --model EleutherAI/pythia-1b \ + --dataset NeelNanda/pile-10k \ + --token_batch_size 8192 \ + --fsdp \ + --precision bf16 +``` + +### Workflow 4: Data Filtering with Attribution + +Filter your dataset based on attribution scores: + +```bash +# Step 1: Build index from high-quality data +bergson build runs/quality_index \ + --model EleutherAI/pythia-160m \ + --dataset high_quality_dataset + +# Step 2: Score candidate data against quality index +bergson score runs/candidate_scores \ + --model EleutherAI/pythia-160m \ + --dataset candidate_dataset \ + --query_path runs/quality_index \ + --score mean + +# Step 3: Filter data using scores (custom script) +python scripts/filter_data.py \ + --scores runs/candidate_scores/scores.npy \ + --dataset candidate_dataset \ + --threshold 0.5 \ + --output runs/filtered_dataset +``` + +--- + +## Advanced Usage + +### Attention Head Gradients + +Split attention modules into per-head gradients for finer-grained attribution: + +```bash +bergson build runs/head_gradients \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --split_attention_modules \ + --token_batch_size 4096 +``` + +### Module Filtering + +Exclude specific layers using glob patterns: + +```bash +bergson build runs/filtered_modules \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --filter_modules "*.embed*" "*.ln_*" \ + --token_batch_size 4096 +``` + +### Custom Precision and Memory Optimization + +```bash +# Use INT8 quantization for memory-constrained environments +bergson build runs/int8_index \ + --model EleutherAI/pythia-410m \ + --dataset NeelNanda/pile-10k \ + --precision int8 \ + --token_batch_size 8192 + +# Use BF16 for better numerical stability on modern GPUs +bergson build runs/bf16_index \ + --model EleutherAI/pythia-410m \ + --dataset NeelNanda/pile-10k \ + --precision bf16 \ + --token_batch_size 16384 +``` + +### GRPO (Policy Gradient) Support + +For RL/preference data with reward columns: + +```bash +bergson build runs/grpo_index \ + --model EleutherAI/pythia-160m \ + --dataset rlhf_dataset \ + --reward_column "reward" \ + --token_batch_size 4096 +``` + +### Skip Index Building (Preconditioners Only) + +Build only the preconditioners without creating the full index: + +```bash +bergson build runs/preconditioners_only \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --skip_index +``` + +--- + +## Tips and Best Practices + +1. **Start Small**: Test with small models (pythia-14m, pythia-70m) and small datasets before scaling up. + +2. **Use Auto Batch Size**: Always run `autobatchsize` first to avoid OOM errors and maximize GPU utilization. + +3. **Enable Truncation**: Use `--truncation` to handle variable-length documents and prevent memory issues. + +4. **Choose Compression Wisely**: + - Use `--projection_dim 16-32` for large-scale builds + - Use `--projection_dim 0` for maximum accuracy on small datasets + +5. **Normalize for Stability**: Use `--normalizer adafactor` or `--normalizer adam` for better gradient stability. + +6. **Use BF16 on Modern GPUs**: `--precision bf16` provides better numerical stability than FP16 on A100/H100 GPUs. + +7. **Multi-GPU Strategy**: + - Use `--fsdp` for models that don't fit on a single GPU + - Run `autobatchsize` once before launching distributed jobs + +8. **FAISS for Large Indices**: Use `--faiss` in `query` command for faster searches on large indices. + +--- + +## Troubleshooting + +### Out of Memory Errors + +```bash +# Reduce token_batch_size +bergson build runs/my_index \ + --model EleutherAI/pythia-160m \ + --dataset NeelNanda/pile-10k \ + --token_batch_size 1024 # Lower value + +# Use a lower token_batch_size value +``` + +### Slow Query Performance + +```bash +# Use FAISS for approximate nearest neighbor search +bergson query --index runs/large_index --faiss +``` + +### Dataset Column Name Issues + +```bash +# Specify custom column names +bergson build runs/my_index \ + --model EleutherAI/pythia-160m \ + --dataset my_dataset \ + --prompt_column "input_text" \ + --completion_column "output_text" +``` + +### Distributed Training Issues + +```bash +# Ensure you set an appropriate token_batch_size for multi-GPU training +bergson build runs/index --fsdp --token_batch_size 8192 +``` + +--- + +## Further Reading + +- For implementation details, see the main codebase in `bergson/` +- For advanced scripting examples, check `scripts/` directory +- For API usage beyond CLI, explore the library imports in your Python code diff --git a/docs/benchmarks/cli_benchmark_1x_NVIDIA_H100_80GB_HBM3.png b/docs/benchmarks/cli_benchmark_1x_NVIDIA_H100_80GB_HBM3.png new file mode 100644 index 00000000..aa3210a4 Binary files /dev/null and b/docs/benchmarks/cli_benchmark_1x_NVIDIA_H100_80GB_HBM3.png differ diff --git a/docs/benchmarks/cli_benchmark_8x_NVIDIA_H100_80GB_HBM3.png b/docs/benchmarks/cli_benchmark_8x_NVIDIA_H100_80GB_HBM3.png new file mode 100644 index 00000000..459abd14 Binary files /dev/null and b/docs/benchmarks/cli_benchmark_8x_NVIDIA_H100_80GB_HBM3.png differ diff --git a/docs/benchmarks/index.rst b/docs/benchmarks/index.rst new file mode 100644 index 00000000..ea5ffadb --- /dev/null +++ b/docs/benchmarks/index.rst @@ -0,0 +1,44 @@ +Benchmarks +========== + +This section provides indicative performance numbers for the Bergson benchmark suite. Performance will vary based on your hardware configuration and choice of hyperparameters. Indicative performance for dattri provided where possible. + +8 GPU Configuration (CLI) +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. image:: cli_benchmark_8x_NVIDIA_H100_80GB_HBM3.png + :alt: 8 GPU Benchmark + :align: center + +1 GPU Configuration (CLI) +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. image:: cli_benchmark_1x_NVIDIA_H100_80GB_HBM3.png + :alt: 1 GPU Benchmark + :align: center + +1 GPU Configuration (Random Projection) +~~~~~~~~~~~~~~~~~~~ + +.. image:: projection_comparison_with_projection.png + :alt: 1 GPU Comparison with Projection (16x16) + :align: center + +1 GPU Configuration (No Random Projection) +~~~~~~~~~~~~~~~~~~~ + +.. image:: projection_comparison_without_projection.png + :alt: 1 GPU Comparison without Projection + :align: center + +1 GPU In-Memory Benchmark +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. image:: inmem_benchmark_1x_NVIDIA_H100_80GB_HBM3.png + :alt: 1 GPU In-Memory Benchmark + :align: center + +Running Your Own Benchmarks +---------------------------- + +To generate benchmarks for your specific setup, you can use the shell scripts in `benchmarks`. diff --git a/docs/benchmarks/inmem_benchmark_1x_NVIDIA_H100_80GB_HBM3.png b/docs/benchmarks/inmem_benchmark_1x_NVIDIA_H100_80GB_HBM3.png new file mode 100644 index 00000000..54947378 Binary files /dev/null and b/docs/benchmarks/inmem_benchmark_1x_NVIDIA_H100_80GB_HBM3.png differ diff --git a/docs/benchmarks/projection_comparison_with_projection.png b/docs/benchmarks/projection_comparison_with_projection.png new file mode 100644 index 00000000..4879e12e Binary files /dev/null and b/docs/benchmarks/projection_comparison_with_projection.png differ diff --git a/docs/benchmarks/projection_comparison_without_projection.png b/docs/benchmarks/projection_comparison_without_projection.png new file mode 100644 index 00000000..42809716 Binary files /dev/null and b/docs/benchmarks/projection_comparison_without_projection.png differ diff --git a/docs/cli.rst b/docs/cli.rst index 797fc392..16c05e99 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -54,7 +54,6 @@ precomputed query gradients. --dataset NeelNanda/pile-10k \ --method mean \ --unit_normalize - --projection_dim 0 .. autoclass:: bergson.__main__.Score :members: @@ -69,4 +68,3 @@ precomputed query gradients. runs/my-scores \ --query_path /runs/my-index \ --dataset EleutherAI/SmolLM2-135M-10B - --projection_dim 0 diff --git a/docs/index.rst b/docs/index.rst index f972e8fe..1e6d5679 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,6 +39,14 @@ Load the gradients: gradients = load_gradients(Path("runs/quickstart")) +Benchmarks +---------- + +.. toctree:: + :maxdepth: 2 + + benchmarks/index + API Reference --------------