From 3aac179c4abc728da705ab6745470c208bc9a72d Mon Sep 17 00:00:00 2001 From: David Johnston Date: Tue, 27 Jan 2026 05:45:42 +0000 Subject: [PATCH] Add preconditioner comparison and semantic experiments --- .gitignore | 6 + bergson/build.py | 2 +- bergson/collector/collector.py | 1 + bergson/config.py | 5 + bergson/score/score.py | 15 +- bergson/utils/math.py | 34 + bergson/utils/worker_utils.py | 1 + data/generate_facts.py | 17 + examples/semantic.py | 56 + examples/semantic/__init__.py | 98 + examples/semantic/asymmetric.py | 3350 +++++++++++++++++++ examples/semantic/attribute_preservation.py | 1785 ++++++++++ examples/semantic/data.py | 241 ++ examples/semantic/experiment.py | 237 ++ examples/semantic/metrics.py | 366 ++ examples/semantic/preconditioners.py | 737 ++++ examples/semantic/scoring.py | 284 ++ examples/semantics_experiment.py | 68 + examples/slurm/data_parallel_score.sh | 2 +- examples/train_lora.py | 311 ++ pyproject.toml | 3 + skills/asymmetric-style.md | 230 ++ skills/attribute-preservation.md | 168 + skills/gradient-debug.md | 92 + skills/preconditioner-analysis.md | 103 + skills/semantic-metrics.md | 74 + 26 files changed, 8274 insertions(+), 12 deletions(-) create mode 100644 data/generate_facts.py create mode 100644 examples/semantic.py create mode 100644 examples/semantic/__init__.py create mode 100644 examples/semantic/asymmetric.py create mode 100644 examples/semantic/attribute_preservation.py create mode 100644 examples/semantic/data.py create mode 100644 examples/semantic/experiment.py create mode 100644 examples/semantic/metrics.py create mode 100644 examples/semantic/preconditioners.py create mode 100644 examples/semantic/scoring.py create mode 100644 examples/semantics_experiment.py create mode 100644 examples/train_lora.py create mode 100644 skills/asymmetric-style.md create mode 100644 skills/attribute-preservation.md create mode 100644 skills/gradient-debug.md create mode 100644 skills/preconditioner-analysis.md create mode 100644 skills/semantic-metrics.md diff --git a/.gitignore b/.gitignore index 870019dd..5e156ef7 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,9 @@ influence_results/ tmp/ .idea/ uv.lock +data/*.hf +zeki_requirements.txt +.python-version +*package-lock.json +*package.json +david_wips/ diff --git a/bergson/build.py b/bergson/build.py index 628a54bf..e454ed98 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -56,7 +56,7 @@ def build_worker( init_method=f"tcp://{addr}:{port}", device_id=torch.device(f"cuda:{local_rank}"), rank=rank, - timeout=timedelta(hours=1), + timeout=timedelta(minutes=2), world_size=world_size, ) diff --git a/bergson/collector/collector.py b/bergson/collector/collector.py index 1601e4ca..e548d4df 100644 --- a/bergson/collector/collector.py +++ b/bergson/collector/collector.py @@ -597,6 +597,7 @@ def fwd_bwd(model, x: Tensor, y: Tensor, batch: dict): logits.reshape(-1, logits.size(-1)), y[:, 1:].flatten(), reduction="none", + label_smoothing=cfg.label_smoothing, ).reshape_as(y[:, 1:]) losses = losses.sum(1) / denoms if "advantage" in batch: diff --git a/bergson/config.py b/bergson/config.py index c32db1f0..adfe496c 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -169,6 +169,11 @@ class IndexConfig: loss_reduction: Literal["mean", "sum"] = "mean" """Reduction method for the loss function.""" + label_smoothing: float = 0.0 + """Label smoothing coefficient for cross-entropy loss. When > 0, prevents + near-zero gradients for high-confidence predictions that can cause numerical + instability. Recommended value: 0.005-0.01.""" + stream_shard_size: int = 400_000 """Shard size for streaming the dataset into Dataset objects.""" diff --git a/bergson/score/score.py b/bergson/score/score.py index 2fac6c2b..63edcdc8 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -18,6 +18,7 @@ from bergson.gradients import GradientProcessor from bergson.score.score_writer import MemmapScoreWriter from bergson.score.scorer import Scorer +from bergson.utils.math import compute_damped_inverse from bergson.utils.utils import ( assert_type, convert_precision_to_torch, @@ -177,16 +178,10 @@ def precondition_ds( ) # Compute H^(-1) via eigendecomposition and apply to query gradients - h_inv = {} - for name, H in mixed_preconditioner.items(): - H = H.to(device=device, dtype=torch.float64) - damping_val = 0.1 * H.abs().mean() - H = H + damping_val * torch.eye(H.shape[0], device=H.device, dtype=H.dtype) - - eigval, eigvec = torch.linalg.eigh(H) - h_inv[name] = (eigvec * (1.0 / eigval) @ eigvec.mT).to( - mixed_preconditioner[name].dtype - ) + h_inv = { + name: compute_damped_inverse(H.to(device=device)) + for name, H in mixed_preconditioner.items() + } def precondition(batch): for name in target_modules: diff --git a/bergson/utils/math.py b/bergson/utils/math.py index fe5b2856..b0280ce1 100644 --- a/bergson/utils/math.py +++ b/bergson/utils/math.py @@ -67,6 +67,40 @@ def psd_rsqrt(A: Tensor) -> Tensor: return rsqrt +def compute_damped_inverse( + H: Tensor, + damping_factor: float = 0.1, + dtype: torch.dtype = torch.float64, + regularizer: Tensor | None = None, +) -> Tensor: + """Compute H^(-1) with damping for numerical stability. + + Uses eigendecomposition to compute the inverse of a positive semi-definite + matrix with adaptive damping based on the matrix's mean absolute value. + + Args: + H: Positive semi-definite matrix to invert. + damping_factor: Multiplier for the damping term (default: 0.1). + dtype: Dtype for intermediate computation (default: float64 for stability). + regularizer: Optional matrix to use as regularizer instead of identity. + If provided, computes inv(H + damping_factor * regularizer). + If None (default), uses scaled identity: inv(H + damping_factor * |H|_mean * I). + + Returns: + The damped inverse H^(-1) in the original dtype of H. + """ + original_dtype = H.dtype + H = H.to(dtype=dtype) + if regularizer is not None: + regularizer = regularizer.to(dtype=dtype, device=H.device) + H = H + damping_factor * regularizer + else: + damping_val = damping_factor * H.abs().mean() + H = H + damping_val * torch.eye(H.shape[0], device=H.device, dtype=H.dtype) + eigval, eigvec = torch.linalg.eigh(H) + return (eigvec * (1.0 / eigval) @ eigvec.mT).to(original_dtype) + + def trace(matrices: Tensor) -> Tensor: """Version of `torch.trace` that works for batches of matrices.""" diag = torch.linalg.diagonal(matrices) diff --git a/bergson/utils/worker_utils.py b/bergson/utils/worker_utils.py index f5469a87..a38c0cfe 100644 --- a/bergson/utils/worker_utils.py +++ b/bergson/utils/worker_utils.py @@ -138,6 +138,7 @@ def setup_model_and_peft( try: peft_config = PeftConfig.from_pretrained(cfg.model) except ValueError: + print(f"PEFT config not found for model {cfg.model}") peft_config = None if peft_config is None: diff --git a/data/generate_facts.py b/data/generate_facts.py new file mode 100644 index 00000000..32e0e3a9 --- /dev/null +++ b/data/generate_facts.py @@ -0,0 +1,17 @@ +from argparse import ArgumentParser + +from datasets import Dataset + +from .dataset import fact_generator + +if __name__ == "__main__": + from argparse import ArgumentParser + + from datasets import Dataset + + parser = ArgumentParser() + parser.add_argument("--num_facts", type=int, default=1000) + args = parser.parse_args() + + dataset = fact_generator(args.num_facts) + Dataset.from_list(list(dataset)).save_to_disk("data/facts_dataset.hf") diff --git a/examples/semantic.py b/examples/semantic.py new file mode 100644 index 00000000..ee13122a --- /dev/null +++ b/examples/semantic.py @@ -0,0 +1,56 @@ +"""Backward-compatible wrapper for semantic experiments. + +This module re-exports all functions from the semantic package for backward +compatibility. New code should import directly from examples.semantic instead. +""" + +# Re-export everything from the semantic package +from semantic import ( + build_style_indices, + build_style_lookup, + compute_between_preconditioner, + compute_between_preconditioner_covariance, + compute_between_preconditioner_means, + compute_metrics, + compute_metrics_groupwise, + compute_mixed_preconditioner, + compute_scores_fast, + compute_scores_with_bergson, + create_data, + create_index, + create_qwen_only_dataset, + finetune, + load_scores_matrix, + main, + reword, + run_preconditioner_comparison, +) + +__all__ = [ + # Data creation + "reword", + "create_data", + "create_qwen_only_dataset", + # Scoring + "load_scores_matrix", + "compute_scores_fast", + "compute_scores_with_bergson", + # Metrics + "build_style_lookup", + "compute_metrics_groupwise", + "compute_metrics", + # Preconditioners + "build_style_indices", + "compute_between_preconditioner_covariance", + "compute_between_preconditioner_means", + "compute_between_preconditioner", + "compute_mixed_preconditioner", + # Experiment + "create_index", + "finetune", + "run_preconditioner_comparison", + "main", +] + +if __name__ == "__main__": + main() diff --git a/examples/semantic/__init__.py b/examples/semantic/__init__.py new file mode 100644 index 00000000..90202df4 --- /dev/null +++ b/examples/semantic/__init__.py @@ -0,0 +1,98 @@ +"""Semantic experiments for analyzing gradient-based embeddings. + +This package provides tools for: +- Creating reworded datasets in different styles (Shakespeare, Pirate) +- Computing pairwise similarity scores from gradient embeddings +- Analyzing semantic vs stylistic similarity patterns +- Comparing different preconditioning strategies +- Asymmetric style distribution experiments for style suppression validation +""" + +from .asymmetric import ( + AsymmetricConfig, + AsymmetricMetrics, + compute_asymmetric_metrics, + compute_style_preconditioner, + create_asymmetric_dataset, + create_asymmetric_index, + print_metrics, + run_asymmetric_experiment, + score_asymmetric_eval, +) +from .attribute_preservation import ( + AttributePreservationConfig, + AttributePreservationMetrics, + compute_attribute_metrics, + create_attribute_dataset, + create_attribute_index, + create_styled_datasets, + print_attribute_metrics, + run_attribute_preservation_experiment, + score_attribute_eval, +) +from .data import create_data, create_qwen_only_dataset, reword +from .experiment import ( + create_index, + finetune, + main, + run_preconditioner_comparison, +) +from .metrics import build_style_lookup, compute_metrics, compute_metrics_groupwise +from .preconditioners import ( + build_style_indices, + compute_between_preconditioner, + compute_between_preconditioner_covariance, + compute_between_preconditioner_means, + compute_mixed_preconditioner, +) +from .scoring import ( + compute_scores_fast, + compute_scores_with_bergson, + load_scores_matrix, +) + +__all__ = [ + # Data creation + "reword", + "create_data", + "create_qwen_only_dataset", + # Scoring + "load_scores_matrix", + "compute_scores_fast", + "compute_scores_with_bergson", + # Metrics + "build_style_lookup", + "compute_metrics_groupwise", + "compute_metrics", + # Preconditioners + "build_style_indices", + "compute_between_preconditioner_covariance", + "compute_between_preconditioner_means", + "compute_between_preconditioner", + "compute_mixed_preconditioner", + # Experiment + "create_index", + "finetune", + "run_preconditioner_comparison", + "main", + # Asymmetric style experiment + "AsymmetricConfig", + "AsymmetricMetrics", + "create_asymmetric_dataset", + "create_asymmetric_index", + "score_asymmetric_eval", + "compute_asymmetric_metrics", + "compute_style_preconditioner", + "print_metrics", + "run_asymmetric_experiment", + # Attribute preservation experiment + "AttributePreservationConfig", + "AttributePreservationMetrics", + "create_attribute_dataset", + "create_attribute_index", + "create_styled_datasets", + "score_attribute_eval", + "compute_attribute_metrics", + "print_attribute_metrics", + "run_attribute_preservation_experiment", +] diff --git a/examples/semantic/asymmetric.py b/examples/semantic/asymmetric.py new file mode 100644 index 00000000..ae6d3178 --- /dev/null +++ b/examples/semantic/asymmetric.py @@ -0,0 +1,3350 @@ +"""Asymmetric style distribution experiment for style suppression validation. + +This module creates datasets where semantic matches are only available in the dominant +style, forcing attribution to choose between style similarity and semantic similarity. +""" + +from dataclasses import dataclass +from pathlib import Path + +import ml_dtypes # noqa: F401 # registers bfloat16 dtype with numpy +import numpy as np +from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk + +from examples.semantic.data import ( + HF_ANALYSIS_MODEL, + load_experiment_data, +) + + +def _load_gradients_as_float(grads: np.memmap, name: str) -> np.ndarray: + """Load a gradient field and convert from bfloat16 to float32. + + Args: + grads: Structured gradient memmap. + name: Field name to access. + + Returns: + Float32 numpy array. + """ + g = grads[name] + # Gradients are stored as bfloat16 (2-byte void) + if g.dtype == np.dtype("|V2"): + g = g.view(ml_dtypes.bfloat16).astype(np.float32) + return g + + +@dataclass +class AsymmetricConfig: + """Configuration for asymmetric style experiment.""" + + dominant_style: str = "shakespeare" + minority_style: str = "pirate" + dominant_ratio: float = 0.95 # Fraction of training in dominant style + exclusive_ratio: float = 0.5 # Fraction of facts exclusive to dominant style + seed: int = 42 + # HuggingFace dataset repo. If set, skips local generation and downloads from HF. + hf_dataset: str | None = None + # Template split for train/test segregation (only used for local generation) + # Train uses templates < train_template_cutoff, eval majority uses templates >= cutoff + train_template_cutoff: int = 5 + + +def create_asymmetric_dataset( + config: AsymmetricConfig, + output_dir: Path | str, +) -> tuple[Dataset, Dataset]: + """Create asymmetric training and evaluation datasets. + + Splits facts into: + - Exclusive facts: only appear in dominant style (for testing semantic matching) + - Shared facts: appear in both styles (for style ratio control) + + For train/test segregation: + - Training uses templates < train_template_cutoff (default: 0-4) + - Eval majority style control uses templates >= cutoff (default: 5+) + This ensures no exact text overlap between train and eval majority control. + + Args: + config: Experiment configuration. + output_dir: Directory to save datasets. + + Returns: + (train_dataset, eval_dataset) tuple. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + train_path = output_dir / "train.hf" + eval_path = output_dir / "eval.hf" + + # Return cached if exists + if train_path.exists() and eval_path.exists(): + print(f"Loading cached datasets from {output_dir}") + return load_from_disk(str(train_path)), load_from_disk(str(eval_path)) + + # Load original facts to get metadata columns + original = load_from_disk("data/facts_dataset.hf") + if isinstance(original, DatasetDict): + original = original["train"] + fact_to_meta = {row["fact"]: row for row in original} + + # Load style-specific datasets (Qwen only for consistency) + style_datasets = { + "shakespeare": load_from_disk( + "data/facts_dataset_shakespeare-Qwen3-8B-Base.hf" + ), + "pirate": load_from_disk("data/facts_dataset_pirate-Qwen3-8B-Base.hf"), + } + for name in style_datasets: + if isinstance(style_datasets[name], DatasetDict): + style_datasets[name] = style_datasets[name]["train"] + + # Add back metadata columns from original + ds = style_datasets[name] + for col in original.column_names: + if col not in ds.column_names: + restored_col = [fact_to_meta[row["fact"]][col] for row in ds] + ds = ds.add_column(col, restored_col) + style_datasets[name] = ds + + dominant_ds = style_datasets[config.dominant_style] + minority_ds = style_datasets[config.minority_style] + + # Get unique (identifier, field) pairs - these represent underlying semantic facts + # Each pair has multiple templates (different surface forms of the same fact) + semantic_facts = list({(row["identifier"], row["field"]) for row in original}) + n_semantic_facts = len(semantic_facts) + + # Split into exclusive (dominant-only) and shared by semantic fact + rng = np.random.default_rng(config.seed) + rng.shuffle(semantic_facts) + + n_exclusive = int(n_semantic_facts * config.exclusive_ratio) + exclusive_semantic_facts = set(semantic_facts[:n_exclusive]) + shared_semantic_facts = set(semantic_facts[n_exclusive:]) + + print(f"Total unique semantic facts (identifier, field pairs): {n_semantic_facts}") + print(f"Exclusive to {config.dominant_style}: {len(exclusive_semantic_facts)}") + print(f"Shared between styles: {len(shared_semantic_facts)}") + print(f"Template cutoff for train/eval split: {config.train_template_cutoff}") + + # Build training set with template filtering + # 1. Dominant style: only templates < cutoff (to reserve rest for eval majority control) + train_dominant_indices = [ + i + for i, row in enumerate(dominant_ds) + if row["template"] < config.train_template_cutoff + ] + train_dominant = dominant_ds.select(train_dominant_indices) + + # 2. Minority style only for shared facts (any template since minority eval is different) + minority_shared_indices = [ + i + for i, row in enumerate(minority_ds) + if (row["identifier"], row["field"]) in shared_semantic_facts + and row["template"] < config.train_template_cutoff + ] + train_minority = minority_ds.select(minority_shared_indices) + + # Add style column + train_dominant = train_dominant.add_column( + "style", [config.dominant_style] * len(train_dominant) + ) + train_minority = train_minority.add_column( + "style", [config.minority_style] * len(train_minority) + ) + + # Combine and shuffle + train_ds = concatenate_datasets([train_dominant, train_minority]) + train_ds = train_ds.shuffle(seed=config.seed) + + print("\nTraining set composition:") + print(f" {config.dominant_style}: {len(train_dominant)} samples") + print(f" {config.minority_style}: {len(train_minority)} samples") + print(f" Total: {len(train_ds)} samples") + print(f" Dominant ratio: {len(train_dominant) / len(train_ds):.2%}") + + # Build eval set: query exclusive facts in minority style + # Use templates >= cutoff to ensure no overlap with train + # These facts don't exist in minority style in training, so the model + # must use semantic matching (not style matching) to find them + eval_minority_indices = [ + i + for i, row in enumerate(minority_ds) + if (row["identifier"], row["field"]) in exclusive_semantic_facts + and row["template"] >= config.train_template_cutoff + ] + eval_ds = minority_ds.select(eval_minority_indices) + eval_ds = eval_ds.add_column("style", [config.minority_style] * len(eval_ds)) + + # Add expected_match_style to indicate where the ground truth is + eval_ds = eval_ds.add_column( + "expected_match_style", [config.dominant_style] * len(eval_ds) + ) + + print("\nEval set:") + print(f" Queries in {config.minority_style} style: {len(eval_ds)}") + print(f" Ground truth only in {config.dominant_style} style") + print( + f" Using templates >= {config.train_template_cutoff} (no overlap with train)" + ) + + # Save datasets + train_ds.save_to_disk(str(train_path)) + eval_ds.save_to_disk(str(eval_path)) + print(f"\nSaved datasets to {output_dir}") + + return train_ds, eval_ds + + +def create_asymmetric_index( + config: AsymmetricConfig, + base_path: Path | str, + analysis_model: str | None = None, +) -> Path: + """Create bergson index for asymmetric training set. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + analysis_model: Model to use for gradient collection. Defaults to HF_ANALYSIS_MODEL. + + Returns: + Path to the created index. + """ + import subprocess + + if analysis_model is None: + analysis_model = HF_ANALYSIS_MODEL + + base_path = Path(base_path) + data_path = base_path / "data" + index_path = base_path / "index" + + # Load or create dataset + if config.hf_dataset: + # Download from HuggingFace and save locally for bergson + print(f"Loading dataset from HuggingFace: {config.hf_dataset}") + dataset_dict = load_experiment_data(hf_repo=config.hf_dataset) + data_path.mkdir(parents=True, exist_ok=True) + for split_name, split_ds in dataset_dict.items(): + split_path = data_path / f"{split_name}.hf" + if not split_path.exists(): + split_ds.save_to_disk(str(split_path)) + print(f" Saved {split_name} to {split_path}") + else: + # Generate locally + create_asymmetric_dataset(config, data_path) + + if index_path.exists(): + print(f"Index already exists at {index_path}, skipping...") + return index_path + + cmd = [ + "bergson", + "build", + str(index_path), + "--model", + analysis_model, + "--dataset", + str(data_path / "train.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + "16", + "--token_batch_size", + "6000", + ] + + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build failed") + print(result.stdout) + + return index_path + + +def score_asymmetric_eval( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, + damping_factor: float = 0.1, + regularizer_name: str | None = None, + eval_prompt_column: str = "fact", + eval_completion_column: str = "reworded", +) -> np.ndarray: + """Score eval queries against training index. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner subdirectory (None for no precond). + damping_factor: Damping factor for matrix inversion (default: 0.1). + regularizer_name: Name of preconditioner to use as regularizer instead of identity. + If provided, computes inv(H + damping_factor * H_regularizer). + Useful for regularizing rank-deficient preconditioners like r_between + with a well-conditioned matrix like H_train or H_eval. + eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). + eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). + Set to "question"/"answer" for semantic-only attribution where gradients + only come from the answer tokens. + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine output path (include damping factor, regularizer, and eval columns in cache key) + damping_suffix = f"_d{damping_factor:.0e}" if damping_factor != 0.1 else "" + reg_suffix = f"_reg_{regularizer_name}" if regularizer_name else "" + # Add eval column suffix if not using default columns + eval_col_suffix = "" + if eval_prompt_column != "fact" or eval_completion_column != "reworded": + eval_col_suffix = f"_{eval_prompt_column}_{eval_completion_column}" + if preconditioner_name: + scores_path = ( + base_path + / f"scores_{preconditioner_name}{damping_suffix}{reg_suffix}{eval_col_suffix}" + ) + precond_path = base_path / preconditioner_name + else: + scores_path = ( + base_path + / f"scores_no_precond{damping_suffix}{reg_suffix}{eval_col_suffix}" + ) + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load train and eval datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + n_train = len(train_ds) + n_eval = len(eval_ds) + + print(f"Scoring {n_eval} eval queries against {n_train} train samples") + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + # Load regularizer preconditioner if specified + reg_proc = None + if regularizer_name: + reg_path = base_path / regularizer_name + if (reg_path / "preconditioners.pth").exists(): + print(f"Loading regularizer from {reg_path}") + reg_proc = GradientProcessor.load(reg_path) + else: + print( + f"Warning: regularizer {regularizer_name} not found at {reg_path}" + ) + + print(f"Loading preconditioner from {precond_path} (damping={damping_factor})") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + regularizer = None + if reg_proc is not None and name in reg_proc.preconditioners: + regularizer = reg_proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse( + H, damping_factor=damping_factor, regularizer=regularizer + ) + + # Concatenate train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize train grads (as index) + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # For eval, we need to compute gradients on the fly + # Use bergson to compute eval gradients with same projection + print("Computing eval gradients...") + + # Use different cache path based on eval columns + if eval_prompt_column == "fact" and eval_completion_column == "reworded": + eval_grads_path = base_path / "eval_grads" + else: + eval_grads_path = ( + base_path / f"eval_grads_{eval_prompt_column}_{eval_completion_column}" + ) + + if not eval_grads_path.exists(): + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + cmd = [ + "bergson", + "build", + str(eval_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval.hf"), + "--drop_columns", + "False", + "--prompt_column", + eval_prompt_column, + "--completion_column", + eval_completion_column, + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for eval failed") + print(result.stdout) + + # Load eval gradients + eval_grads = load_gradients(eval_grads_path, structured=True) + eval_grad_list = [] + for name in tqdm(module_names, desc="Loading eval grads"): + g = torch.from_numpy(_load_gradients_as_float(eval_grads, name)) + if h_inv: + # Apply preconditioning + g = (g.cuda() @ h_inv[name]).cpu() + eval_grad_list.append(g) + eval_grad_tensor = torch.cat(eval_grad_list, dim=1) + + # Unit normalize eval grads (as query) + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores: eval @ train.T gives (n_eval, n_train) + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + # Save scores + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +@dataclass +class AsymmetricMetrics: + """Metrics for asymmetric style suppression experiment.""" + + # Semantic accuracy (same subject AND field = same underlying fact) + top1_semantic_accuracy: float # Top-1 same subject AND field + top5_semantic_recall: float # Any of top-5 same subject AND field + top10_semantic_recall: float # Any of top-10 same subject AND field + + # Style leakage (lower is better - means not matching on style) + top1_style_leakage: float # Top-1 is same (minority) style + top5_style_leakage: float # Fraction of top-5 in same (minority) style + top10_style_leakage: float # Fraction of top-10 in same (minority) style + + # Breakdown by attribute + top1_subject_accuracy: float # Top-1 same subject + top1_field_accuracy: float # Top-1 same field type + + +def compute_asymmetric_metrics( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, + damping_factor: float = 0.1, + regularizer_name: str | None = None, + eval_prompt_column: str = "fact", + eval_completion_column: str = "reworded", +) -> AsymmetricMetrics: + """Compute metrics for asymmetric style suppression. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner to use. + damping_factor: Damping factor for matrix inversion (default: 0.1). + regularizer_name: Name of preconditioner to use as regularizer instead of identity. + eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). + eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). + Set to "question"/"answer" for semantic-only attribution. + + Returns: + AsymmetricMetrics dataclass. + """ + base_path = Path(base_path) + data_path = base_path / "data" + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Load scores + scores = score_asymmetric_eval( + config, + base_path, + preconditioner_name, + damping_factor=damping_factor, + regularizer_name=regularizer_name, + eval_prompt_column=eval_prompt_column, + eval_completion_column=eval_completion_column, + ) + + n_eval = len(eval_ds) + + # Extract metadata + train_styles = train_ds["style"] + train_identifiers = train_ds["identifier"] + train_fields = train_ds["field"] + + eval_identifiers = eval_ds["identifier"] + eval_fields = eval_ds["field"] + + # Get top-k indices for each query + top_k = 10 + top_indices = np.argsort(-scores, axis=1)[:, :top_k] + + # Compute metrics + semantic_top1 = 0 + semantic_top5 = 0 + semantic_top10 = 0 + style_leak_top1 = 0 + style_leak_top5 = 0 + style_leak_top10 = 0 + subject_top1 = 0 + field_top1 = 0 + + for i in range(n_eval): + query_identifier = eval_identifiers[i] + query_field = eval_fields[i] + + top_k_idx = top_indices[i] + + # Check semantic matching (same identifier AND field = same underlying fact) + for k, idx in enumerate(top_k_idx): + if ( + train_identifiers[idx] == query_identifier + and train_fields[idx] == query_field + ): + if k == 0: + semantic_top1 += 1 + if k < 5: + semantic_top5 += 1 + break + if k < 10: + semantic_top10 += 1 + break + + # Check style leakage + top1_style = train_styles[top_k_idx[0]] + if top1_style == config.minority_style: + style_leak_top1 += 1 + + top5_minority = sum( + 1 for idx in top_k_idx[:5] if train_styles[idx] == config.minority_style + ) + style_leak_top5 += top5_minority / 5 + + top10_minority = sum( + 1 for idx in top_k_idx[:10] if train_styles[idx] == config.minority_style + ) + style_leak_top10 += top10_minority / 10 + + # Check attribute matching for top-1 + top1_idx = top_k_idx[0] + if train_identifiers[top1_idx] == query_identifier: + subject_top1 += 1 + if train_fields[top1_idx] == query_field: + field_top1 += 1 + + return AsymmetricMetrics( + top1_semantic_accuracy=semantic_top1 / n_eval, + top5_semantic_recall=semantic_top5 / n_eval, + top10_semantic_recall=semantic_top10 / n_eval, + top1_style_leakage=style_leak_top1 / n_eval, + top5_style_leakage=style_leak_top5 / n_eval, + top10_style_leakage=style_leak_top10 / n_eval, + top1_subject_accuracy=subject_top1 / n_eval, + top1_field_accuracy=field_top1 / n_eval, + ) + + +def print_metrics(metrics: AsymmetricMetrics, name: str) -> None: + """Print metrics in a formatted way.""" + print(f"\n{'=' * 60}") + print(f"RESULTS: {name}") + print("=" * 60) + + print("\nSemantic Accuracy (higher is better):") + print(f" Top-1 accuracy: {metrics.top1_semantic_accuracy:.2%}") + print(f" Top-5 recall: {metrics.top5_semantic_recall:.2%}") + print(f" Top-10 recall: {metrics.top10_semantic_recall:.2%}") + + print("\nStyle Leakage (lower is better):") + print(f" Top-1 leakage: {metrics.top1_style_leakage:.2%}") + print(f" Top-5 leakage: {metrics.top5_style_leakage:.2%}") + print(f" Top-10 leakage: {metrics.top10_style_leakage:.2%}") + + print("\nAttribute Breakdown (Top-1):") + print(f" Same subject: {metrics.top1_subject_accuracy:.2%}") + print(f" Same field: {metrics.top1_field_accuracy:.2%}") + + +def compute_style_preconditioner( + base_path: Path | str, + config: AsymmetricConfig, +) -> Path: + """Compute R_between preconditioner that isolates the style direction. + + This creates a rank-1 preconditioner from the difference in style means. + When used for scoring, this should downweight the style direction. + + Args: + base_path: Base path for experiment outputs. + config: Experiment configuration. + + Returns: + Path to the preconditioner. + """ + import json + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + output_path = base_path / "r_between" + + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached R_between from {output_path}") + return output_path + + print("Computing R_between preconditioner from style means...") + + # Load training data and gradients + train_ds = load_from_disk(str(data_path / "train.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + + train_styles = train_ds["style"] + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Separate indices by style + dominant_indices = [ + i for i, s in enumerate(train_styles) if s == config.dominant_style + ] + minority_indices = [ + i for i, s in enumerate(train_styles) if s == config.minority_style + ] + + print(f" {config.dominant_style}: {len(dominant_indices)} samples") + print(f" {config.minority_style}: {len(minority_indices)} samples") + + # Load a processor to get metadata + base_proc = GradientProcessor.load(index_path) + + # Compute per-module rank-1 preconditioners + between_precs = {} + print(f" Computing per-module R_between for {len(module_names)} modules...") + + for name in tqdm(module_names): + g_all = torch.from_numpy(_load_gradients_as_float(train_grads, name)) + + # Get style-specific gradients + g_dominant = g_all[dominant_indices] + g_minority = g_all[minority_indices] + + # Compute means + mu_dominant = g_dominant.mean(dim=0) + mu_minority = g_minority.mean(dim=0) + + # Style direction + delta = mu_dominant - mu_minority + + # Rank-1 preconditioner: outer product + between_precs[name] = torch.outer(delta, delta) + + # Save + output_path.mkdir(parents=True, exist_ok=True) + between_proc = GradientProcessor( + normalizers=base_proc.normalizers, + preconditioners=between_precs, + preconditioners_eigen={}, + projection_dim=base_proc.projection_dim, + projection_type=base_proc.projection_type, + include_bias=base_proc.include_bias, + ) + between_proc.save(output_path) + print(f"Saved R_between preconditioner to {output_path}") + + return output_path + + +def score_asymmetric_eval_with_pca_projection( + config: AsymmetricConfig, + base_path: Path | str, + style_subspace: dict[str, tuple], + top_k: int = 10, + preconditioner_name: str | None = None, + damping_factor: float = 0.1, + eval_prompt_column: str = "fact", + eval_completion_column: str = "reworded", +) -> np.ndarray: + """Score eval queries using PCA projection to remove style direction. + + Instead of using matrix-inverse preconditioning, this projects eval gradients + onto the orthogonal complement of the style subspace before computing scores. + Can optionally combine with a preconditioner applied after projection. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + style_subspace: Dictionary from compute_pca_style_subspace(). + top_k: Number of principal components used (for cache naming). + preconditioner_name: Optional preconditioner to apply after projection. + damping_factor: Damping factor for matrix inversion (default: 0.1). + eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). + eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). + Set to "question"/"answer" for semantic-only attribution. + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + from .preconditioners import project_orthogonal_to_style_subspace + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Build cache path including preconditioner, damping factor, and eval columns + damping_suffix = f"_d{damping_factor:.0e}" if damping_factor != 0.1 else "" + eval_col_suffix = "" + if eval_prompt_column != "fact" or eval_completion_column != "reworded": + eval_col_suffix = f"_{eval_prompt_column}_{eval_completion_column}" + if preconditioner_name: + scores_path = ( + base_path + / f"scores_pca_k{top_k}_{preconditioner_name}{damping_suffix}{eval_col_suffix}" + ) + precond_path = base_path / preconditioner_name + else: + scores_path = ( + base_path / f"scores_pca_k{top_k}{damping_suffix}{eval_col_suffix}" + ) + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load train and eval datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + n_train = len(train_ds) + n_eval = len(eval_ds) + + print( + f"Scoring {n_eval} eval queries against {n_train} train samples (PCA projection)" + ) + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path} (damping={damping_factor})") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H, damping_factor=damping_factor) + + # Concatenate train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize train grads + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Compute eval gradients if needed + print("Computing eval gradients...") + if eval_prompt_column == "fact" and eval_completion_column == "reworded": + eval_grads_path = base_path / "eval_grads" + else: + eval_grads_path = ( + base_path / f"eval_grads_{eval_prompt_column}_{eval_completion_column}" + ) + + if not eval_grads_path.exists(): + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + cmd = [ + "bergson", + "build", + str(eval_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval.hf"), + "--drop_columns", + "False", + "--prompt_column", + eval_prompt_column, + "--completion_column", + eval_completion_column, + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for eval failed") + print(result.stdout) + + # Load eval gradients and apply PCA projection + eval_grads = load_gradients(eval_grads_path, structured=True) + eval_grad_list = [] + + # Track cumulative dimension for concatenation + for name in tqdm(module_names, desc="Loading and projecting eval grads"): + g = torch.from_numpy(_load_gradients_as_float(eval_grads, name)) + + # Apply PCA projection if we have the subspace for this module + if name in style_subspace: + eigvecs, _ = style_subspace[name] + g = g.cuda() + eigvecs = eigvecs.cuda() + g = project_orthogonal_to_style_subspace(g, eigvecs) + # Apply preconditioning after projection if specified + if h_inv: + g = g @ h_inv[name] + g = g.cpu() + elif h_inv: + # Apply preconditioning even without PCA projection + g = (g.cuda() @ h_inv[name]).cpu() + + eval_grad_list.append(g) + + eval_grad_tensor = torch.cat(eval_grad_list, dim=1) + + # Unit normalize eval grads + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + # Save scores + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def compute_asymmetric_metrics_with_pca( + config: AsymmetricConfig, + base_path: Path | str, + style_subspace: dict[str, tuple], + top_k: int = 10, + preconditioner_name: str | None = None, + damping_factor: float = 0.1, + eval_prompt_column: str = "fact", + eval_completion_column: str = "reworded", +) -> "AsymmetricMetrics": + """Compute metrics using PCA projection style suppression. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + style_subspace: Dictionary from compute_pca_style_subspace(). + top_k: Number of principal components. + preconditioner_name: Optional preconditioner to combine with PCA. + damping_factor: Damping factor for matrix inversion (default: 0.1). + eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). + eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). + Set to "question"/"answer" for semantic-only attribution. + + Returns: + AsymmetricMetrics dataclass. + """ + base_path = Path(base_path) + data_path = base_path / "data" + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Load scores (using PCA projection) + scores = score_asymmetric_eval_with_pca_projection( + config, + base_path, + style_subspace, + top_k, + preconditioner_name, + damping_factor=damping_factor, + eval_prompt_column=eval_prompt_column, + eval_completion_column=eval_completion_column, + ) + + n_eval = len(eval_ds) + + # Extract metadata + train_styles = train_ds["style"] + train_identifiers = train_ds["identifier"] + train_fields = train_ds["field"] + + eval_identifiers = eval_ds["identifier"] + eval_fields = eval_ds["field"] + + # Get top-k indices for each query + top_k_results = 10 + top_indices = np.argsort(-scores, axis=1)[:, :top_k_results] + + # Compute metrics (same logic as compute_asymmetric_metrics) + semantic_top1 = 0 + semantic_top5 = 0 + semantic_top10 = 0 + style_leak_top1 = 0 + style_leak_top5 = 0 + style_leak_top10 = 0 + subject_top1 = 0 + field_top1 = 0 + + for i in range(n_eval): + query_identifier = eval_identifiers[i] + query_field = eval_fields[i] + + top_k_idx = top_indices[i] + + # Check semantic matching (same identifier AND field = same underlying fact) + for k, idx in enumerate(top_k_idx): + if ( + train_identifiers[idx] == query_identifier + and train_fields[idx] == query_field + ): + if k == 0: + semantic_top1 += 1 + if k < 5: + semantic_top5 += 1 + break + if k < 10: + semantic_top10 += 1 + break + + # Check style leakage + top1_style = train_styles[top_k_idx[0]] + if top1_style == config.minority_style: + style_leak_top1 += 1 + + top5_minority = sum( + 1 for idx in top_k_idx[:5] if train_styles[idx] == config.minority_style + ) + style_leak_top5 += top5_minority / 5 + + top10_minority = sum( + 1 for idx in top_k_idx[:10] if train_styles[idx] == config.minority_style + ) + style_leak_top10 += top10_minority / 10 + + # Check attribute matching for top-1 + top1_idx = top_k_idx[0] + if train_identifiers[top1_idx] == query_identifier: + subject_top1 += 1 + if train_fields[top1_idx] == query_field: + field_top1 += 1 + + return AsymmetricMetrics( + top1_semantic_accuracy=semantic_top1 / n_eval, + top5_semantic_recall=semantic_top5 / n_eval, + top10_semantic_recall=semantic_top10 / n_eval, + top1_style_leakage=style_leak_top1 / n_eval, + top5_style_leakage=style_leak_top5 / n_eval, + top10_style_leakage=style_leak_top10 / n_eval, + top1_subject_accuracy=subject_top1 / n_eval, + top1_field_accuracy=field_top1 / n_eval, + ) + + +def create_majority_style_eval( + config: AsymmetricConfig, + base_path: Path | str, + force_regenerate: bool = False, +) -> tuple[Path, bool]: + """Create eval set using majority style (control for style mismatch). + + Instead of using minority style queries, uses dominant style queries + for the exclusive facts. This shows baseline performance without style mismatch. + + IMPORTANT: Uses templates >= train_template_cutoff to ensure NO overlap with + training data. This provides a proper train/test split where eval majority + style items test semantic matching (same fact, different surface form) rather + than exact text matching. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + force_regenerate: If True, regenerate even if cached version exists. + + Returns: + Tuple of (path to the majority style eval dataset, has_leakage flag). + has_leakage is True if there's train/test overlap (e.g., from HF data). + """ + base_path = Path(base_path) + data_path = base_path / "data" + majority_eval_path = data_path / "eval_majority_style.hf" + + # Check for existing cached version + if majority_eval_path.exists() and not force_regenerate: + print(f"Loading cached majority style eval from {majority_eval_path}") + + # Check for train/test leakage by comparing reworded texts + train_ds = load_from_disk(str(data_path / "train.hf")) + majority_eval_ds = load_from_disk(str(majority_eval_path)) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(majority_eval_ds, DatasetDict): + majority_eval_ds = majority_eval_ds["train"] + + train_reworded = set(train_ds["reworded"]) + eval_reworded = set(majority_eval_ds["reworded"]) + overlap = train_reworded & eval_reworded + has_leakage = len(overlap) > 0 + + if has_leakage: + print( + f" WARNING: {len(overlap)}/{len(eval_reworded)} eval items have " + "exact text match in train (train/test leakage)" + ) + print(" Use force_regenerate=True with local data to fix") + + return majority_eval_path, has_leakage + + print("Creating majority style eval set (control)...") + + # Check if local styled datasets exist for proper template segregation + local_styled_path = Path( + f"data/facts_dataset_{config.dominant_style}-Qwen3-8B-Base.hf" + ) + if not local_styled_path.exists(): + print(f" WARNING: Local styled dataset not found at {local_styled_path}") + print(" Cannot create properly segregated majority eval") + print(" Using HF eval_majority_style (may have train/test leakage)") + return majority_eval_path, True # Return existing HF version with leakage flag + + # Load the minority style eval to get the semantic facts (identifier, field pairs) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Get semantic facts from eval (identifier, field pairs) + eval_semantic_facts = {(row["identifier"], row["field"]) for row in eval_ds} + + # Load dominant style dataset + dominant_ds = load_from_disk(str(local_styled_path)) + if isinstance(dominant_ds, DatasetDict): + dominant_ds = dominant_ds["train"] + + # Add back metadata columns from original + original = load_from_disk("data/facts_dataset.hf") + if isinstance(original, DatasetDict): + original = original["train"] + fact_to_meta = {row["fact"]: row for row in original} + + for col in original.column_names: + if col not in dominant_ds.column_names: + restored_col = [fact_to_meta[row["fact"]][col] for row in dominant_ds] + dominant_ds = dominant_ds.add_column(col, restored_col) + + # Select dominant style versions of eval semantic facts + # Use templates >= cutoff to ensure NO overlap with training data + dominant_eval_indices = [ + i + for i, row in enumerate(dominant_ds) + if (row["identifier"], row["field"]) in eval_semantic_facts + and row["template"] >= config.train_template_cutoff + ] + majority_eval_ds = dominant_ds.select(dominant_eval_indices) + + print( + f" Using templates >= {config.train_template_cutoff} (no overlap with train)" + ) + print(f" Found {len(majority_eval_ds)} majority style eval samples") + + # Add style columns if not present + if "style" not in majority_eval_ds.column_names: + majority_eval_ds = majority_eval_ds.add_column( + "style", [config.dominant_style] * len(majority_eval_ds) + ) + if "expected_match_style" not in majority_eval_ds.column_names: + majority_eval_ds = majority_eval_ds.add_column( + "expected_match_style", [config.dominant_style] * len(majority_eval_ds) + ) + + majority_eval_ds.save_to_disk(str(majority_eval_path)) + print(f"Saved majority style eval to {majority_eval_path}") + + return majority_eval_path, False # No leakage with proper segregation + + +def score_majority_style_eval( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> np.ndarray: + """Score majority style eval queries (control for style mismatch). + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner subdirectory (None for no precond). + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Create majority style eval if needed + _, has_leakage = create_majority_style_eval(config, base_path) + if has_leakage: + print( + " Note: Majority control may show inflated accuracy due to train/test leakage" + ) + + # Determine output path + if preconditioner_name: + scores_path = base_path / f"scores_majority_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / "scores_majority_no_precond" + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load train and eval datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval_majority_style.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + n_train = len(train_ds) + n_eval = len(eval_ds) + + print( + f"Scoring {n_eval} majority style eval queries against {n_train} train samples" + ) + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path}") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + # Concatenate train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize train grads + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Compute eval gradients for majority style + print("Computing majority style eval gradients...") + majority_eval_grads_path = base_path / "eval_grads_majority" + if not majority_eval_grads_path.exists(): + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + cmd = [ + "bergson", + "build", + str(majority_eval_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval_majority_style.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for majority eval failed") + print(result.stdout) + + # Load eval gradients + eval_grads = load_gradients(majority_eval_grads_path, structured=True) + eval_grad_list = [] + for name in tqdm(module_names, desc="Loading eval grads"): + g = torch.from_numpy(_load_gradients_as_float(eval_grads, name)) + if h_inv: + g = (g.cuda() @ h_inv[name]).cpu() + eval_grad_list.append(g) + eval_grad_tensor = torch.cat(eval_grad_list, dim=1) + + # Unit normalize eval grads + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + # Save scores + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def compute_majority_style_metrics( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> AsymmetricMetrics: + """Compute metrics for majority style eval (control). + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner to use. + + Returns: + AsymmetricMetrics dataclass. + """ + base_path = Path(base_path) + data_path = base_path / "data" + + # Create majority style eval if needed + _, _ = create_majority_style_eval(config, base_path) + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval_majority_style.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Load scores + scores = score_majority_style_eval(config, base_path, preconditioner_name) + + n_eval = len(eval_ds) + + # Extract metadata + train_styles = train_ds["style"] + train_identifiers = train_ds["identifier"] + train_fields = train_ds["field"] + + eval_identifiers = eval_ds["identifier"] + eval_fields = eval_ds["field"] + + # Get top-k indices for each query + top_k = 10 + top_indices = np.argsort(-scores, axis=1)[:, :top_k] + + # Compute metrics (same logic as compute_asymmetric_metrics) + semantic_top1 = 0 + semantic_top5 = 0 + semantic_top10 = 0 + style_leak_top1 = 0 + style_leak_top5 = 0 + style_leak_top10 = 0 + subject_top1 = 0 + field_top1 = 0 + + for i in range(n_eval): + query_identifier = eval_identifiers[i] + query_field = eval_fields[i] + + top_k_idx = top_indices[i] + + # Check semantic matching (same identifier AND field = same underlying fact) + for k, idx in enumerate(top_k_idx): + if ( + train_identifiers[idx] == query_identifier + and train_fields[idx] == query_field + ): + if k == 0: + semantic_top1 += 1 + if k < 5: + semantic_top5 += 1 + break + if k < 10: + semantic_top10 += 1 + break + + # Check style leakage - for majority style, "leakage" means NOT matching dominant + # We flip the interpretation: matching minority style would be leakage + top1_style = train_styles[top_k_idx[0]] + if top1_style == config.minority_style: + style_leak_top1 += 1 + + top5_minority = sum( + 1 for idx in top_k_idx[:5] if train_styles[idx] == config.minority_style + ) + style_leak_top5 += top5_minority / 5 + + top10_minority = sum( + 1 for idx in top_k_idx[:10] if train_styles[idx] == config.minority_style + ) + style_leak_top10 += top10_minority / 10 + + # Check attribute matching for top-1 + top1_idx = top_k_idx[0] + if train_identifiers[top1_idx] == query_identifier: + subject_top1 += 1 + if train_fields[top1_idx] == query_field: + field_top1 += 1 + + return AsymmetricMetrics( + top1_semantic_accuracy=semantic_top1 / n_eval, + top5_semantic_recall=semantic_top5 / n_eval, + top10_semantic_recall=semantic_top10 / n_eval, + top1_style_leakage=style_leak_top1 / n_eval, + top5_style_leakage=style_leak_top5 / n_eval, + top10_style_leakage=style_leak_top10 / n_eval, + top1_subject_accuracy=subject_top1 / n_eval, + top1_field_accuracy=field_top1 / n_eval, + ) + + +def score_summed_eval( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> np.ndarray: + """Score using summed eval gradients (minority + majority style for each fact). + + Instead of using minority-style eval gradients alone, this sums the gradients + from both style versions of each fact. This makes the query "style-neutral" + since style-specific components should cancel while semantic components reinforce. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner subdirectory (None for no precond). + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine output path + if preconditioner_name: + scores_path = base_path / f"scores_summed_eval_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / "scores_summed_eval_no_precond" + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load train dataset + train_ds = load_from_disk(str(data_path / "train.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + n_train = len(train_ds) + + # Load eval datasets (need both minority and majority style versions) + eval_minority_ds = load_from_disk(str(data_path / "eval.hf")) + if isinstance(eval_minority_ds, DatasetDict): + eval_minority_ds = eval_minority_ds["train"] + + # Create majority style eval if needed + _, _ = create_majority_style_eval(config, base_path) + eval_majority_ds = load_from_disk(str(data_path / "eval_majority_style.hf")) + if isinstance(eval_majority_ds, DatasetDict): + eval_majority_ds = eval_majority_ds["train"] + + n_eval = len(eval_minority_ds) + print( + f"Scoring {n_eval} summed eval queries (minority + majority) against {n_train} train samples" + ) + + # Build semantic fact mapping for alignment (identifier, field pairs) + # This works even when templates differ between minority and majority eval + minority_semantic_facts = [ + (row["identifier"], row["field"]) for row in eval_minority_ds + ] + majority_semantic_to_idx = { + (row["identifier"], row["field"]): i for i, row in enumerate(eval_majority_ds) + } + + # Verify alignment by semantic fact + assert len(eval_minority_ds) == len( + eval_majority_ds + ), "Eval datasets must have same size" + for sf in minority_semantic_facts: + assert ( + sf in majority_semantic_to_idx + ), f"Semantic fact {sf} not found in majority eval" + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path}") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + # Concatenate train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize train grads + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Ensure both eval gradient sets exist + eval_minority_grads_path = base_path / "eval_grads" + eval_majority_grads_path = base_path / "eval_grads_majority" + + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + # Build minority eval grads if needed + if not eval_minority_grads_path.exists(): + print("Computing minority style eval gradients...") + cmd = [ + "bergson", + "build", + str(eval_minority_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for minority eval failed") + print(result.stdout) + + # Build majority eval grads if needed + if not eval_majority_grads_path.exists(): + print("Computing majority style eval gradients...") + cmd = [ + "bergson", + "build", + str(eval_majority_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval_majority_style.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for majority eval failed") + print(result.stdout) + + # Load both eval gradient sets + print("Loading eval gradients (minority + majority)...") + minority_grads = load_gradients(eval_minority_grads_path, structured=True) + majority_grads = load_gradients(eval_majority_grads_path, structured=True) + + # Sum gradients: for each eval fact, sum minority + majority style gradients + # Align by semantic fact (identifier, field) since templates may differ + summed_grad_list = [] + for name in tqdm(module_names, desc="Summing eval grads"): + g_minority = torch.from_numpy(_load_gradients_as_float(minority_grads, name)) + g_majority = torch.from_numpy(_load_gradients_as_float(majority_grads, name)) + + # Align majority grads to minority semantic fact order + aligned_majority_indices = [ + majority_semantic_to_idx[sf] for sf in minority_semantic_facts + ] + g_majority_aligned = g_majority[aligned_majority_indices] + + # Sum the gradients + g_summed = g_minority + g_majority_aligned + + # Apply preconditioning if specified + if h_inv: + g_summed = (g_summed.cuda() @ h_inv[name]).cpu() + + summed_grad_list.append(g_summed) + + eval_grad_tensor = torch.cat(summed_grad_list, dim=1) + + # Unit normalize summed eval grads + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + # Save scores + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def compute_summed_eval_metrics( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> AsymmetricMetrics: + """Compute metrics for summed eval gradient approach. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner to use. + + Returns: + AsymmetricMetrics dataclass. + """ + base_path = Path(base_path) + data_path = base_path / "data" + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Load scores (computed with summed gradients) + scores = score_summed_eval(config, base_path, preconditioner_name) + + n_eval = len(eval_ds) + + # Extract metadata + train_styles = train_ds["style"] + train_identifiers = train_ds["identifier"] + train_fields = train_ds["field"] + + eval_identifiers = eval_ds["identifier"] + eval_fields = eval_ds["field"] + + # Get top-k indices for each query + top_k = 10 + top_indices = np.argsort(-scores, axis=1)[:, :top_k] + + # Compute metrics + semantic_top1 = 0 + semantic_top5 = 0 + semantic_top10 = 0 + style_leak_top1 = 0 + style_leak_top5 = 0 + style_leak_top10 = 0 + subject_top1 = 0 + field_top1 = 0 + + for i in range(n_eval): + query_identifier = eval_identifiers[i] + query_field = eval_fields[i] + + top_k_idx = top_indices[i] + + # Check semantic matching (same identifier AND field = same underlying fact) + for k, idx in enumerate(top_k_idx): + if ( + train_identifiers[idx] == query_identifier + and train_fields[idx] == query_field + ): + if k == 0: + semantic_top1 += 1 + if k < 5: + semantic_top5 += 1 + break + if k < 10: + semantic_top10 += 1 + break + + # Check style leakage (still measured against minority style) + top1_style = train_styles[top_k_idx[0]] + if top1_style == config.minority_style: + style_leak_top1 += 1 + + top5_minority = sum( + 1 for idx in top_k_idx[:5] if train_styles[idx] == config.minority_style + ) + style_leak_top5 += top5_minority / 5 + + top10_minority = sum( + 1 for idx in top_k_idx[:10] if train_styles[idx] == config.minority_style + ) + style_leak_top10 += top10_minority / 10 + + # Check attribute matching for top-1 + top1_idx = top_k_idx[0] + if train_identifiers[top1_idx] == query_identifier: + subject_top1 += 1 + if train_fields[top1_idx] == query_field: + field_top1 += 1 + + return AsymmetricMetrics( + top1_semantic_accuracy=semantic_top1 / n_eval, + top5_semantic_recall=semantic_top5 / n_eval, + top10_semantic_recall=semantic_top10 / n_eval, + top1_style_leakage=style_leak_top1 / n_eval, + top5_style_leakage=style_leak_top5 / n_eval, + top10_style_leakage=style_leak_top10 / n_eval, + top1_subject_accuracy=subject_top1 / n_eval, + top1_field_accuracy=field_top1 / n_eval, + ) + + +def sweep_pca_k( + config: AsymmetricConfig | None = None, + base_path: Path | str = "runs/asymmetric_style", + k_values: list[int] | None = None, + preconditioners: list[str | None] | None = None, +) -> dict[str, AsymmetricMetrics]: + """Sweep over k values and preconditioners for PCA projection approach. + + Args: + config: Experiment configuration (uses defaults if None). + base_path: Base path for experiment outputs. + k_values: List of k values to test (default: [1, 5, 10, 20, 50, 100]). + preconditioners: List of preconditioner names to combine with PCA. + None means no preconditioner. Default: [None, "index"]. + + Returns: + Dictionary mapping strategy names to their metrics. + """ + from .preconditioners import compute_pca_style_subspace + + if config is None: + config = AsymmetricConfig() + + base_path = Path(base_path) + + if k_values is None: + k_values = [1, 5, 10, 20, 50, 100] + + if preconditioners is None: + preconditioners = [ + None, + "index", + ] # None = no precond, "index" = train second moment + + # Check that style indices exist + pirate_idx = Path("runs/precond_comparison/pirate") + shakespeare_idx = Path("runs/precond_comparison/shakespeare") + if not (pirate_idx.exists() and shakespeare_idx.exists()): + raise FileNotFoundError( + f"Style-specific indices not found at {pirate_idx} and {shakespeare_idx}. " + "Run build_style_indices() first." + ) + + all_metrics: dict[str, AsymmetricMetrics] = {} + + # Compute style subspaces for each k value + print("=" * 70) + print("PCA K-VALUE AND PRECONDITIONER SWEEP") + print("=" * 70) + + for k in k_values: + print(f"\n--- Computing style subspace for k={k} ---") + style_subspace = compute_pca_style_subspace( + pirate_idx, shakespeare_idx, base_path / "pca_subspace", top_k=k + ) + + for precond_name in preconditioners: + precond_display = precond_name if precond_name else "no_precond" + strategy_name = f"pca_k{k}_{precond_display}" + + print(f"\n--- Strategy: {strategy_name} ---") + metrics = compute_asymmetric_metrics_with_pca( + config, + base_path, + style_subspace, + top_k=k, + preconditioner_name=precond_name, + ) + print(f" Top-1 Semantic: {metrics.top1_semantic_accuracy:.2%}") + print(f" Top-1 Style Leak: {metrics.top1_style_leakage:.2%}") + all_metrics[strategy_name] = metrics + + # Print summary + print("\n" + "=" * 70) + print("SWEEP SUMMARY") + print("=" * 70) + + print(f"\n{'Strategy':<30} {'Top-1 Semantic':<15} {'Top-1 Style Leak':<17}") + print("-" * 65) + + for name, m in sorted(all_metrics.items()): + print( + f"{name:<30} {m.top1_semantic_accuracy:<15.2%} {m.top1_style_leakage:<17.2%}" + ) + + return all_metrics + + +def run_asymmetric_experiment( + config: AsymmetricConfig | None = None, + base_path: Path | str = "runs/asymmetric_style", + analysis_model: str | None = None, + include_pca: bool = True, + pca_top_k: int = 10, + include_summed_loss: bool = True, + include_second_moments: bool = True, + include_majority_control: bool = True, + include_summed_eval: bool = True, + include_semantic_eval: bool = True, + damping_factor: float = 0.1, +) -> dict[str, AsymmetricMetrics]: + """Run the full asymmetric style suppression experiment. + + Compares attribution quality with and without style suppression preconditioning. + + Args: + config: Experiment configuration (uses defaults if None). Set config.hf_dataset + to load data from HuggingFace instead of generating locally. + base_path: Base path for experiment outputs. + analysis_model: Model to use for gradient collection. Defaults to HF_ANALYSIS_MODEL. + include_pca: Whether to include PCA projection strategy. + pca_top_k: Number of principal components for PCA projection. + include_summed_loss: Whether to include summed loss preconditioner strategy. + include_second_moments: Whether to include train/eval/mixed second moment strategies. + include_majority_control: Whether to include majority style eval as control. + include_summed_eval: Whether to include summed eval gradient approach (minority + majority). + include_semantic_eval: Whether to include semantic-only eval using question/answer columns. + This tests attribution when gradients only come from the semantic content (answer tokens), + ignoring style in the eval query entirely. + damping_factor: Damping factor for matrix inversion (default: 0.1). + + Returns: + Dictionary mapping preconditioner names to their metrics. + """ + from .preconditioners import ( + compute_eval_preconditioner, + compute_pca_style_subspace, + compute_summed_loss_preconditioner, + compute_train_eval_mixed_preconditioner, + ) + + if config is None: + config = AsymmetricConfig() + + base_path = Path(base_path) + + print("=" * 70) + print("ASYMMETRIC STYLE SUPPRESSION EXPERIMENT") + print("=" * 70) + print("\nConfiguration:") + print(f" Dominant style: {config.dominant_style} ({config.dominant_ratio:.0%})") + print(f" Minority style: {config.minority_style}") + print(f" Exclusive facts: {config.exclusive_ratio:.0%}") + + # Step 1: Create dataset and index + print("\n" + "-" * 60) + print("STEP 1: Creating asymmetric dataset and index") + print("-" * 60) + create_asymmetric_index(config, base_path, analysis_model) + + # Step 2: Compute R_between preconditioner + print("\n" + "-" * 60) + print("STEP 2: Computing style suppression preconditioner (R_between)") + print("-" * 60) + compute_style_preconditioner(base_path, config) + + # Step 2b: Compute summed loss preconditioner if requested + summed_loss_proc = None + if include_summed_loss: + print("\n" + "-" * 60) + print("STEP 2b: Computing summed loss preconditioner") + print("-" * 60) + # We need the style-specific indices for this + # Use the runs/precond_comparison indices if they exist + pirate_idx = Path("runs/precond_comparison/pirate") + shakespeare_idx = Path("runs/precond_comparison/shakespeare") + if pirate_idx.exists() and shakespeare_idx.exists(): + summed_loss_path = base_path / "summed_loss" + compute_summed_loss_preconditioner( + pirate_idx, shakespeare_idx, summed_loss_path + ) + else: + print( + " Style-specific indices not found, skipping summed loss preconditioner" + ) + print(f" (Expected: {pirate_idx} and {shakespeare_idx})") + include_summed_loss = False + + # Step 2c: Compute PCA style subspace if requested + style_subspace = None + if include_pca: + print("\n" + "-" * 60) + print(f"STEP 2c: Computing PCA style subspace (top_k={pca_top_k})") + print("-" * 60) + pirate_idx = Path("runs/precond_comparison/pirate") + shakespeare_idx = Path("runs/precond_comparison/shakespeare") + if pirate_idx.exists() and shakespeare_idx.exists(): + style_subspace = compute_pca_style_subspace( + pirate_idx, shakespeare_idx, base_path / "pca_subspace", top_k=pca_top_k + ) + else: + print(" Style-specific indices not found, skipping PCA projection") + print(f" (Expected: {pirate_idx} and {shakespeare_idx})") + include_pca = False + + # Step 2d: Compute second moment preconditioners if requested + if include_second_moments: + print("\n" + "-" * 60) + print("STEP 2d: Computing second moment preconditioners (train/eval/mixed)") + print("-" * 60) + + index_path = base_path / "index" + eval_grads_path = base_path / "eval_grads" + + # Note: train second moment is already computed during index build + # We just need to use it directly from the index + + # Compute eval second moment + if eval_grads_path.exists(): + compute_eval_preconditioner( + eval_grads_path, + base_path / "eval_second_moment", + reference_proc_path=index_path, # Use train index for metadata + ) + + # Compute 50:50 train-eval mixed + compute_train_eval_mixed_preconditioner( + index_path, + eval_grads_path, + base_path / "train_eval_mixed", + train_weight=0.5, + ) + else: + print(" Eval grads not found, will compute during scoring") + include_second_moments = False + + # Step 3: Score and evaluate with each strategy + print("\n" + "-" * 60) + print("STEP 3: Evaluating preconditioner strategies") + print("-" * 60) + + # Basic strategies using matrix-inverse preconditioning + strategies = [ + (None, "no_precond"), + ("r_between", "r_between"), + ] + + # Add summed loss if available + if include_summed_loss: + strategies.append(("summed_loss", "summed_loss")) + + # Add second moment strategies if available + if include_second_moments: + # Train second moment (use the index's preconditioner directly) + strategies.append(("index", "train_second_moment")) + # Eval second moment + strategies.append(("eval_second_moment", "eval_second_moment")) + # 50:50 train-eval mixed + strategies.append(("train_eval_mixed", "train_eval_mixed")) + + all_metrics: dict[str, AsymmetricMetrics] = {} + + for precond_name, display_name in strategies: + print(f"\n--- Strategy: {display_name} ---") + metrics = compute_asymmetric_metrics( + config, base_path, precond_name, damping_factor=damping_factor + ) + print_metrics(metrics, display_name) + all_metrics[display_name] = metrics + + # Evaluate PCA projection strategy (different approach - not matrix-inverse) + if include_pca and style_subspace is not None: + print(f"\n--- Strategy: pca_projection_k{pca_top_k} ---") + metrics = compute_asymmetric_metrics_with_pca( + config, + base_path, + style_subspace, + top_k=pca_top_k, + damping_factor=damping_factor, + ) + print_metrics(metrics, f"pca_projection_k{pca_top_k}") + all_metrics[f"pca_projection_k{pca_top_k}"] = metrics + + # Evaluate majority style control (no style mismatch) + if include_majority_control: + print("\n" + "-" * 60) + print("MAJORITY STYLE CONTROL (no style mismatch)") + print("-" * 60) + print("\n--- Control: majority_style_no_precond ---") + metrics = compute_majority_style_metrics(config, base_path, None) + print_metrics(metrics, "majority_no_precond") + all_metrics["majority_no_precond"] = metrics + + # Evaluate summed eval gradient approach (minority + majority style) + if include_summed_eval: + print("\n" + "-" * 60) + print("SUMMED EVAL GRADIENTS (minority + majority style)") + print("-" * 60) + print("\n--- Strategy: summed_eval_no_precond ---") + metrics = compute_summed_eval_metrics(config, base_path, None) + print_metrics(metrics, "summed_eval") + all_metrics["summed_eval"] = metrics + + # Evaluate semantic-only approach (question/answer columns - gradients only from answer) + if include_semantic_eval: + print("\n" + "-" * 60) + print("SEMANTIC-ONLY EVAL (gradients only from answer tokens)") + print("-" * 60) + + # Standard influence function approach: semantic mask + H_train preconditioner + # This is the "correct" way to compute influence functions + print("\n--- Strategy: semantic_index (standard IF with H_train) ---") + metrics = compute_asymmetric_metrics( + config, + base_path, + "index", # H_train - the standard IF preconditioner + damping_factor=damping_factor, + eval_prompt_column="question", + eval_completion_column="answer", + ) + print_metrics(metrics, "semantic_index") + all_metrics["semantic_index"] = metrics + + print("\n--- Strategy: semantic_no_precond ---") + metrics = compute_asymmetric_metrics( + config, + base_path, + None, + damping_factor=damping_factor, + eval_prompt_column="question", + eval_completion_column="answer", + ) + print_metrics(metrics, "semantic_no_precond") + all_metrics["semantic_no_precond"] = metrics + + # Also try with r_between preconditioner + print("\n--- Strategy: semantic_r_between ---") + metrics = compute_asymmetric_metrics( + config, + base_path, + "r_between", + damping_factor=damping_factor, + eval_prompt_column="question", + eval_completion_column="answer", + ) + print_metrics(metrics, "semantic_r_between") + all_metrics["semantic_r_between"] = metrics + + # Print summary comparison + print("\n" + "=" * 70) + print("SUMMARY COMPARISON") + print("=" * 70) + + print(f"\n{'Strategy':<25} {'Top-1 Semantic':<15} {'Top-1 Style Leak':<17}") + print("-" * 60) + + for name, m in all_metrics.items(): + print( + f"{name:<25} {m.top1_semantic_accuracy:<15.2%} {m.top1_style_leakage:<17.2%}" + ) + + print("\n" + "=" * 70) + print("INTERPRETATION") + print("=" * 70) + print("\nSuccess criteria:") + print(" - Higher semantic accuracy = preconditioner helps find correct facts") + print(" - Lower style leakage = preconditioner reduces style matching") + print("\nStrategies:") + print(" - no_precond: Baseline without any style suppression") + print(" - r_between: Rank-1 preconditioner from style mean difference") + if include_summed_loss: + print(" - summed_loss: Preconditioner from summed gradients across pairs") + if include_second_moments: + print(" - train_second_moment: Second moment matrix from train gradients") + print(" - eval_second_moment: Second moment matrix from eval gradients") + print(" - train_eval_mixed: 50:50 mixture of train and eval second moments") + if include_pca: + print(f" - pca_projection_k{pca_top_k}: Project out top-{pca_top_k} style PCs") + if include_majority_control: + print( + " - majority_no_precond: Control using majority style for eval (no mismatch)" + ) + if include_summed_eval: + print( + " - summed_eval: Sum minority + majority style eval gradients (style-neutral query)" + ) + if include_semantic_eval: + print( + " - semantic_*: Eval gradients only from answer tokens (question/answer format)" + ) + print( + " Tests if attribution works when query has no style information at all" + ) + + return all_metrics + + +# ============================================================================= +# RAW INNER PRODUCT COMPARISON +# ============================================================================= +# Compare cosine similarity vs raw inner product scoring + + +def score_with_inner_product( + config: AsymmetricConfig, + base_path: Path | str, + eval_style: str = "minority", + preconditioner_name: str | None = None, +) -> np.ndarray: + """Score using raw inner product (no unit normalization). + + This matches bergson's default behavior where unit_normalize=False. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + eval_style: Which eval set to use ("minority", "majority", "summed"). + preconditioner_name: Name of preconditioner subdirectory. + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine output path + suffix = f"_innerproduct_{eval_style}" + if preconditioner_name: + scores_path = base_path / f"scores{suffix}_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / f"scores{suffix}_no_precond" + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path}") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + # Concatenate train gradients - NO NORMALIZATION + print("Preparing train gradients (no normalization)...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1).cuda() + + # Load eval gradients based on style + if eval_style == "minority": + eval_grads_path = base_path / "eval_grads" + elif eval_style == "majority": + eval_grads_path = base_path / "eval_grads_majority" + elif eval_style == "summed": + # Need to sum minority + majority + minority_grads = load_gradients(base_path / "eval_grads", structured=True) + majority_grads = load_gradients( + base_path / "eval_grads_majority", structured=True + ) + + # Load eval datasets for alignment + eval_minority_ds = load_from_disk(str(data_path / "eval.hf")) + eval_majority_ds = load_from_disk(str(data_path / "eval_majority_style.hf")) + if isinstance(eval_minority_ds, DatasetDict): + eval_minority_ds = eval_minority_ds["train"] + if isinstance(eval_majority_ds, DatasetDict): + eval_majority_ds = eval_majority_ds["train"] + + # Use semantic fact alignment (identifier, field) since templates may differ + minority_semantic_facts = [ + (row["identifier"], row["field"]) for row in eval_minority_ds + ] + majority_semantic_to_idx = { + (row["identifier"], row["field"]): i + for i, row in enumerate(eval_majority_ds) + } + + summed_grad_list = [] + for name in tqdm(module_names, desc="Summing eval grads"): + g_minority = torch.from_numpy( + _load_gradients_as_float(minority_grads, name) + ) + g_majority = torch.from_numpy( + _load_gradients_as_float(majority_grads, name) + ) + + aligned_majority_indices = [ + majority_semantic_to_idx[sf] for sf in minority_semantic_facts + ] + g_majority_aligned = g_majority[aligned_majority_indices] + + g_summed = g_minority + g_majority_aligned + + if h_inv: + g_summed = (g_summed.cuda() @ h_inv[name]).cpu() + + summed_grad_list.append(g_summed) + + eval_grad_tensor = torch.cat(summed_grad_list, dim=1).cuda() + + # NO NORMALIZATION - raw inner product + print("Computing scores (raw inner product)...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + return scores + else: + raise ValueError(f"Unknown eval_style: {eval_style}") + + # Load eval gradients + print(f"Loading {eval_style} eval gradients...") + eval_grads = load_gradients(eval_grads_path, structured=True) + + eval_grad_list = [] + for name in tqdm(module_names, desc="Loading eval grads"): + g = torch.from_numpy(_load_gradients_as_float(eval_grads, name)) + if h_inv: + g = (g.cuda() @ h_inv[name]).cpu() + eval_grad_list.append(g) + + eval_grad_tensor = torch.cat(eval_grad_list, dim=1).cuda() + + # NO NORMALIZATION - raw inner product + print("Computing scores (raw inner product)...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def run_inner_product_comparison( + config: AsymmetricConfig | None = None, + base_path: Path | str = "runs/asymmetric_style", +) -> dict[str, "AsymmetricMetrics"]: + """Compare key strategies using raw inner product instead of cosine similarity. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + + Returns: + Dictionary mapping strategy names to their metrics. + """ + if config is None: + config = AsymmetricConfig() + + base_path = Path(base_path) + data_path = base_path / "data" + + print("=" * 70) + print("INNER PRODUCT VS COSINE SIMILARITY COMPARISON") + print("=" * 70) + print("\nRunning key strategies with raw inner product (bergson default)") + print() + + # Load datasets for metrics computation + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + n_eval = len(eval_ds) + train_styles = train_ds["style"] + train_identifiers = train_ds["identifier"] + train_fields = train_ds["field"] + eval_identifiers = eval_ds["identifier"] + eval_fields = eval_ds["field"] + + def compute_metrics_from_scores(scores): + top_indices = np.argsort(-scores, axis=1)[:, :10] + + sem_top1 = sem_top5 = leak_top1 = 0 + + for i in range(n_eval): + top_k_idx = top_indices[i] + + # Check semantic matching (same identifier AND field = same underlying fact) + for k, idx in enumerate(top_k_idx): + if ( + train_identifiers[idx] == eval_identifiers[i] + and train_fields[idx] == eval_fields[i] + ): + if k == 0: + sem_top1 += 1 + if k < 5: + sem_top5 += 1 + break + + if train_styles[top_k_idx[0]] == config.minority_style: + leak_top1 += 1 + + return AsymmetricMetrics( + top1_semantic_accuracy=sem_top1 / n_eval, + top5_semantic_recall=sem_top5 / n_eval, + top10_semantic_recall=0, # Not computed + top1_style_leakage=leak_top1 / n_eval, + top5_style_leakage=0, + top10_style_leakage=0, + top1_subject_accuracy=0, + top1_field_accuracy=0, + ) + + all_metrics = {} + + strategies = [ + ("minority_no_precond", "minority", None), + ("majority_no_precond", "majority", None), + ("summed_no_precond", "summed", None), + ("minority_index", "minority", "index"), + ] + + for name, eval_style, precond in strategies: + print(f"\n--- Strategy: {name} (inner product) ---") + scores = score_with_inner_product(config, base_path, eval_style, precond) + metrics = compute_metrics_from_scores(scores) + print(f" Top-1 Semantic: {metrics.top1_semantic_accuracy:.2%}") + print(f" Top-1 Style Leak: {metrics.top1_style_leakage:.2%}") + all_metrics[f"ip_{name}"] = metrics + + # Print comparison summary + print("\n" + "=" * 70) + print("COMPARISON: Inner Product vs Cosine Similarity") + print("=" * 70) + + print(f"\n{'Strategy':<30} {'Cosine Top-1':<15} {'InnerProd Top-1':<15}") + print("-" * 60) + + # Load cosine results for comparison + cosine_results = { + "minority_no_precond": 0.87, + "majority_no_precond": 100.0, + "summed_no_precond": 92.71, + "minority_index": 1.04, + } + + for name, _, _ in strategies: + cosine = cosine_results.get(name, 0) + ip = all_metrics[f"ip_{name}"].top1_semantic_accuracy * 100 + print(f"{name:<30} {cosine:<15.2f}% {ip:<15.2f}%") + + return all_metrics + + +# ============================================================================= +# REWRITE ABLATION EXPERIMENT +# ============================================================================= +# Test: what if we sum two rewrite styles (shakespeare + pirate) that are both +# different from training? This tests whether summed_eval works because of +# general style cancellation or because one component matches training. + + +def create_original_style_eval( + config: AsymmetricConfig, + base_path: Path | str, +) -> Path: + """Create eval set using original un-stylized facts. + + Creates a dataset where both prompt and completion use the original fact text + (no stylization). This represents the "true" eval data before any rewriting. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + + Returns: + Path to the original style eval dataset. + """ + base_path = Path(base_path) + data_path = base_path / "data" + original_eval_path = data_path / "eval_original_style.hf" + + if original_eval_path.exists(): + print(f"Loading cached original style eval from {original_eval_path}") + return original_eval_path + + print("Creating original style eval set...") + + # Load the minority style eval to get the facts we need + eval_ds = load_from_disk(str(data_path / "eval.hf")) + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + eval_facts = list(eval_ds["fact"]) + + # Load original facts dataset to get metadata + original = load_from_disk("data/facts_dataset.hf") + if isinstance(original, DatasetDict): + original = original["train"] + fact_to_row = {row["fact"]: row for row in original} + + # Build original style eval dataset (fact = reworded = original text) + rows = [] + for fact in eval_facts: + if fact not in fact_to_row: + print(f"Warning: fact not found in original dataset: {fact[:50]}...") + continue + row = dict(fact_to_row[fact]) + row["reworded"] = fact # Use original fact as "reworded" too + row["style"] = "original" + row["expected_match_style"] = config.dominant_style + rows.append(row) + + original_eval_ds = Dataset.from_list(rows) + original_eval_ds.save_to_disk(str(original_eval_path)) + print( + f"Saved original style eval ({len(original_eval_ds)} samples) to {original_eval_path}" + ) + + return original_eval_path + + +def create_pirate_style_eval( + config: AsymmetricConfig, + base_path: Path | str, +) -> Path: + """Create eval set using pirate style for the exclusive facts. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + + Returns: + Path to the pirate style eval dataset. + """ + base_path = Path(base_path) + data_path = base_path / "data" + pirate_eval_path = data_path / "eval_pirate_style.hf" + + if pirate_eval_path.exists(): + print(f"Loading cached pirate style eval from {pirate_eval_path}") + return pirate_eval_path + + print("Creating pirate style eval set...") + + # Load the minority style eval to get the facts we need + eval_ds = load_from_disk(str(data_path / "eval.hf")) + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + eval_facts = set(eval_ds["fact"]) + + # Load pirate dataset + pirate_ds = load_from_disk("data/facts_dataset_pirate-Qwen3-8B-Base.hf") + if isinstance(pirate_ds, DatasetDict): + pirate_ds = pirate_ds["train"] + + # Add back metadata columns from original + original = load_from_disk("data/facts_dataset.hf") + if isinstance(original, DatasetDict): + original = original["train"] + fact_to_meta = {row["fact"]: row for row in original} + + for col in original.column_names: + if col not in pirate_ds.column_names: + restored_col = [fact_to_meta[row["fact"]][col] for row in pirate_ds] + pirate_ds = pirate_ds.add_column(col, restored_col) + + # Select only the exclusive facts (same facts as in minority eval) + pirate_eval_indices = [ + i for i, row in enumerate(pirate_ds) if row["fact"] in eval_facts + ] + pirate_eval_ds = pirate_ds.select(pirate_eval_indices) + + # Add style columns + pirate_eval_ds = pirate_eval_ds.add_column( + "style", ["pirate"] * len(pirate_eval_ds) + ) + pirate_eval_ds = pirate_eval_ds.add_column( + "expected_match_style", [config.dominant_style] * len(pirate_eval_ds) + ) + + pirate_eval_ds.save_to_disk(str(pirate_eval_path)) + print( + f"Saved pirate style eval ({len(pirate_eval_ds)} samples) to {pirate_eval_path}" + ) + + return pirate_eval_path + + +def score_summed_rewrites( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> np.ndarray: + """Score using summed rewrite gradients (shakespeare + pirate). + + Tests whether summing two different rewrite styles helps with style invariance, + even when neither rewrite matches the training distribution. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner subdirectory (None for no precond). + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine output path + if preconditioner_name: + scores_path = base_path / f"scores_summed_rewrites_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / "scores_summed_rewrites_no_precond" + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load train dataset + train_ds = load_from_disk(str(data_path / "train.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + n_train = len(train_ds) + + # Create shakespeare and pirate eval datasets if needed + # Shakespeare is already in eval.hf (minority style) + create_pirate_style_eval(config, base_path) + + # Load eval datasets + shakespeare_eval_ds = load_from_disk(str(data_path / "eval.hf")) + if isinstance(shakespeare_eval_ds, DatasetDict): + shakespeare_eval_ds = shakespeare_eval_ds["train"] + + pirate_eval_ds = load_from_disk(str(data_path / "eval_pirate_style.hf")) + if isinstance(pirate_eval_ds, DatasetDict): + pirate_eval_ds = pirate_eval_ds["train"] + + n_eval = len(shakespeare_eval_ds) + print( + f"Scoring {n_eval} summed rewrite queries (shakespeare + pirate) against {n_train} train" + ) + + # Build fact-to-index mapping for alignment + shakespeare_facts = shakespeare_eval_ds["fact"] + pirate_facts = pirate_eval_ds["fact"] + pirate_fact_to_idx = {f: i for i, f in enumerate(pirate_facts)} + + # Verify alignment + for f in shakespeare_facts: + assert f in pirate_fact_to_idx, f"Fact {f} not found in pirate eval" + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path}") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + # Concatenate train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize train grads + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Build gradient paths + shakespeare_grads_path = base_path / "eval_grads" # minority = shakespeare + pirate_grads_path = base_path / "eval_grads_pirate" + + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + # Build shakespeare eval grads if needed + if not shakespeare_grads_path.exists(): + print("Computing shakespeare style eval gradients...") + cmd = [ + "bergson", + "build", + str(shakespeare_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for shakespeare eval failed") + print(result.stdout) + + # Build pirate eval grads if needed + if not pirate_grads_path.exists(): + print("Computing pirate style eval gradients...") + cmd = [ + "bergson", + "build", + str(pirate_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval_pirate_style.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for pirate eval failed") + print(result.stdout) + + # Load both eval gradient sets + print("Loading eval gradients (shakespeare + pirate)...") + shakespeare_grads = load_gradients(shakespeare_grads_path, structured=True) + pirate_grads = load_gradients(pirate_grads_path, structured=True) + + # Sum gradients: for each eval fact, sum shakespeare + pirate style gradients + summed_grad_list = [] + for name in tqdm(module_names, desc="Summing rewrite grads"): + g_shakespeare = torch.from_numpy( + _load_gradients_as_float(shakespeare_grads, name) + ) + g_pirate = torch.from_numpy(_load_gradients_as_float(pirate_grads, name)) + + # Align pirate grads to shakespeare fact order + aligned_pirate_indices = [pirate_fact_to_idx[f] for f in shakespeare_facts] + g_pirate_aligned = g_pirate[aligned_pirate_indices] + + # Sum the gradients + g_summed = g_shakespeare + g_pirate_aligned + + # Apply preconditioning if specified + if h_inv: + g_summed = (g_summed.cuda() @ h_inv[name]).cpu() + + summed_grad_list.append(g_summed) + + eval_grad_tensor = torch.cat(summed_grad_list, dim=1) + + # Unit normalize summed eval grads + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + # Save scores + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def score_original_style_eval( + config: AsymmetricConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> np.ndarray: + """Score using original un-stylized eval gradients. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner subdirectory (None for no precond). + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine output path + if preconditioner_name: + scores_path = base_path / f"scores_original_style_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / "scores_original_style_no_precond" + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Create original style eval dataset if needed + create_original_style_eval(config, base_path) + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + n_train = len(train_ds) + + original_eval_ds = load_from_disk(str(data_path / "eval_original_style.hf")) + if isinstance(original_eval_ds, DatasetDict): + original_eval_ds = original_eval_ds["train"] + n_eval = len(original_eval_ds) + + print( + f"Scoring {n_eval} original style eval queries against {n_train} train samples" + ) + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path}") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + # Concatenate train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize train grads + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Build original style eval grads if needed + original_grads_path = base_path / "eval_grads_original" + + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + if not original_grads_path.exists(): + print("Computing original style eval gradients...") + cmd = [ + "bergson", + "build", + str(original_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval_original_style.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for original style eval failed") + print(result.stdout) + + # Load original eval gradients + print("Loading original style eval gradients...") + original_grads = load_gradients(original_grads_path, structured=True) + + eval_grad_list = [] + for name in tqdm(module_names, desc="Loading original eval grads"): + g = torch.from_numpy(_load_gradients_as_float(original_grads, name)) + + # Apply preconditioning if specified + if h_inv: + g = (g.cuda() @ h_inv[name]).cpu() + + eval_grad_list.append(g) + + eval_grad_tensor = torch.cat(eval_grad_list, dim=1) + + # Unit normalize eval grads + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + # Save scores + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def compute_rewrite_ablation_metrics( + config: AsymmetricConfig, + base_path: Path | str, + strategy: str, + preconditioner_name: str | None = None, +) -> "AsymmetricMetrics": + """Compute metrics for rewrite ablation strategies. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + strategy: One of "original", "summed_rewrites", "shakespeare_only", "pirate_only". + preconditioner_name: Name of preconditioner subdirectory (None for no precond). + + Returns: + AsymmetricMetrics with accuracy measurements. + """ + base_path = Path(base_path) + data_path = base_path / "data" + + # Load train dataset for ground truth mapping + train_ds = load_from_disk(str(data_path / "train.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + + # Load eval dataset (use minority style for fact mapping) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Get scores based on strategy + if strategy == "original": + scores = score_original_style_eval(config, base_path, preconditioner_name) + elif strategy == "summed_rewrites": + scores = score_summed_rewrites(config, base_path, preconditioner_name) + elif strategy == "shakespeare_only": + # Shakespeare is the minority style, use standard scoring + scores = score_asymmetric_eval(config, base_path, preconditioner_name) + elif strategy == "pirate_only": + # Score using pirate eval gradients only + scores = _score_single_style_eval( + config, base_path, "pirate", preconditioner_name + ) + else: + raise ValueError(f"Unknown strategy: {strategy}") + + n_eval = len(eval_ds) + + # Extract metadata + train_styles = train_ds["style"] + train_identifiers = train_ds["identifier"] + train_fields = train_ds["field"] + + eval_identifiers = eval_ds["identifier"] + eval_fields = eval_ds["field"] + + # Get top-k indices for each query + top_k = 10 + top_indices = np.argsort(-scores, axis=1)[:, :top_k] + + # Compute metrics + semantic_top1 = 0 + semantic_top5 = 0 + semantic_top10 = 0 + style_leak_top1 = 0 + style_leak_top5 = 0 + style_leak_top10 = 0 + subject_top1 = 0 + field_top1 = 0 + + for i in range(n_eval): + query_identifier = eval_identifiers[i] + query_field = eval_fields[i] + + top_k_idx = top_indices[i] + + # Check semantic matching (same identifier AND field = same underlying fact) + for k, idx in enumerate(top_k_idx): + if ( + train_identifiers[idx] == query_identifier + and train_fields[idx] == query_field + ): + if k == 0: + semantic_top1 += 1 + if k < 5: + semantic_top5 += 1 + break + if k < 10: + semantic_top10 += 1 + break + + # Check style leakage + top1_style = train_styles[top_k_idx[0]] + if top1_style == config.minority_style: + style_leak_top1 += 1 + + top5_minority = sum( + 1 for idx in top_k_idx[:5] if train_styles[idx] == config.minority_style + ) + style_leak_top5 += top5_minority / 5 + + top10_minority = sum( + 1 for idx in top_k_idx[:10] if train_styles[idx] == config.minority_style + ) + style_leak_top10 += top10_minority / 10 + + # Check attribute matching for top-1 + top1_idx = top_k_idx[0] + if train_identifiers[top1_idx] == query_identifier: + subject_top1 += 1 + if train_fields[top1_idx] == query_field: + field_top1 += 1 + + return AsymmetricMetrics( + top1_semantic_accuracy=semantic_top1 / n_eval, + top5_semantic_recall=semantic_top5 / n_eval, + top10_semantic_recall=semantic_top10 / n_eval, + top1_style_leakage=style_leak_top1 / n_eval, + top5_style_leakage=style_leak_top5 / n_eval, + top10_style_leakage=style_leak_top10 / n_eval, + top1_subject_accuracy=subject_top1 / n_eval, + top1_field_accuracy=field_top1 / n_eval, + ) + + +def _score_single_style_eval( + config: AsymmetricConfig, + base_path: Path | str, + style: str, + preconditioner_name: str | None = None, +) -> np.ndarray: + """Score using a single style's eval gradients. + + Helper function for scoring with pirate-only or other single styles. + """ + import json + import subprocess + + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine paths + if style == "pirate": + create_pirate_style_eval(config, base_path) + eval_path = data_path / "eval_pirate_style.hf" + grads_path = base_path / "eval_grads_pirate" + else: + raise ValueError(f"Unsupported style: {style}") + + if preconditioner_name: + scores_path = base_path / f"scores_{style}_only_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / f"scores_{style}_only_no_precond" + precond_path = None + + # Return cached if exists + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + n_train = len(train_ds) + + eval_ds = load_from_disk(str(eval_path)) + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + n_eval = len(eval_ds) + + print( + f"Scoring {n_eval} {style} style eval queries against {n_train} train samples" + ) + + # Load train gradients + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + # Prepare train gradients + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = _load_gradients_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Build eval grads if needed + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + if not grads_path.exists(): + print(f"Computing {style} style eval gradients...") + cmd = [ + "bergson", + "build", + str(grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(eval_path), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError(f"bergson build for {style} eval failed") + + # Load eval gradients + eval_grads = load_gradients(grads_path, structured=True) + eval_grad_list = [] + for name in tqdm(module_names, desc=f"Loading {style} eval grads"): + g = torch.from_numpy(_load_gradients_as_float(eval_grads, name)) + if h_inv: + g = (g.cuda() @ h_inv[name]).cpu() + eval_grad_list.append(g) + + eval_grad_tensor = torch.cat(eval_grad_list, dim=1) + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def run_rewrite_ablation_experiment( + config: AsymmetricConfig | None = None, + base_path: Path | str = "runs/asymmetric_style", +) -> dict[str, "AsymmetricMetrics"]: + """Run the rewrite ablation experiment. + + Compares: + - original: Score with un-stylized eval gradients + - summed_rewrites: Sum of shakespeare + pirate eval gradients + - shakespeare_only: Just shakespeare eval gradients (baseline) + - pirate_only: Just pirate eval gradients + - summed_eval (reference): Sum of minority + majority style (from main experiment) + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + + Returns: + Dictionary mapping strategy names to their metrics. + """ + if config is None: + config = AsymmetricConfig() + + base_path = Path(base_path) + + print("=" * 70) + print("REWRITE ABLATION EXPERIMENT") + print("=" * 70) + print("\nThis tests whether summing two different rewrite styles helps,") + print("even when neither rewrite matches the training distribution.") + print("\nSetup:") + print(f" - Training: {config.dominant_style} style (majority)") + print(" - Eval strategies:") + print(" - original: un-stylized facts") + print(" - shakespeare_only: just shakespeare rewrite") + print(" - pirate_only: just pirate rewrite") + print(" - summed_rewrites: shakespeare + pirate rewrites summed") + print(" - summed_eval (reference): minority + majority style summed") + print() + + all_metrics: dict[str, AsymmetricMetrics] = {} + + strategies = [ + ("original", "original"), + ("shakespeare_only", "shakespeare_only"), + ("pirate_only", "pirate_only"), + ("summed_rewrites", "summed_rewrites"), + ] + + for name, strategy in strategies: + print(f"\n--- Strategy: {name} ---") + metrics = compute_rewrite_ablation_metrics(config, base_path, strategy) + print(f" Top-1 Semantic: {metrics.top1_semantic_accuracy:.2%}") + print(f" Top-1 Style Leak: {metrics.top1_style_leakage:.2%}") + all_metrics[name] = metrics + + # Add summed_eval reference + print("\n--- Strategy: summed_eval (reference) ---") + summed_metrics = compute_summed_eval_metrics(config, base_path) + print(f" Top-1 Semantic: {summed_metrics.top1_semantic_accuracy:.2%}") + print(f" Top-1 Style Leak: {summed_metrics.top1_style_leakage:.2%}") + all_metrics["summed_eval_reference"] = summed_metrics + + # Print summary + print("\n" + "=" * 70) + print("REWRITE ABLATION SUMMARY") + print("=" * 70) + + print(f"\n{'Strategy':<25} {'Top-1 Semantic':<15} {'Top-1 Style Leak':<17}") + print("-" * 60) + + for name, m in all_metrics.items(): + print( + f"{name:<25} {m.top1_semantic_accuracy:<15.2%} {m.top1_style_leakage:<17.2%}" + ) + + return all_metrics + + +if __name__ == "__main__": + run_asymmetric_experiment() diff --git a/examples/semantic/attribute_preservation.py b/examples/semantic/attribute_preservation.py new file mode 100644 index 00000000..e8d06a87 --- /dev/null +++ b/examples/semantic/attribute_preservation.py @@ -0,0 +1,1785 @@ +"""Attribute Preservation Under Style Suppression Experiment. + +This module tests whether style suppression preconditioners preserve the ability +to match on content attributes (not just exact facts). This is a harder test - +we want to surgically remove style signal without damaging attribute signal. + +Key insight: Current synthetic data has largely independent facts. For a meaningful +test, we need data where attributes actually correlate or cluster. + +Design: +- Create occupational clusters (Scientists, Business, Creative) +- Each cluster has correlated attributes (institution types, degree types, etc.) +- Assign different styles to different clusters in training +- Query in "wrong" style but matching occupation +- Style suppression should preserve attribute matching +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np +from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk + +from examples.semantic.data import ( + HF_ANALYSIS_MODEL, + load_experiment_data, +) + +# ============================================================================== +# Attribute Cluster Definitions +# ============================================================================== + +# Occupational clusters with correlated attributes +OCCUPATION_CLUSTERS = { + "scientist": { + "employers": [ + "MIT", + "Stanford Research Institute", + "NASA", + "CERN", + "Caltech", + "Lawrence Berkeley Lab", + "Fermilab", + "Max Planck Institute", + "Cambridge Research", + "Oxford Physics Lab", + ], + "universities": [ + "MIT", + "Stanford University", + "Caltech", + "Princeton University", + "Harvard University", + "UC Berkeley", + "Cambridge University", + "Oxford University", + "ETH Zurich", + "Imperial College London", + ], + "degrees": [ + "PhD in Physics", + "PhD in Chemistry", + "PhD in Biology", + "MSc in Mathematics", + ], + "titles": ["Dr.", "Professor", "Research Scientist", "Principal Investigator"], + }, + "business": { + "employers": [ + "Goldman Sachs", + "JPMorgan Chase", + "McKinsey", + "Bain & Company", + "Microsoft", + "Amazon", + "Deloitte", + "PwC", + "Boston Consulting Group", + "Morgan Stanley", + ], + "universities": [ + "Harvard Business School", + "Wharton School", + "Stanford GSB", + "Columbia Business School", + "Chicago Booth", + "INSEAD", + "London Business School", + "Kellogg School", + "MIT Sloan", + "Yale School of Management", + ], + "degrees": [ + "MBA", + "MS in Finance", + "BS in Economics", + "MA in Business Administration", + ], + "titles": ["CEO", "CFO", "Managing Director", "Vice President", "Partner"], + }, + "creative": { + "employers": [ + "Netflix", + "Disney", + "Pixar", + "Warner Bros", + "Universal Studios", + "Sony Pictures", + "HBO", + "Paramount", + "DreamWorks", + "Lionsgate", + ], + "universities": [ + "USC School of Cinematic Arts", + "NYU Tisch School", + "UCLA School of Film", + "AFI Conservatory", + "CalArts", + "Parsons School of Design", + "Rhode Island School of Design", + "Pratt Institute", + "School of Visual Arts", + "Royal College of Art", + ], + "degrees": [ + "MFA in Film", + "BFA in Animation", + "MFA in Creative Writing", + "BA in Fine Arts", + ], + "titles": [ + "Director", + "Producer", + "Creative Director", + "Lead Designer", + "Showrunner", + ], + }, +} + +# Fact templates that reveal occupation through correlated attributes +FACT_TEMPLATES = { + "employer": [ + "{name} works at {value}.", + "{name} is employed by {value}.", + "{name} has been working at {value} for several years.", + "{name} currently holds a position at {value}.", + ], + "university": [ + "{name} studied at {value}.", + "{name} graduated from {value}.", + "{name} received their degree from {value}.", + "{name} is an alumnus of {value}.", + ], + "degree": [ + "{name} earned a {value}.", + "{name} holds a {value}.", + "{name} completed a {value}.", + "{name} was awarded a {value}.", + ], + "title": [ + "{name} serves as {value}.", + "{name} holds the position of {value}.", + "{name} works as a {value}.", + "{name} is a {value}.", + ], +} + +# Name pools for synthetic people +FIRST_NAMES = [ + "Alice", + "Bob", + "Carol", + "David", + "Emma", + "Frank", + "Grace", + "Henry", + "Iris", + "Jack", + "Kate", + "Leo", + "Maya", + "Noah", + "Olivia", + "Peter", + "Quinn", + "Rachel", + "Sam", + "Tara", + "Uma", + "Victor", + "Wendy", + "Xavier", + "Yara", + "Zach", +] + +LAST_NAMES = [ + "Anderson", + "Brown", + "Chen", + "Davis", + "Evans", + "Fischer", + "Garcia", + "Harris", + "Ibrahim", + "Johnson", + "Kim", + "Lee", + "Martinez", + "Nguyen", + "O'Brien", + "Patel", + "Quinn", + "Rodriguez", + "Smith", + "Taylor", + "Ueno", + "Volkov", + "Wang", + "Xavier", + "Yamamoto", + "Zhang", +] + + +@dataclass +class AttributePreservationConfig: + """Configuration for attribute preservation experiment.""" + + # Style assignment: which occupation gets which style in training + style_occupation_map: dict[str, str] = field( + default_factory=lambda: { + "scientist": "shakespeare", # Scientists in Shakespeare style + "business": "pirate", # Business in Pirate style + "creative": "shakespeare", # Creative in Shakespeare style (same as scientist) + } + ) + + # Eval: query scientists in pirate style (wrong style for this occupation) + eval_occupation: str = "scientist" + eval_style: str = "pirate" + + # Data size + people_per_occupation: int = 50 + facts_per_person: int = 4 # employer, university, degree, title + templates_per_fact: int = 2 + + seed: int = 42 + + # HuggingFace dataset repo. If set, skips local generation and downloads from HF. + hf_dataset: str | None = None + + +def generate_correlated_facts( + config: AttributePreservationConfig, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Generate synthetic facts with correlated attributes. + + Creates facts where each person belongs to an occupation cluster, and + their attributes (employer, university, degree, title) are drawn from + that cluster's pool. + + Args: + config: Experiment configuration. + + Returns: + (train_facts, eval_facts) tuple of fact dictionaries. + """ + rng = np.random.default_rng(config.seed) + + train_facts: list[dict[str, Any]] = [] + eval_facts: list[dict[str, Any]] = [] + + person_id = 0 + + for occupation, cluster_attrs in OCCUPATION_CLUSTERS.items(): + style = config.style_occupation_map[occupation] + + for _ in range(config.people_per_occupation): + # Generate a person + first_name = rng.choice(FIRST_NAMES) + last_name = rng.choice(LAST_NAMES) + name = f"{first_name} {last_name}" + + # Sample correlated attributes from this occupation's pool + employer = rng.choice(cluster_attrs["employers"]) + university = rng.choice(cluster_attrs["universities"]) + degree = rng.choice(cluster_attrs["degrees"]) + title = rng.choice(cluster_attrs["titles"]) + + attributes = { + "employer": employer, + "university": university, + "degree": degree, + "title": title, + } + + # Generate facts for each attribute + for field_name, value in attributes.items(): + templates = FACT_TEMPLATES[field_name] + selected_templates = rng.choice( + len(templates), + size=min(config.templates_per_fact, len(templates)), + replace=False, + ) + + for template_idx in selected_templates: + template = templates[template_idx] + fact_text = template.format(name=name, value=value) + + fact = { + "fact": fact_text, + "field": field_name, + "identifier": person_id, + "name": name, + "value": value, + "occupation": occupation, + "style": style, + "template": template_idx, + } + + # Determine if this fact goes to train or eval + if occupation == config.eval_occupation: + # This occupation's facts go to both: + # - Train: in the "correct" style (shakespeare) + # - Eval: in the "wrong" style (pirate) for later rewording + fact["style"] = config.style_occupation_map[occupation] + train_facts.append(fact.copy()) + + # Mark for eval (will be reworded to wrong style) + fact["style"] = config.eval_style + eval_facts.append(fact.copy()) + else: + # Other occupations only in train + train_facts.append(fact) + + person_id += 1 + + return train_facts, eval_facts + + +def create_attribute_dataset( + config: AttributePreservationConfig, + output_dir: Path | str, +) -> tuple[Dataset, Dataset]: + """Create datasets for attribute preservation experiment. + + Args: + config: Experiment configuration. + output_dir: Directory to save datasets. + + Returns: + (train_dataset, eval_dataset) tuple. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + base_train_path = output_dir / "base_train.hf" + base_eval_path = output_dir / "base_eval.hf" + + # Generate base facts (before style rewording) + if base_train_path.exists() and base_eval_path.exists(): + print(f"Loading cached base datasets from {output_dir}") + base_train = load_from_disk(str(base_train_path)) + base_eval = load_from_disk(str(base_eval_path)) + else: + print("Generating correlated facts...") + train_facts, eval_facts = generate_correlated_facts(config) + + print(f" Train facts: {len(train_facts)}") + print(f" Eval facts: {len(eval_facts)}") + + # Create datasets + base_train = Dataset.from_list(train_facts) + base_eval = Dataset.from_list(eval_facts) + + base_train.save_to_disk(str(base_train_path)) + base_eval.save_to_disk(str(base_eval_path)) + + if isinstance(base_train, DatasetDict): + base_train = base_train["train"] + if isinstance(base_eval, DatasetDict): + base_eval = base_eval["train"] + + return base_train, base_eval + + +def reword_dataset_with_style( + dataset: Dataset, + style: str, + model_name: str = "Qwen/Qwen3-8B-Base", + batch_size: int = 8, +) -> Dataset: + """Reword facts in a dataset to a specific style. + + Args: + dataset: Dataset with 'fact' column. + style: Style to apply ('shakespeare' or 'pirate'). + model_name: Model to use for rewording. + batch_size: Batch size for generation. + + Returns: + Dataset with 'fact' and 'reworded' columns. + """ + import torch + from tqdm import tqdm + from transformers import AutoModelForCausalLM, AutoTokenizer + + style_prompts = { + "shakespeare": ( + "Reword the following fact in a Shakespearean style, adding flair and poetry.\n" + "Do not include other text in your response, just the contents of the reworded fact.\n" + "Fact: {fact}\n" + "Your rewrite:" + ), + "pirate": ( + "Reword the following fact like it's coming from a pirate. Be creative!\n" + "Do not include any other text in your response, just the contents of the reworded fact.\n" + "Fact: {fact}\n" + "Your rewrite:" + ), + } + + prompt_template = style_prompts[style] + + device = "cuda:0" + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + ) + model.eval() + + new_facts = [] + new_reworded = [] + + data_list = list(dataset) + + print(f"Rewording {len(data_list)} facts to {style} style...") + + for i in tqdm(range(0, len(data_list), batch_size)): + batch_items = data_list[i : i + batch_size] + prompts = [prompt_template.format(fact=item["fact"]) for item in batch_items] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + input_len = inputs.input_ids.shape[1] + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=128, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, + temperature=0.7, + top_p=0.8, + ) + + generated_tokens = outputs[:, input_len:] + decoded_batch = tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) + + for item, output_text in zip(batch_items, decoded_batch): + new_facts.append(item["fact"]) + new_reworded.append(output_text.strip()) + + # Build new dataset with all original columns plus 'reworded' + new_data = {col: dataset[col] for col in dataset.column_names} + new_data["reworded"] = new_reworded + + return Dataset.from_dict(new_data) + + +def create_styled_datasets( + config: AttributePreservationConfig, + output_dir: Path | str, + model_name: str = "Qwen/Qwen3-8B-Base", +) -> tuple[Dataset, Dataset]: + """Create style-reworded training and eval datasets. + + Args: + config: Experiment configuration. + output_dir: Directory for outputs. + model_name: Model for rewording. + + Returns: + (styled_train, styled_eval) tuple. + """ + output_dir = Path(output_dir) + + train_path = output_dir / "train.hf" + eval_path = output_dir / "eval.hf" + + if train_path.exists() and eval_path.exists(): + print(f"Loading cached styled datasets from {output_dir}") + return load_from_disk(str(train_path)), load_from_disk(str(eval_path)) + + # Get base facts + base_train, base_eval = create_attribute_dataset(config, output_dir) + + # Group train facts by style and reword + print("\nRewording training data by style...") + styled_train_parts = [] + + for style in set(config.style_occupation_map.values()): + # Filter facts for this style + style_indices = [i for i, s in enumerate(base_train["style"]) if s == style] + if not style_indices: + continue + + style_subset = base_train.select(style_indices) + print(f" {style}: {len(style_subset)} facts") + + # Check for cached reworded data + style_cache = output_dir / f"train_{style}.hf" + if style_cache.exists(): + reworded = load_from_disk(str(style_cache)) + else: + reworded = reword_dataset_with_style(style_subset, style, model_name) + reworded.save_to_disk(str(style_cache)) + + styled_train_parts.append(reworded) + + styled_train = concatenate_datasets(styled_train_parts) + styled_train = styled_train.shuffle(seed=config.seed) + + # Reword eval data to the "wrong" style + print(f"\nRewording eval data to {config.eval_style} style...") + eval_cache = output_dir / f"eval_{config.eval_style}.hf" + if eval_cache.exists(): + styled_eval = load_from_disk(str(eval_cache)) + else: + styled_eval = reword_dataset_with_style( + base_eval, config.eval_style, model_name + ) + styled_eval.save_to_disk(str(eval_cache)) + + # Save final datasets + styled_train.save_to_disk(str(train_path)) + styled_eval.save_to_disk(str(eval_path)) + + print("\nFinal datasets:") + print(f" Train: {len(styled_train)} samples") + print(f" Eval: {len(styled_eval)} samples") + + return styled_train, styled_eval + + +def create_attribute_index( + config: AttributePreservationConfig, + base_path: Path | str, + analysis_model: str | None = None, +) -> Path: + """Create bergson index for attribute preservation training set. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + analysis_model: Model for gradient collection. Defaults to HF_ANALYSIS_MODEL. + + Returns: + Path to the created index. + """ + import subprocess + + if analysis_model is None: + analysis_model = HF_ANALYSIS_MODEL + + base_path = Path(base_path) + data_path = base_path / "data" + index_path = base_path / "index" + + # Load or create dataset + if config.hf_dataset: + # Download from HuggingFace and save locally for bergson + print(f"Loading dataset from HuggingFace: {config.hf_dataset}") + dataset_dict = load_experiment_data(hf_repo=config.hf_dataset) + data_path.mkdir(parents=True, exist_ok=True) + for split_name, split_ds in dataset_dict.items(): + split_path = data_path / f"{split_name}.hf" + if not split_path.exists(): + split_ds.save_to_disk(str(split_path)) + print(f" Saved {split_name} to {split_path}") + else: + # Generate locally with rewording + create_styled_datasets(config, data_path) + + if index_path.exists(): + print(f"Index already exists at {index_path}, skipping...") + return index_path + + cmd = [ + "bergson", + "build", + str(index_path), + "--model", + analysis_model, + "--dataset", + str(data_path / "train.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + "16", + "--token_batch_size", + "6000", + ] + + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build failed") + print(result.stdout) + + return index_path + + +@dataclass +class AttributePreservationMetrics: + """Metrics for attribute preservation experiment.""" + + # Semantic accuracy (same fact) + top1_fact_accuracy: float + top5_fact_recall: float + top10_fact_recall: float + + # Attribute preservation (same occupation cluster) + top1_occupation_accuracy: float + top5_occupation_recall: float + top10_occupation_recall: float + + # Within-occupation attribute matching + top1_same_employer_type: float # Same employer from cluster + top1_same_university_type: float # Same university from cluster + + # Style-only matches (style matches but occupation doesn't - lower is better) + top1_style_only_match: float + top5_style_only_match: float + top10_style_only_match: float + + # Per-field accuracy + top1_by_field: dict[str, float] = field(default_factory=dict) + + +def score_attribute_eval( + config: AttributePreservationConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> "np.ndarray": + """Score eval queries against training index. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner (None for no precond). + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import ml_dtypes # noqa: F401 + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine output path + if preconditioner_name: + scores_path = base_path / f"scores_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / "scores_no_precond" + precond_path = None + + # Return cached + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + n_train = len(train_ds) + n_eval = len(eval_ds) + + print(f"Scoring {n_eval} eval queries against {n_train} train samples") + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path}") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + def load_grad_as_float(grads: np.memmap, name: str) -> np.ndarray: + g = grads[name] + if g.dtype == np.dtype("|V2"): + g = g.view(ml_dtypes.bfloat16).astype(np.float32) + return g + + # Prepare train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = load_grad_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Compute eval gradients + print("Computing eval gradients...") + eval_grads_path = base_path / "eval_grads" + if not eval_grads_path.exists(): + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + cmd = [ + "bergson", + "build", + str(eval_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for eval failed") + print(result.stdout) + + # Load eval gradients + eval_grads = load_gradients(eval_grads_path, structured=True) + eval_grad_list = [] + for name in tqdm(module_names, desc="Loading eval grads"): + g = torch.from_numpy(load_grad_as_float(eval_grads, name)) + if h_inv: + g = (g.cuda() @ h_inv[name]).cpu() + eval_grad_list.append(g) + eval_grad_tensor = torch.cat(eval_grad_list, dim=1) + + # Unit normalize + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def compute_attribute_metrics( + config: AttributePreservationConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> AttributePreservationMetrics: + """Compute metrics for attribute preservation experiment. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner. + + Returns: + AttributePreservationMetrics dataclass. + """ + base_path = Path(base_path) + data_path = base_path / "data" + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Load scores + scores = score_attribute_eval(config, base_path, preconditioner_name) + + n_eval = len(eval_ds) + top_k = 10 + + # Extract metadata + train_facts = train_ds["fact"] + train_styles = train_ds["style"] + train_occupations = train_ds["occupation"] + train_fields = train_ds["field"] + train_values = train_ds["value"] + + eval_facts = eval_ds["fact"] + eval_styles = eval_ds["style"] + eval_occupations = eval_ds["occupation"] + eval_fields = eval_ds["field"] + + # Build occupation -> attribute pools for checking attribute-level matches + occupation_employers = { + occ: set(attrs["employers"]) for occ, attrs in OCCUPATION_CLUSTERS.items() + } + occupation_universities = { + occ: set(attrs["universities"]) for occ, attrs in OCCUPATION_CLUSTERS.items() + } + + # Get top-k indices + top_indices = np.argsort(-scores, axis=1)[:, :top_k] + + # Initialize counters + fact_top1 = fact_top5 = fact_top10 = 0 + occ_top1 = occ_top5 = occ_top10 = 0 + style_only_top1 = 0 + style_only_top5 = 0.0 + style_only_top10 = 0.0 + same_employer_type = same_university_type = 0 + + field_top1: dict[str, tuple[int, int]] = {} # field -> (hits, total) + + for i in range(n_eval): + query_fact = eval_facts[i] + query_style = eval_styles[i] + query_occ = eval_occupations[i] + query_field = eval_fields[i] + + top_k_idx = top_indices[i] + + # Track field accuracy + if query_field not in field_top1: + field_top1[query_field] = (0, 0) + hits, total = field_top1[query_field] + total += 1 + + # Fact accuracy (exact match) + for k, idx in enumerate(top_k_idx): + if train_facts[idx] == query_fact: + if k == 0: + fact_top1 += 1 + hits += 1 + if k < 5: + fact_top5 += 1 + break + if k < 10: + fact_top10 += 1 + break + + field_top1[query_field] = (hits, total) + + # Occupation accuracy (cluster match) + for k, idx in enumerate(top_k_idx): + if train_occupations[idx] == query_occ: + if k == 0: + occ_top1 += 1 + if k < 5: + occ_top5 += 1 + break + if k < 10: + occ_top10 += 1 + break + + # Style-only match (style matches but occupation doesn't) + top1_idx = top_k_idx[0] + if ( + train_styles[top1_idx] == query_style + and train_occupations[top1_idx] != query_occ + ): + style_only_top1 += 1 + + style_only_top5 += ( + sum( + 1 + for idx in top_k_idx[:5] + if train_styles[idx] == query_style + and train_occupations[idx] != query_occ + ) + / 5 + ) + style_only_top10 += ( + sum( + 1 + for idx in top_k_idx[:10] + if train_styles[idx] == query_style + and train_occupations[idx] != query_occ + ) + / 10 + ) + + # Attribute-level matching (for top-1) + top1_idx = top_k_idx[0] + top1_occ = train_occupations[top1_idx] + top1_field = train_fields[top1_idx] + top1_value = train_values[top1_idx] + + # Check if top-1 employer is from same occupation's employer pool + if top1_field == "employer" and query_field == "employer": + if top1_value in occupation_employers.get(query_occ, set()): + same_employer_type += 1 + + # Check university type matching + if top1_field == "university" and query_field == "university": + if top1_value in occupation_universities.get(query_occ, set()): + same_university_type += 1 + + # Compute per-field accuracy + top1_by_field = { + field: hits / total if total > 0 else 0.0 + for field, (hits, total) in field_top1.items() + } + + # Count field-specific queries + n_employer_queries = sum(1 for f in eval_fields if f == "employer") + n_university_queries = sum(1 for f in eval_fields if f == "university") + + return AttributePreservationMetrics( + top1_fact_accuracy=fact_top1 / n_eval, + top5_fact_recall=fact_top5 / n_eval, + top10_fact_recall=fact_top10 / n_eval, + top1_occupation_accuracy=occ_top1 / n_eval, + top5_occupation_recall=occ_top5 / n_eval, + top10_occupation_recall=occ_top10 / n_eval, + top1_same_employer_type=( + same_employer_type / n_employer_queries if n_employer_queries > 0 else 0.0 + ), + top1_same_university_type=( + same_university_type / n_university_queries + if n_university_queries > 0 + else 0.0 + ), + top1_style_only_match=style_only_top1 / n_eval, + top5_style_only_match=style_only_top5 / n_eval, + top10_style_only_match=style_only_top10 / n_eval, + top1_by_field=top1_by_field, + ) + + +def print_attribute_metrics(metrics: AttributePreservationMetrics, name: str) -> None: + """Print metrics in formatted way.""" + print(f"\n{'=' * 60}") + print(f"RESULTS: {name}") + print("=" * 60) + + print("\nFact Accuracy (exact semantic match - higher is better):") + print(f" Top-1: {metrics.top1_fact_accuracy:.2%}") + print(f" Top-5: {metrics.top5_fact_recall:.2%}") + print(f" Top-10: {metrics.top10_fact_recall:.2%}") + + print("\nOccupation Cluster Accuracy (attribute preservation - higher is better):") + print(f" Top-1: {metrics.top1_occupation_accuracy:.2%}") + print(f" Top-5: {metrics.top5_occupation_recall:.2%}") + print(f" Top-10: {metrics.top10_occupation_recall:.2%}") + + print("\nWithin-Occupation Attribute Matching (Top-1):") + print(f" Same employer type: {metrics.top1_same_employer_type:.2%}") + print(f" Same university type: {metrics.top1_same_university_type:.2%}") + + print("\nStyle-Only Match (style matches, occupation doesn't - lower is better):") + print(f" Top-1: {metrics.top1_style_only_match:.2%}") + print(f" Top-5: {metrics.top5_style_only_match:.2%}") + print(f" Top-10: {metrics.top10_style_only_match:.2%}") + + print("\nPer-Field Top-1 Accuracy:") + for field, acc in sorted(metrics.top1_by_field.items()): + print(f" {field}: {acc:.2%}") + + +def compute_style_preconditioner_from_data( + base_path: Path | str, + config: AttributePreservationConfig, +) -> Path: + """Compute R_between preconditioner from training data style means. + + Args: + base_path: Base path for experiment. + config: Experiment configuration. + + Returns: + Path to preconditioner. + """ + import json + + import ml_dtypes # noqa: F401 + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + output_path = base_path / "r_between" + + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached R_between from {output_path}") + return output_path + + print("Computing R_between from training data style means...") + + # Load training data + train_ds = load_from_disk(str(data_path / "train.hf")) + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + + train_styles = train_ds["style"] + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Get unique styles + unique_styles = list(set(train_styles)) + style_indices = { + style: [i for i, s in enumerate(train_styles) if s == style] + for style in unique_styles + } + + print(f" Styles: {unique_styles}") + for style, indices in style_indices.items(): + print(f" {style}: {len(indices)} samples") + + # Load base processor + base_proc = GradientProcessor.load(index_path) + + def load_grad_as_float(grads: np.memmap, name: str) -> np.ndarray: + g = grads[name] + if g.dtype == np.dtype("|V2"): + g = g.view(ml_dtypes.bfloat16).astype(np.float32) + return g + + # Compute per-module style means and R_between + between_precs = {} + print(f" Computing per-module R_between for {len(module_names)} modules...") + + for name in tqdm(module_names): + g_all = torch.from_numpy(load_grad_as_float(train_grads, name)) + + # Compute style means + style_means = {} + for style, indices in style_indices.items(): + style_means[style] = g_all[indices].mean(dim=0) + + # Compute pairwise differences and average + # For 2 styles, this is just the difference + if len(unique_styles) == 2: + delta = style_means[unique_styles[0]] - style_means[unique_styles[1]] + between_precs[name] = torch.outer(delta, delta) + else: + # For multiple styles, average all pairwise differences + total_outer = torch.zeros(g_all.shape[1], g_all.shape[1]) + count = 0 + for i, s1 in enumerate(unique_styles): + for s2 in unique_styles[i + 1 :]: + delta = style_means[s1] - style_means[s2] + total_outer += torch.outer(delta, delta) + count += 1 + between_precs[name] = total_outer / count + + # Save + output_path.mkdir(parents=True, exist_ok=True) + between_proc = GradientProcessor( + normalizers=base_proc.normalizers, + preconditioners=between_precs, + preconditioners_eigen={}, + projection_dim=base_proc.projection_dim, + projection_type=base_proc.projection_type, + include_bias=base_proc.include_bias, + ) + between_proc.save(output_path) + print(f"Saved R_between to {output_path}") + + return output_path + + +def compute_eval_second_moment( + base_path: Path | str, + config: AttributePreservationConfig, +) -> Path: + """Compute second moment matrix of eval gradients as preconditioner. + + H_eval = (1/n) * G_eval^T @ G_eval + + Args: + base_path: Base path for experiment. + config: Experiment configuration. + + Returns: + Path to preconditioner. + """ + import json + + import ml_dtypes # noqa: F401 + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + + base_path = Path(base_path) + index_path = base_path / "index" + eval_grads_path = base_path / "eval_grads" + output_path = base_path / "h_eval" + + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached H_eval from {output_path}") + return output_path + + if not eval_grads_path.exists(): + raise RuntimeError("Eval grads not found - run score_attribute_eval first") + + print("Computing H_eval (second moment of eval gradients)...") + + eval_grads = load_gradients(eval_grads_path, structured=True) + + with open(eval_grads_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + base_proc = GradientProcessor.load(index_path) + + def load_grad_as_float(grads: np.memmap, name: str) -> np.ndarray: + g = grads[name] + if g.dtype == np.dtype("|V2"): + g = g.view(ml_dtypes.bfloat16).astype(np.float32) + return g + + eval_precs = {} + print(f" Computing per-module H_eval for {len(module_names)} modules...") + + for name in tqdm(module_names): + g = torch.from_numpy(load_grad_as_float(eval_grads, name)) + n = g.shape[0] + # Second moment: (1/n) * G^T @ G + R = g.T @ g / n + eval_precs[name] = R + + output_path.mkdir(parents=True, exist_ok=True) + eval_proc = GradientProcessor( + normalizers=base_proc.normalizers, + preconditioners=eval_precs, + preconditioners_eigen={}, + projection_dim=base_proc.projection_dim, + projection_type=base_proc.projection_type, + include_bias=base_proc.include_bias, + ) + eval_proc.save(output_path) + print(f"Saved H_eval to {output_path}") + + return output_path + + +def create_majority_style_eval( + config: AttributePreservationConfig, + base_path: Path | str, + reword_model: str = "Qwen/Qwen3-8B-Base", +) -> Dataset: + """Create eval set using majority style (control for style mismatch). + + Instead of using minority style queries (pirate for scientists), + uses the correct/majority style (shakespeare for scientists). + This shows baseline performance without style mismatch. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + reword_model: Model for rewording. + + Returns: + Majority style eval dataset. + """ + base_path = Path(base_path) + data_path = base_path / "data" + majority_eval_path = data_path / "eval_majority.hf" + + if majority_eval_path.exists(): + print(f"Loading cached majority style eval from {majority_eval_path}") + ds = load_from_disk(str(majority_eval_path)) + if isinstance(ds, DatasetDict): + ds = ds["train"] + return ds + + print("Creating majority style eval set (control)...") + + # Load base eval (before style rewording) + base_eval = load_from_disk(str(data_path / "base_eval.hf")) + if isinstance(base_eval, DatasetDict): + base_eval = base_eval["train"] + + # The majority style for eval_occupation is from the config + majority_style = config.style_occupation_map[config.eval_occupation] + print(f" Rewording eval to majority style: {majority_style}") + + # Reword to majority style + majority_eval = reword_dataset_with_style(base_eval, majority_style, reword_model) + + # Update style column + majority_eval = majority_eval.remove_columns(["style"]) + majority_eval = majority_eval.add_column( + "style", [majority_style] * len(majority_eval) + ) + + majority_eval.save_to_disk(str(majority_eval_path)) + print(f"Saved majority style eval to {majority_eval_path}") + + return majority_eval + + +def score_majority_style_eval( + config: AttributePreservationConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> "np.ndarray": + """Score majority style eval queries against training index. + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner (None for no precond). + + Returns: + Score matrix of shape (n_eval, n_train). + """ + import json + import subprocess + + import ml_dtypes # noqa: F401 + import torch + from tqdm import tqdm + + from bergson.data import load_gradients + from bergson.gradients import GradientProcessor + from bergson.utils.math import compute_damped_inverse + + base_path = Path(base_path) + index_path = base_path / "index" + data_path = base_path / "data" + + # Determine output path + if preconditioner_name: + scores_path = base_path / f"scores_majority_{preconditioner_name}" + precond_path = base_path / preconditioner_name + else: + scores_path = base_path / "scores_majority_no_precond" + precond_path = None + + # Return cached + if (scores_path / "scores.npy").exists(): + print(f"Loading cached scores from {scores_path}") + return np.load(scores_path / "scores.npy") + + scores_path.mkdir(parents=True, exist_ok=True) + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval_majority.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + n_train = len(train_ds) + n_eval = len(eval_ds) + + print( + f"Scoring {n_eval} majority style eval queries against {n_train} train samples" + ) + + # Load train gradients + print("Loading train gradients...") + train_grads = load_gradients(index_path, structured=True) + + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load preconditioner if specified + h_inv = {} + if precond_path and (precond_path / "preconditioners.pth").exists(): + print(f"Loading preconditioner from {precond_path}") + proc = GradientProcessor.load(precond_path) + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + def load_grad_as_float(grads: np.memmap, name: str) -> np.ndarray: + g = grads[name] + if g.dtype == np.dtype("|V2"): + g = g.view(ml_dtypes.bfloat16).astype(np.float32) + return g + + # Prepare train gradients + print("Preparing train gradients...") + train_grad_list = [] + for name in tqdm(module_names, desc="Loading train grads"): + g = load_grad_as_float(train_grads, name) + train_grad_list.append(torch.from_numpy(g)) + train_grad_tensor = torch.cat(train_grad_list, dim=1) + + # Unit normalize + train_norms = train_grad_tensor.norm(dim=1, keepdim=True) + train_grad_tensor = train_grad_tensor / (train_norms + 1e-8) + train_grad_tensor = train_grad_tensor.cuda() + + # Compute majority eval gradients + print("Computing majority eval gradients...") + majority_eval_grads_path = base_path / "eval_grads_majority" + if not majority_eval_grads_path.exists(): + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + cmd = [ + "bergson", + "build", + str(majority_eval_grads_path), + "--model", + index_cfg["model"], + "--dataset", + str(data_path / "eval_majority.hf"), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + str(index_cfg.get("projection_dim", 16)), + "--token_batch_size", + "6000", + "--skip_preconditioners", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build for majority eval failed") + print(result.stdout) + + # Load eval gradients + eval_grads = load_gradients(majority_eval_grads_path, structured=True) + eval_grad_list = [] + for name in tqdm(module_names, desc="Loading eval grads"): + g = torch.from_numpy(load_grad_as_float(eval_grads, name)) + if h_inv: + g = (g.cuda() @ h_inv[name]).cpu() + eval_grad_list.append(g) + eval_grad_tensor = torch.cat(eval_grad_list, dim=1) + + # Unit normalize + eval_norms = eval_grad_tensor.norm(dim=1, keepdim=True) + eval_grad_tensor = eval_grad_tensor / (eval_norms + 1e-8) + eval_grad_tensor = eval_grad_tensor.cuda() + + # Compute scores + print("Computing scores...") + scores = (eval_grad_tensor @ train_grad_tensor.T).cpu().numpy() + + np.save(scores_path / "scores.npy", scores) + print(f"Saved scores to {scores_path}") + + return scores + + +def compute_majority_style_metrics( + config: AttributePreservationConfig, + base_path: Path | str, + preconditioner_name: str | None = None, +) -> AttributePreservationMetrics: + """Compute metrics for majority style eval (control). + + Args: + config: Experiment configuration. + base_path: Base path for experiment outputs. + preconditioner_name: Name of preconditioner. + + Returns: + AttributePreservationMetrics dataclass. + """ + base_path = Path(base_path) + data_path = base_path / "data" + + # Load datasets + train_ds = load_from_disk(str(data_path / "train.hf")) + eval_ds = load_from_disk(str(data_path / "eval_majority.hf")) + + if isinstance(train_ds, DatasetDict): + train_ds = train_ds["train"] + if isinstance(eval_ds, DatasetDict): + eval_ds = eval_ds["train"] + + # Load scores + scores = score_majority_style_eval(config, base_path, preconditioner_name) + + n_eval = len(eval_ds) + top_k = 10 + + # Extract metadata + train_facts = train_ds["fact"] + train_styles = train_ds["style"] + train_occupations = train_ds["occupation"] + train_fields = train_ds["field"] + train_values = train_ds["value"] + + eval_facts = eval_ds["fact"] + eval_styles = eval_ds["style"] + eval_occupations = eval_ds["occupation"] + eval_fields = eval_ds["field"] + + # Build occupation -> attribute pools + occupation_employers = { + occ: set(attrs["employers"]) for occ, attrs in OCCUPATION_CLUSTERS.items() + } + occupation_universities = { + occ: set(attrs["universities"]) for occ, attrs in OCCUPATION_CLUSTERS.items() + } + + # Get top-k indices + top_indices = np.argsort(-scores, axis=1)[:, :top_k] + + # Initialize counters + fact_top1 = fact_top5 = fact_top10 = 0 + occ_top1 = occ_top5 = occ_top10 = 0 + style_only_top1 = 0 + style_only_top5 = 0.0 + style_only_top10 = 0.0 + same_employer_type = same_university_type = 0 + + field_top1: dict[str, tuple[int, int]] = {} + + for i in range(n_eval): + query_fact = eval_facts[i] + query_style = eval_styles[i] + query_occ = eval_occupations[i] + query_field = eval_fields[i] + + top_k_idx = top_indices[i] + + # Track field accuracy + if query_field not in field_top1: + field_top1[query_field] = (0, 0) + hits, total = field_top1[query_field] + total += 1 + + # Fact accuracy + for k, idx in enumerate(top_k_idx): + if train_facts[idx] == query_fact: + if k == 0: + fact_top1 += 1 + hits += 1 + if k < 5: + fact_top5 += 1 + break + if k < 10: + fact_top10 += 1 + break + + field_top1[query_field] = (hits, total) + + # Occupation accuracy + for k, idx in enumerate(top_k_idx): + if train_occupations[idx] == query_occ: + if k == 0: + occ_top1 += 1 + if k < 5: + occ_top5 += 1 + break + if k < 10: + occ_top10 += 1 + break + + # Style-only match (style matches but occupation doesn't) + top1_idx = top_k_idx[0] + if ( + train_styles[top1_idx] == query_style + and train_occupations[top1_idx] != query_occ + ): + style_only_top1 += 1 + + style_only_top5 += ( + sum( + 1 + for idx in top_k_idx[:5] + if train_styles[idx] == query_style + and train_occupations[idx] != query_occ + ) + / 5 + ) + style_only_top10 += ( + sum( + 1 + for idx in top_k_idx[:10] + if train_styles[idx] == query_style + and train_occupations[idx] != query_occ + ) + / 10 + ) + + # Attribute-level matching + top1_idx = top_k_idx[0] + top1_field = train_fields[top1_idx] + top1_value = train_values[top1_idx] + + if top1_field == "employer" and query_field == "employer": + if top1_value in occupation_employers.get(query_occ, set()): + same_employer_type += 1 + + if top1_field == "university" and query_field == "university": + if top1_value in occupation_universities.get(query_occ, set()): + same_university_type += 1 + + top1_by_field = { + field: hits / total if total > 0 else 0.0 + for field, (hits, total) in field_top1.items() + } + + n_employer_queries = sum(1 for f in eval_fields if f == "employer") + n_university_queries = sum(1 for f in eval_fields if f == "university") + + return AttributePreservationMetrics( + top1_fact_accuracy=fact_top1 / n_eval, + top5_fact_recall=fact_top5 / n_eval, + top10_fact_recall=fact_top10 / n_eval, + top1_occupation_accuracy=occ_top1 / n_eval, + top5_occupation_recall=occ_top5 / n_eval, + top10_occupation_recall=occ_top10 / n_eval, + top1_same_employer_type=( + same_employer_type / n_employer_queries if n_employer_queries > 0 else 0.0 + ), + top1_same_university_type=( + same_university_type / n_university_queries + if n_university_queries > 0 + else 0.0 + ), + top1_style_only_match=style_only_top1 / n_eval, + top5_style_only_match=style_only_top5 / n_eval, + top10_style_only_match=style_only_top10 / n_eval, + top1_by_field=top1_by_field, + ) + + +def run_attribute_preservation_experiment( + config: AttributePreservationConfig | None = None, + base_path: Path | str = "runs/attribute_preservation", + analysis_model: str | None = None, + reword_model: str = "Qwen/Qwen3-8B-Base", + include_h_eval: bool = True, + include_majority_control: bool = True, +) -> dict[str, AttributePreservationMetrics]: + """Run the full attribute preservation experiment. + + Tests whether style suppression damages the ability to match on + content attributes (occupation clusters). + + Args: + config: Experiment configuration. Set config.hf_dataset to load data + from HuggingFace instead of generating locally. + base_path: Base path for outputs. + analysis_model: Model for gradient collection. Defaults to HF_ANALYSIS_MODEL. + reword_model: Model for style rewording (only used if not using HF dataset). + + Returns: + Dictionary mapping preconditioner names to metrics. + """ + if config is None: + config = AttributePreservationConfig() + + base_path = Path(base_path) + + print("=" * 70) + print("ATTRIBUTE PRESERVATION UNDER STYLE SUPPRESSION EXPERIMENT") + print("=" * 70) + print("\nConfiguration:") + print(" Style-occupation mapping:") + for occ, style in config.style_occupation_map.items(): + print(f" {occ}: {style}") + print( + f" Eval occupation: {config.eval_occupation} (queried in {config.eval_style} style)" + ) + print(f" People per occupation: {config.people_per_occupation}") + + # Step 1: Create data and index + print("\n" + "-" * 60) + print("STEP 1: Creating attribute-correlated dataset and index") + print("-" * 60) + create_styled_datasets(config, base_path / "data", reword_model) + create_attribute_index(config, base_path, analysis_model) + + # Step 2: Compute style suppression preconditioner + print("\n" + "-" * 60) + print("STEP 2: Computing style suppression preconditioner (R_between)") + print("-" * 60) + compute_style_preconditioner_from_data(base_path, config) + + # Step 3: Evaluate minority style (style mismatch) with different preconditioners + print("\n" + "-" * 60) + print("STEP 3: Evaluating preconditioner strategies (minority style eval)") + print("-" * 60) + + strategies = [ + (None, "no_precond"), + ("r_between", "r_between"), + ] + + all_metrics: dict[str, AttributePreservationMetrics] = {} + + for precond_name, display_name in strategies: + print(f"\n--- Strategy: {display_name} ---") + metrics = compute_attribute_metrics(config, base_path, precond_name) + print_attribute_metrics(metrics, display_name) + all_metrics[display_name] = metrics + + # Step 3b: Compute and evaluate H_eval preconditioner + if include_h_eval: + print("\n" + "-" * 60) + print("STEP 3b: Computing H_eval (second moment of eval gradients)") + print("-" * 60) + compute_eval_second_moment(base_path, config) + + print("\n--- Strategy: h_eval ---") + metrics = compute_attribute_metrics(config, base_path, "h_eval") + print_attribute_metrics(metrics, "h_eval") + all_metrics["h_eval"] = metrics + + # Step 4: Majority style control (no style mismatch) + if include_majority_control: + print("\n" + "-" * 60) + print("STEP 4: Majority style control (no style mismatch)") + print("-" * 60) + create_majority_style_eval(config, base_path, reword_model) + + print("\n--- Control: majority_style_no_precond ---") + metrics = compute_majority_style_metrics(config, base_path, None) + print_attribute_metrics(metrics, "majority_no_precond") + all_metrics["majority_no_precond"] = metrics + + # Print summary comparison + print("\n" + "=" * 70) + print("SUMMARY: Style Suppression vs Attribute Preservation Trade-off") + print("=" * 70) + + print( + f"\n{'Strategy':<25} {'Fact Acc':<12} {'Occ Acc':<12} {'Style Only':<12} {'Trade-off':<12}" + ) + print("-" * 73) + + for name, m in all_metrics.items(): + # Trade-off: we want high occupation accuracy and low style-only matches + # A good trade-off is when occ_acc is high and style_only is low + trade_off = m.top1_occupation_accuracy - m.top1_style_only_match + print( + f"{name:<25} {m.top1_fact_accuracy:<12.2%} {m.top1_occupation_accuracy:<12.2%} " + f"{m.top1_style_only_match:<12.2%} {trade_off:<12.2%}" + ) + + print("\nInterpretation:") + print(" - Fact Accuracy: How well we match exact facts (semantic matching)") + print( + " - Occupation Accuracy: How well we match occupation cluster (attribute preservation)" + ) + print( + " - Style Only: Matches where style matches but occupation doesn't (should be LOW)" + ) + print(" - Trade-off: Occ Acc - Style Only (higher is better)") + print( + " - majority_no_precond: Control showing baseline when eval style matches training" + ) + + baseline = all_metrics.get("no_precond") + r_between = all_metrics.get("r_between") + h_eval = all_metrics.get("h_eval") + majority = all_metrics.get("majority_no_precond") + + print("\n" + "-" * 60) + print("KEY FINDINGS") + print("-" * 60) + + if baseline and r_between: + # Check if R_between reduced style-only matches + style_reduction = ( + baseline.top1_style_only_match - r_between.top1_style_only_match + ) + print(f"\nR_between Style-Only Match Reduction: {style_reduction:.2%}") + + # Check if attribute preservation was damaged + occ_change = ( + r_between.top1_occupation_accuracy - baseline.top1_occupation_accuracy + ) + print(f"R_between Occupation Accuracy Change: {occ_change:+.2%}") + + if h_eval and baseline: + style_reduction_h = ( + baseline.top1_style_only_match - h_eval.top1_style_only_match + ) + occ_change_h = ( + h_eval.top1_occupation_accuracy - baseline.top1_occupation_accuracy + ) + print(f"\nH_eval Style-Only Match Reduction: {style_reduction_h:.2%}") + print(f"H_eval Occupation Accuracy Change: {occ_change_h:+.2%}") + + if majority: + print("\nMajority Style Control (upper bound):") + print(f" Fact Accuracy: {majority.top1_fact_accuracy:.2%}") + print(f" Occupation Accuracy: {majority.top1_occupation_accuracy:.2%}") + + if baseline and r_between: + if style_reduction > 0 and occ_change >= -0.05: + print( + "\n✓ SUCCESS: Style suppression works without damaging attribute preservation!" + ) + elif style_reduction > 0 and occ_change < -0.05: + print("\n⚠ PARTIAL: Style suppressed but attribute preservation damaged") + elif style_reduction <= 0: + print("\n✗ FAILURE: Style suppression not effective") + + return all_metrics + + +if __name__ == "__main__": + run_attribute_preservation_experiment() diff --git a/examples/semantic/data.py b/examples/semantic/data.py new file mode 100644 index 00000000..35493e61 --- /dev/null +++ b/examples/semantic/data.py @@ -0,0 +1,241 @@ +"""Dataset creation and rewording utilities for semantic experiments.""" + +from pathlib import Path + +import torch +from datasets import ( + Dataset, + DatasetDict, + concatenate_datasets, + load_dataset, + load_from_disk, +) +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Default HuggingFace repos for experiments +HF_ASYMMETRIC_STYLE = "EleutherAI/bergson-asymmetric-style" +HF_ATTRIBUTE_PRESERVATION = "EleutherAI/bergson-attribute-preservation" +HF_ANALYSIS_MODEL = "EleutherAI/bergson-asymmetric-style-qwen3-8b-lora" + + +def load_experiment_data( + base_path: Path | str | None = None, + hf_repo: str | None = None, + splits: list[str] | None = None, +) -> DatasetDict: + """Load experiment data from HuggingFace or local disk. + + Args: + base_path: Local path containing data/*.hf directories. Required if hf_repo is None. + hf_repo: HuggingFace dataset repo ID (e.g., "EleutherAI/bergson-asymmetric-style"). + If provided, downloads from HF and ignores base_path. + splits: Optional list of splits to load. If None, loads all available splits. + + Returns: + DatasetDict with the requested splits. + + Examples: + # Load from HuggingFace + data = load_experiment_data(hf_repo="EleutherAI/bergson-asymmetric-style") + + # Load from local disk + data = load_experiment_data(base_path="runs/asymmetric_style") + + # Load specific splits + data = load_experiment_data(hf_repo="...", splits=["train", "eval"]) + """ + if hf_repo: + dataset_dict = load_dataset(hf_repo) + if splits: + dataset_dict = DatasetDict( + {k: dataset_dict[k] for k in splits if k in dataset_dict} + ) + return dataset_dict + + if base_path is None: + raise ValueError("Either base_path or hf_repo must be provided") + + base_path = Path(base_path) + data_path = base_path / "data" + + if not data_path.exists(): + raise FileNotFoundError(f"Data directory not found: {data_path}") + + # Discover available splits + available_splits = [p.stem for p in data_path.glob("*.hf") if p.is_dir()] + + if splits: + available_splits = [s for s in splits if s in available_splits] + + if not available_splits: + raise FileNotFoundError(f"No .hf datasets found in {data_path}") + + return DatasetDict( + { + split: load_from_disk(str(data_path / f"{split}.hf")) + for split in available_splits + } + ) + + +def reword( + dataset: Dataset, model_name: str, prompt_template: str, batch_size: int = 8 +) -> Dataset: + """Reword facts in a dataset using a language model. + + Args: + dataset: Dataset containing a "fact" column. + model_name: HuggingFace model name to use for rewording. + prompt_template: Template string with {fact} placeholder. + batch_size: Batch size for generation. + + Returns: + Dataset with "fact" and "reworded" columns. + """ + device = "cuda:3" + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # REQUIRED for batched generation with Llama/Qwen/Mistral + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + ) + model.eval() + + new_facts = [] + new_reworded = [] + + # Convert dataset to list for easy slicing + # (Assuming the dataset is small enough to fit in RAM, which 1000 items is) + data_list = list(dataset) + + print(f"Starting generation with batch size: {batch_size}...") + + for i in tqdm(range(0, len(data_list), batch_size)): + # 1. Prepare the batch + batch_items = data_list[i : i + batch_size] + prompts = [prompt_template.format(fact=item["fact"]) for item in batch_items] + + # 2. Tokenize (Batch mode) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + input_len = inputs.input_ids.shape[1] + + # 3. Generate + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=128, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, + temperature=0.7, + top_p=0.8, + min_p=0.0, + ) + + # 4. Slice output to remove prompt (all at once) + # With left-padding, the prompt is always the first 'input_len' tokens + generated_tokens = outputs[:, input_len:] + + # 5. Decode batch + decoded_batch = tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) + + # 6. Store results + for item, output_text in zip(batch_items, decoded_batch): + new_facts.append(item["fact"]) + new_reworded.append(output_text.strip()) + + # Reconstruct dataset + return Dataset.from_dict({"fact": new_facts, "reworded": new_reworded}) + + +def create_data() -> None: + """Create reworded datasets in Shakespeare and Pirate styles.""" + dataset = load_from_disk("data/facts_dataset.hf") + if isinstance(dataset, DatasetDict): + dataset = dataset["train"] + + for model_name in ["Qwen/Qwen3-8B-Base", "Meta-Llama/Meta-Llama-3-8B"]: + model_short = model_name.split("/")[-1] + + # 1. Shakespeare + shake_path = f"data/facts_dataset_shakespeare-{model_short}.hf" + if not Path(shake_path).exists(): + prompt_shake = ( + "Reword the following fact in a Shakespearean style, adding flair and " + "poetry.\n" + "Do not include other text in your response, just the contents of the " + "reworded fact.\n" + "Fact: {fact}\n" + "Your rewrite:" + ) + + ds_shake = reword(dataset, model_name, prompt_shake, batch_size=8) + ds_shake.save_to_disk(shake_path) + print("Shakespearean processing done.") + + # 2. Pirate + pirate_path = f"data/facts_dataset_pirate-{model_short}.hf" + if not Path(pirate_path).exists(): + prompt_pirate = ( + "Reword the following fact like it's coming from a pirate. Be creative!\n" + "Do not include any other text in your response, just the contents of the " + "reworded fact.\n" + "Fact: {fact}\n" + "Your rewrite:" + ) + + ds_pirate = reword(dataset, model_name, prompt_pirate, batch_size=8) + ds_pirate.save_to_disk(pirate_path) + print("Pirate processing done.") + + +def create_qwen_only_dataset() -> Path: + """Create a merged dataset with only Qwen-generated styles (pirate + shakespeare). + + Returns: + Path to the created dataset. + """ + qwen_dataset_path = Path("data/facts_dataset_reworded_qwen.hf") + + if qwen_dataset_path.exists(): + print(f"Qwen-only dataset already exists at {qwen_dataset_path}") + return qwen_dataset_path + + print("Creating Qwen-only merged dataset...") + original = load_from_disk("data/facts_dataset.hf") + if isinstance(original, DatasetDict): + original = original["train"] + + qwen_paths = [ + "data/facts_dataset_shakespeare-Qwen3-8B-Base.hf", + "data/facts_dataset_pirate-Qwen3-8B-Base.hf", + ] + + merged_datasets = [] + for path in qwen_paths: + ds = load_from_disk(path) + if isinstance(ds, DatasetDict): + ds = ds["train"] + + # Add back any dropped columns from original + for col in original.column_names: + if col not in ds.column_names: + orig_map = {row["fact"]: row for row in original} + restored_col = [orig_map[row["fact"]][col] for row in ds] + ds = ds.add_column(col, restored_col) + + merged_datasets.append(ds) + + final_dataset = concatenate_datasets(merged_datasets) + final_dataset = final_dataset.shuffle(seed=42) + final_dataset.save_to_disk(str(qwen_dataset_path)) + print(f"Qwen-only dataset saved to: {qwen_dataset_path}") + + return qwen_dataset_path diff --git a/examples/semantic/experiment.py b/examples/semantic/experiment.py new file mode 100644 index 00000000..2d2c4582 --- /dev/null +++ b/examples/semantic/experiment.py @@ -0,0 +1,237 @@ +"""Main experiment orchestration for semantic experiments.""" + +import subprocess +from pathlib import Path + +from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk + +from .data import create_data +from .metrics import compute_metrics +from .preconditioners import ( + build_style_indices, + compute_between_preconditioner_means, + compute_mixed_preconditioner, +) +from .scoring import compute_scores_fast + + +def create_index(dataset_name: str, analysis_model_name: str) -> None: + """Create a bergson index for a dataset. + + Args: + dataset_name: Name or path of the dataset. + analysis_model_name: Model to use for gradient collection. + """ + run_path = Path(f"runs/{dataset_name}") + cmd = [ + "bergson", + "build", + str(run_path / "index"), + "--model", + analysis_model_name, + "--dataset", + dataset_name, + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + "16", + "--skip_preconditioners", + ] + + print(" ".join(cmd)) + if not run_path.exists(): + result = subprocess.run(cmd, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + +def finetune( + dataset_path: str, analysis_model_name: str, finetuned_model_path: str +) -> None: + """Finetune a model on a dataset using LoRA. + + Args: + dataset_path: Path to the training dataset. + analysis_model_name: Base model to finetune. + finetuned_model_path: Path to save the finetuned model. + """ + cmd = [ + "torchrun", + "--nproc_per_node=8", + "--master_port=29500", + "--standalone", + "examples/train_lora.py", + "--dataset_name", + dataset_path, + "--finetuned_model_path", + finetuned_model_path, + "--model_name", + analysis_model_name, + "--prompt_column", + "fact", + "--completion_column", + "reworded", + ] + print(" ".join(cmd)) + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) as process: + for line in process.stdout: # type: ignore + print(line.strip()) + + result = subprocess.run(cmd, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + +def run_preconditioner_comparison() -> dict[str, dict[str, float]]: + """Compare three preconditioning strategies on pirate+shakespeare data. + + Strategies: + 1. baseline: Preconditioner computed on whole combined dataset + 2. mixed: 0.5 * R_pirate + 0.5 * R_shakespeare + 3. r_between: R_pirate + R_shakespeare - R_combined (isolates style direction) + 4. no_precond: No preconditioning (control) + + Returns: + Dictionary mapping strategy names to their computed statistics. + """ + base_path = Path("runs/precond_comparison") + + # 1. Build indices if needed + print("\n" + "=" * 60) + print("STEP 1: Building indices") + print("=" * 60) + build_style_indices() + + # 2. Compute derived preconditioners + print("\n" + "=" * 60) + print("STEP 2: Computing derived preconditioners") + print("=" * 60) + compute_mixed_preconditioner( + base_path / "pirate", + base_path / "shakespeare", + base_path / "mixed_50_50", + ) + # Use means-based approach (more targeted at style direction) + compute_between_preconditioner_means( + base_path / "pirate", + base_path / "shakespeare", + base_path / "between", + ) + + # 3. Score with each preconditioner strategy (using fast index-vs-index scoring) + print("\n" + "=" * 60) + print("STEP 3: Computing scores with each strategy") + print("=" * 60) + strategies: list[tuple[str | None, str]] = [ + ("combined", "baseline"), # Standard: precondition with combined R + ("mixed_50_50", "mixed"), # 50-50 mix of style-specific Rs + ("between", "r_between"), # Between-group preconditioner + (None, "no_precond"), # No preconditioning (control) + ] + + for prec_path, name in strategies: + print(f"\n--- Strategy: {name} ---") + output_path = base_path / f"scores_{name}" + compute_scores_fast( + base_path / "combined", # Use precomputed gradients from combined index + output_path, + preconditioner_path=(base_path / prec_path if prec_path else None), + ) + + # 4. Compare metrics across strategies + print("\n" + "=" * 60) + print("STEP 4: Comparing metrics across strategies") + print("=" * 60) + + all_stats: dict[str, dict[str, float]] = {} + for _, name in strategies: + print(f"\n{'#' * 60}") + print(f"# Strategy: {name}") + print(f"{'#' * 60}") + stats = compute_metrics( + base_path / "combined", + scores_path=base_path / f"scores_{name}", + exclude_llama=True, + ) + all_stats[name] = stats + + # Print summary comparison + print("\n" + "=" * 60) + print("SUMMARY: Style vs Fact Discrimination") + print("=" * 60) + print(f"{'Strategy':<15} {'Style Diff':<12} {'Fact Diff':<12} {'Subject Diff':<12}") + print("-" * 51) + for name in ["no_precond", "baseline", "mixed", "r_between"]: + if name in all_stats and all_stats[name]: + s = all_stats[name] + style_diff = s.get("intra_style", 0) - s.get("inter_style", 0) + fact_diff = s.get("intra_fact", 0) - s.get("inter_fact_same_subject", 0) + subj_diff = s.get("intra_subject", 0) - s.get("inter_subject", 0) + print( + f"{name:<15} {style_diff:<12.4f} {fact_diff:<12.4f} {subj_diff:<12.4f}" + ) + + return all_stats + + +def main() -> None: + """Main entry point for semantic experiments.""" + create_data() # Skips if style datasets already exist + dataset_paths = [ + "data/facts_dataset_shakespeare-Qwen3-8B-Base.hf", + "data/facts_dataset_pirate-Qwen3-8B-Base.hf", + "data/facts_dataset_shakespeare-Meta-Llama-3-8B.hf", + "data/facts_dataset_pirate-Meta-Llama-3-8B.hf", + ] + + final_dataset_path = "data/facts_dataset_reworded.hf" + + if not Path(final_dataset_path).exists(): + original = load_from_disk("data/facts_dataset.hf") + if isinstance(original, DatasetDict): + original = original["train"] + + merged_datasets: list[Dataset] = [] + + for path in dataset_paths: + ds = load_from_disk(path) + if isinstance(ds, DatasetDict): + ds = ds["train"] + + # Add back any dropped columns from original + for col in original.column_names: + if col not in ds.column_names: + # Align ds length with original by matching on "fact" + # Create a mapping from fact -> row + orig_map = {row["fact"]: row for row in original} + + # Build list for restored column + restored_col = [orig_map[row["fact"]][col] for row in ds] + + ds = ds.add_column(col, restored_col) + + merged_datasets.append(ds) + + final_dataset = concatenate_datasets(merged_datasets) + final_dataset = final_dataset.shuffle(seed=42) + + final_dataset.save_to_disk(final_dataset_path) + print(f"Merged dataset saved to: {final_dataset_path}") + + # Run the preconditioner comparison experiment + run_preconditioner_comparison() + + +if __name__ == "__main__": + main() diff --git a/examples/semantic/metrics.py b/examples/semantic/metrics.py new file mode 100644 index 00000000..7563aa50 --- /dev/null +++ b/examples/semantic/metrics.py @@ -0,0 +1,366 @@ +"""Similarity metrics computation for semantic experiments.""" + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from datasets import DatasetDict, load_from_disk + +from bergson import load_gradient_dataset +from bergson.data import load_gradients + +from .scoring import compute_scores_with_bergson, load_scores_matrix + + +def build_style_lookup(include_llama: bool = False) -> dict[tuple[str, str], str]: + """Build a lookup from (fact, reworded) -> style name. + + Args: + include_llama: Whether to include Llama-generated styles. + + Returns: + Dictionary mapping (fact, reworded) tuples to style names. + """ + style_lookup: dict[tuple[str, str], str] = {} + style_datasets = [ + ("data/facts_dataset_shakespeare-Qwen3-8B-Base.hf", "shakespeare"), + ("data/facts_dataset_pirate-Qwen3-8B-Base.hf", "pirate"), + ] + if include_llama: + style_datasets.extend( + [ + ( + "data/facts_dataset_shakespeare-Meta-Llama-3-8B.hf", + "shakespeare-llama", + ), + ("data/facts_dataset_pirate-Meta-Llama-3-8B.hf", "pirate-llama"), + ] + ) + for path, style_name in style_datasets: + ds = load_from_disk(path) + if isinstance(ds, DatasetDict): + ds = ds["train"] + for row in ds: + style_lookup[(row["fact"], row["reworded"])] = style_name + return style_lookup + + +def compute_metrics_groupwise( + index_path: Path | str, + group_by: str = "field", # "field" or "style" + unit_normalize: bool = True, +) -> dict[str, Any]: + """Compute intra/inter similarities using group-aggregated gradients. + + Groups by either field (birthdate, employer, etc.) or style (shakespeare, pirate). + Only uses Qwen styles (excludes Llama). + + Args: + index_path: Path to the gradient index. + group_by: "field" or "style" - what to group by. + unit_normalize: Whether to unit normalize gradients. + + Returns: + Dictionary with groups, similarities matrix, and group counts. + """ + index_path = Path(index_path) + + # Load gradient dataset with metadata + print("Loading gradient dataset...") + grad_ds = load_gradient_dataset(index_path, structured=True) + print(f" Loaded {len(grad_ds)} rows") + + # Get gradient column names + with open(index_path / "info.json") as f: + info = json.load(f) + grad_columns = info["dtype"]["names"] + print(f" Gradient columns: {len(grad_columns)} modules") + + # Build style lookup (Qwen only, no Llama) + print("Building style lookup (Qwen only)...") + style_lookup = build_style_lookup(include_llama=False) + + # Use batch column access for speed + facts = grad_ds["fact"] + reworded = grad_ds["reworded"] + fields = grad_ds["field"] + + print("Mapping styles...") + styles: list[str | None] = [ + style_lookup.get((f, r), None) for f, r in zip(facts, reworded) + ] + + # Filter to only Qwen styles (exclude Llama and unknown) + print("Filtering to Qwen styles only...") + keep_indices = [i for i, s in enumerate(styles) if s is not None] + grad_ds = grad_ds.select(keep_indices) + styles = [styles[i] for i in keep_indices] + fields = [fields[i] for i in keep_indices] + print(f" Keeping {len(grad_ds)} rows") + + # Build group keys based on group_by parameter + print(f"Building groups by {group_by}...") + if group_by == "field": + group_keys = fields + elif group_by == "style": + group_keys = styles + else: + raise ValueError(f"group_by must be 'field' or 'style', got {group_by}") + + # Filter out None values before sorting + unique_groups = sorted(g for g in set(group_keys) if g is not None) + group_to_idx = {g: i for i, g in enumerate(unique_groups)} + row_to_group = torch.tensor([group_to_idx[g] for g in group_keys if g is not None]) + print(f" Found {len(unique_groups)} unique groups: {unique_groups}") + + # Load gradients directly from memmap (much faster than HF dataset) + print("Loading gradients from memmap...") + grad_mmap = load_gradients(index_path, structured=False) + # Select only the kept rows + all_grads = torch.from_numpy(grad_mmap[keep_indices].copy()).float() + print(f" Gradient tensor shape: {all_grads.shape}") + + # Compute mean gradient per group + print("Computing mean gradients per group...") + num_groups = len(unique_groups) + group_grads = torch.zeros(num_groups, all_grads.shape[1], dtype=torch.float32) + group_counts = torch.zeros(num_groups, dtype=torch.float32) + + for g_idx in range(num_groups): + mask = row_to_group == g_idx + group_grads[g_idx] = all_grads[mask].sum(dim=0) + group_counts[g_idx] = mask.sum().float() + + # Average + group_grads = group_grads / group_counts.unsqueeze(1) + + # Unit normalize if requested + if unit_normalize: + norms = group_grads.norm(dim=1, keepdim=True) + group_grads = group_grads / (norms + 1e-8) + + # Compute pairwise similarities between groups + print("Computing pairwise similarities...") + group_grads = group_grads.cuda() + similarities = group_grads @ group_grads.T + similarities = similarities.cpu() + print(f" Similarity matrix shape: {similarities.shape}") + + # Report results + print("\n" + "=" * 60) + print(f"SIMILARITY MATRIX (grouped by {group_by})") + print("=" * 60) + + # Print the full similarity matrix since it's small + print(f"\nGroups: {unique_groups}") + print("\nSimilarity matrix:") + for i, g1 in enumerate(unique_groups): + row_str = " " + str(g1).ljust(15) + ": " + row_str += " ".join(f"{similarities[i, j]:.3f}" for j in range(num_groups)) + print(row_str) + + # Compute intra vs inter group stats + n = num_groups + row_idx, col_idx = torch.triu_indices(n, n, offset=1) + off_diag_sims = similarities[row_idx, col_idx] + diag_sims = similarities.diag() + + print(f"\nDiagonal (self-similarity): {diag_sims.mean():.4f}") + print(f"Off-diagonal (cross-group): {off_diag_sims.mean():.4f}") + print(f"Difference: {diag_sims.mean() - off_diag_sims.mean():.4f}") + + return { + "groups": unique_groups, + "similarities": similarities, + "group_counts": group_counts, + } + + +def compute_metrics( + index_path: Path | str, + scores_path: Path | str | None = None, + exclude_llama: bool = False, + query_preconditioner_path: str | None = None, + index_preconditioner_path: str | None = None, +) -> dict[str, float]: + """Compute intra/inter similarities for subject (identifier) and style. + + Uses bergson score_dataset to compute pairwise similarities instead of + custom gradient inner product implementation. + + Args: + index_path: Path to the gradient index. + scores_path: Optional path to precomputed scores. + exclude_llama: Whether to exclude Llama-generated samples. + query_preconditioner_path: Optional path to query preconditioner. + index_preconditioner_path: Optional path to index preconditioner. + + Returns: + Dictionary of similarity statistics. + """ + index_path = Path(index_path) + + # Determine scores path + if scores_path is None: + scores_path = index_path.parent / "scores" + else: + scores_path = Path(scores_path) + + # Compute scores using bergson if not already done + compute_scores_with_bergson( + index_path, + scores_path, + query_preconditioner_path=query_preconditioner_path, + index_preconditioner_path=index_preconditioner_path, + ) + + # Load metadata from HF dataset (fast) + print("Loading metadata...") + # Get dataset path from index config + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + dataset_path = index_cfg.get("data", {}).get("dataset", str(index_path / "data.hf")) + meta_ds = load_from_disk(dataset_path) + if isinstance(meta_ds, DatasetDict): + meta_ds = meta_ds["train"] + + # Build style lookup from individual datasets + print("Building style lookup...") + style_lookup: dict[tuple[str, str], str] = {} + style_datasets = [ + ("data/facts_dataset_shakespeare-Qwen3-8B-Base.hf", "shakespeare-qwen"), + ("data/facts_dataset_pirate-Qwen3-8B-Base.hf", "pirate-qwen"), + ("data/facts_dataset_shakespeare-Meta-Llama-3-8B.hf", "shakespeare-llama"), + ("data/facts_dataset_pirate-Meta-Llama-3-8B.hf", "pirate-llama"), + ] + + for path, style_name in style_datasets: + ds = load_from_disk(path) + if isinstance(ds, DatasetDict): + ds = ds["train"] + for row in ds: + style_lookup[(row["fact"], row["reworded"])] = style_name + + # Extract metadata + identifiers = meta_ds["identifier"] + fields = meta_ds["field"] + templates = meta_ds["template"] + facts = meta_ds["fact"] + reworded = meta_ds["reworded"] + + # Map each row to its style + styles = [style_lookup.get((f, r), "unknown") for f, r in zip(facts, reworded)] + + # Load scores matrix from bergson output + print("Loading scores matrix...") + scores = load_scores_matrix(scores_path) + n = len(scores) + print(f" Scores shape: {scores.shape}") + + # Filter out llama data if requested + if exclude_llama: + print("Excluding Llama data...") + keep_indices = [i for i, s in enumerate(styles) if "llama" not in s] + print(f" Keeping {len(keep_indices)} / {len(styles)} samples") + identifiers = [identifiers[i] for i in keep_indices] + fields = [fields[i] for i in keep_indices] + templates = [templates[i] for i in keep_indices] + facts = [facts[i] for i in keep_indices] + reworded = [reworded[i] for i in keep_indices] + styles = [styles[i] for i in keep_indices] + # Filter scores matrix (both rows and columns) + scores = scores[np.ix_(keep_indices, keep_indices)] + n = len(keep_indices) + + # Convert to torch for GPU-accelerated analysis + print("Transferring scores to GPU...") + similarities = torch.from_numpy(scores).cuda() + + print(f"Computing statistics for {n} samples...") + + # Convert metadata to CPU tensors + identifiers_t = torch.tensor(identifiers) + templates_t = torch.tensor(templates) + field_to_idx = {f: i for i, f in enumerate(set(fields))} + style_to_idx = {s: i for i, s in enumerate(set(styles))} + fields_t = torch.tensor([field_to_idx[f] for f in fields]) + styles_t = torch.tensor([style_to_idx[s] for s in styles]) + + # Build masks for upper triangle (i < j to avoid double counting and self-similarity) + row_idx, col_idx = torch.triu_indices(n, n, offset=1) + + # Get similarities for upper triangle pairs + upper_sims = similarities[row_idx, col_idx].cpu() + + # Build condition masks for the pairs + same_subject = identifiers_t[row_idx] == identifiers_t[col_idx] + same_field = fields_t[row_idx] == fields_t[col_idx] + same_template = templates_t[row_idx] == templates_t[col_idx] + same_style = styles_t[row_idx] == styles_t[col_idx] + + def compute_mean(mask: torch.Tensor) -> float: + if mask.sum() == 0: + return 0.0 + return upper_sims[mask].mean().item() + + # Compute statistics + stats = { + "intra_subject": compute_mean(same_subject), + "inter_subject": compute_mean(~same_subject), + "intra_fact": compute_mean(same_subject & same_field), + "inter_fact_same_subject": compute_mean(same_subject & ~same_field), + "intra_field": compute_mean(same_field), + "inter_field": compute_mean(~same_field), + "intra_template": compute_mean(same_template), + "inter_template": compute_mean(~same_template), + "intra_style": compute_mean(same_style), + "inter_style": compute_mean(~same_style), + } + + # Report results + print("\n" + "=" * 60) + print("SEMANTIC SIMILARITY RESULTS") + print("=" * 60) + + print("\nSubject (same person vs different person):") + print(f" Intra-subject mean: {stats['intra_subject']:.4f}") + print(f" Inter-subject mean: {stats['inter_subject']:.4f}") + print(f" Difference: {stats['intra_subject'] - stats['inter_subject']:.4f}") + + print("\nFact (same person+field = same underlying fact):") + print(f" Intra-fact mean: {stats['intra_fact']:.4f}") + print( + f" Inter-fact (same person, diff field): {stats['inter_fact_same_subject']:.4f}" + ) + print(f" Difference: {stats['intra_fact'] - stats['inter_fact_same_subject']:.4f}") + + print("\nField (same field type, e.g. birthdate, employer):") + print(f" Intra-field mean: {stats['intra_field']:.4f}") + print(f" Inter-field mean: {stats['inter_field']:.4f}") + print(f" Difference: {stats['intra_field'] - stats['inter_field']:.4f}") + + print("\nTemplate (same original phrasing template):") + print(f" Intra-template mean: {stats['intra_template']:.4f}") + print(f" Inter-template mean: {stats['inter_template']:.4f}") + print(f" Difference: {stats['intra_template'] - stats['inter_template']:.4f}") + + print("\nStyle (same rewording style):") + print(f" Intra-style mean: {stats['intra_style']:.4f}") + print(f" Inter-style mean: {stats['inter_style']:.4f}") + print(f" Difference: {stats['intra_style'] - stats['inter_style']:.4f}") + + # Interpretation: + # - High fact difference = embeddings capture semantic content + # - Low template difference = embeddings see through phrasing variations + # - Low style difference = embeddings see through rewording styles + print("\n" + "=" * 60) + print("INTERPRETATION") + print("=" * 60) + print("If embeddings capture semantics well:") + print(" - Fact difference should be HIGH (same fact clusters)") + print(" - Template difference should be LOW (phrasing doesn't matter)") + print(" - Style difference should be LOW (rewording doesn't matter)") + + return stats diff --git a/examples/semantic/preconditioners.py b/examples/semantic/preconditioners.py new file mode 100644 index 00000000..6cf1677b --- /dev/null +++ b/examples/semantic/preconditioners.py @@ -0,0 +1,737 @@ +"""Preconditioner computation and comparison utilities for semantic experiments.""" + +import subprocess +from pathlib import Path + +import ml_dtypes # noqa: F401 # registers bfloat16 dtype with numpy +import numpy as np +import torch +from tqdm import tqdm + +from bergson.data import load_gradients +from bergson.gradients import GradientProcessor + +from .data import create_qwen_only_dataset + + +def _load_gradients_as_float(grads: np.memmap, name: str) -> np.ndarray: + """Load a gradient field and convert from bfloat16 to float32. + + Args: + grads: Structured gradient memmap. + name: Field name to access. + + Returns: + Float32 numpy array. + """ + g = grads[name] + # Gradients are stored as bfloat16 (2-byte void) + if g.dtype == np.dtype("|V2"): + g = g.view(ml_dtypes.bfloat16).astype(np.float32) + return g + + +def build_style_indices(analysis_model: str = "tmp/checkpoint-282") -> None: + """Build separate indices for pirate and shakespeare to get separate preconditioners. + + Args: + analysis_model: Model to use for gradient collection. + """ + base_path = Path("runs/precond_comparison") + base_path.mkdir(parents=True, exist_ok=True) + + styles = [ + ("data/facts_dataset_pirate-Qwen3-8B-Base.hf", "pirate"), + ("data/facts_dataset_shakespeare-Qwen3-8B-Base.hf", "shakespeare"), + ] + + for dataset_path, style_name in styles: + run_path = base_path / style_name + if run_path.exists(): + print(f"Index already exists at {run_path}, skipping...") + continue + + print(f"Building index for {style_name}...") + cmd = [ + "bergson", + "build", + str(run_path), + "--model", + analysis_model, + "--dataset", + dataset_path, + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + "16", + "--token_batch_size", + "6000", + # NOTE: Do NOT skip preconditioners - we need them! + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError(f"bergson build failed for {style_name}") + print(result.stdout) + + # Also build combined index on merged Qwen-only dataset + combined_path = base_path / "combined" + if not combined_path.exists(): + # Ensure Qwen-only dataset exists + qwen_dataset_path = create_qwen_only_dataset() + + print("Building combined index...") + cmd = [ + "bergson", + "build", + str(combined_path), + "--model", + analysis_model, + "--dataset", + str(qwen_dataset_path), + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + "16", + "--token_batch_size", + "6000", + ] + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError("bergson build failed for combined") + print(result.stdout) + else: + print(f"Combined index already exists at {combined_path}, skipping...") + + +def compute_between_preconditioner_covariance( + pirate_path: Path | str, + shakespeare_path: Path | str, + combined_path: Path | str, + output_path: Path | str, +) -> GradientProcessor: + """Compute R_between = R_combined - (R_pirate + R_shakespeare) / 2. + + Mathematical reasoning: + - R_pirate and R_shakespeare capture within-class variance only + - R_combined captures within-class + between-class variance + - R_between = R_combined - R_within isolates the between-class component + + This captures the "style" direction that differs between pirate and shakespeare. + Preconditioning with this should downweight the style direction. + + Args: + pirate_path: Path to pirate style preconditioner. + shakespeare_path: Path to shakespeare style preconditioner. + combined_path: Path to combined preconditioner. + output_path: Path to save the between-class preconditioner. + + Returns: + The computed GradientProcessor. + """ + output_path = Path(output_path) + + # Check cache first + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached R_between (covariance) from {output_path}") + return GradientProcessor.load(output_path) + + print("Computing R_between preconditioner (covariance method)...") + pirate_proc = GradientProcessor.load(Path(pirate_path)) + shakespeare_proc = GradientProcessor.load(Path(shakespeare_path)) + combined_proc = GradientProcessor.load(Path(combined_path)) + + between_precs = {} + for name in pirate_proc.preconditioners: + R_pirate = pirate_proc.preconditioners[name] + R_shakespeare = shakespeare_proc.preconditioners[name] + R_combined = combined_proc.preconditioners[name] + + # R_within = average of within-class covariances + R_within = 0.5 * R_pirate + 0.5 * R_shakespeare + + # R_between = R_combined - R_within (isolates between-class variance) + between_precs[name] = R_combined - R_within + + # Create processor with required fields from one of the source processors + between_proc = GradientProcessor( + normalizers=pirate_proc.normalizers, + preconditioners=between_precs, + preconditioners_eigen={}, + projection_dim=pirate_proc.projection_dim, + projection_type=pirate_proc.projection_type, + include_bias=pirate_proc.include_bias, + ) + between_proc.save(output_path) + print(f"Saved R_between preconditioner to {output_path}") + return between_proc + + +def compute_between_preconditioner_means( + pirate_index_path: Path | str, + shakespeare_index_path: Path | str, + output_path: Path | str, +) -> GradientProcessor: + """Compute R_between = (mu_pirate - mu_shakespeare)(mu_pirate - mu_shakespeare)^T per module. + + This creates a rank-1 preconditioner from the difference in class means. + More targeted than the covariance method - captures exactly the "style direction". + + Works per-module to avoid OOM from creating the full outer product. + + Args: + pirate_index_path: Path to pirate gradient index. + shakespeare_index_path: Path to shakespeare gradient index. + output_path: Path to save the between-class preconditioner. + + Returns: + The computed GradientProcessor. + """ + output_path = Path(output_path) + + # Check cache first + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached R_between (means) from {output_path}") + return GradientProcessor.load(output_path) + + print("Computing R_between preconditioner (class means method)...") + + pirate_path = Path(pirate_index_path) + shakespeare_path = Path(shakespeare_index_path) + + # Load structured gradients (per-module) instead of flattened + print(" Loading pirate gradients (structured)...") + pirate_grads = load_gradients(pirate_path, structured=True) + + print(" Loading shakespeare gradients (structured)...") + shakespeare_grads = load_gradients(shakespeare_path, structured=True) + + # Load a processor to get module names and metadata + pirate_proc = GradientProcessor.load(pirate_path) + + # Compute per-module rank-1 preconditioners + between_precs = {} + module_names = list(pirate_proc.preconditioners.keys()) + + print(f" Computing per-module R_between for {len(module_names)} modules...") + for name in tqdm(module_names): + # Get gradients for this module (numpy structured array access) + pirate_mod = torch.from_numpy(pirate_grads[name].copy()).float() + shakespeare_mod = torch.from_numpy(shakespeare_grads[name].copy()).float() + + # Compute means + mu_pirate = pirate_mod.mean(dim=0) + mu_shakespeare = shakespeare_mod.mean(dim=0) + + # Style direction for this module + delta = mu_pirate - mu_shakespeare + + # Rank-1 preconditioner: outer product + between_precs[name] = torch.outer(delta, delta) + + between_proc = GradientProcessor( + normalizers=pirate_proc.normalizers, + preconditioners=between_precs, + preconditioners_eigen={}, + projection_dim=pirate_proc.projection_dim, + projection_type=pirate_proc.projection_type, + include_bias=pirate_proc.include_bias, + ) + between_proc.save(output_path) + print(f"Saved R_between preconditioner (means) to {output_path}") + return between_proc + + +# Default to the means-based approach as it's more targeted +compute_between_preconditioner = compute_between_preconditioner_means + + +def compute_mixed_preconditioner( + pirate_path: Path | str, + shakespeare_path: Path | str, + output_path: Path | str, +) -> GradientProcessor: + """Compute R_mixed = 0.5 * R_pirate + 0.5 * R_shakespeare. + + Args: + pirate_path: Path to pirate style preconditioner. + shakespeare_path: Path to shakespeare style preconditioner. + output_path: Path to save the mixed preconditioner. + + Returns: + The computed GradientProcessor. + """ + output_path = Path(output_path) + + # Check cache first + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached mixed preconditioner from {output_path}") + return GradientProcessor.load(output_path) + + print("Computing mixed 50-50 preconditioner...") + pirate_proc = GradientProcessor.load(Path(pirate_path)) + shakespeare_proc = GradientProcessor.load(Path(shakespeare_path)) + + mixed_precs = {} + for name in pirate_proc.preconditioners: + mixed_precs[name] = ( + 0.5 * pirate_proc.preconditioners[name] + + 0.5 * shakespeare_proc.preconditioners[name] + ) + + mixed_proc = GradientProcessor( + normalizers=pirate_proc.normalizers, + preconditioners=mixed_precs, + preconditioners_eigen={}, + projection_dim=pirate_proc.projection_dim, + projection_type=pirate_proc.projection_type, + include_bias=pirate_proc.include_bias, + ) + mixed_proc.save(output_path) + print(f"Saved mixed preconditioner to {output_path}") + return mixed_proc + + +def compute_summed_loss_preconditioner( + pirate_index_path: Path | str, + shakespeare_index_path: Path | str, + output_path: Path | str, +) -> GradientProcessor: + """Compute preconditioner from summed loss across style contrastive pairs. + + Instead of computing gradients separately and then averaging, this approach + conceptually sums the loss across contrastive pairs before computing gradients. + For paired samples with the same underlying fact but different styles: + - g_summed = g_pirate + g_shakespeare (for same fact) + - R_summed = sum over pairs of outer(g_summed, g_summed) + + This captures the common (semantic) direction by reinforcing what's shared. + + Args: + pirate_index_path: Path to pirate gradient index. + shakespeare_index_path: Path to shakespeare gradient index. + output_path: Path to save the preconditioner. + + Returns: + The computed GradientProcessor. + """ + from datasets import load_from_disk + + output_path = Path(output_path) + + # Check cache first + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached summed loss preconditioner from {output_path}") + return GradientProcessor.load(output_path) + + print("Computing summed loss preconditioner from style contrastive pairs...") + + pirate_path = Path(pirate_index_path) + shakespeare_path = Path(shakespeare_index_path) + + # Load structured gradients + print(" Loading pirate gradients...") + pirate_grads = load_gradients(pirate_path, structured=True) + print(" Loading shakespeare gradients...") + shakespeare_grads = load_gradients(shakespeare_path, structured=True) + + # Load datasets to match facts + pirate_ds = load_from_disk( + str(pirate_path.parent / "pirate" / "dataset") + if (pirate_path.parent / "pirate" / "dataset").exists() + else "data/facts_dataset_pirate-Qwen3-8B-Base.hf" + ) + shakespeare_ds = load_from_disk( + str(shakespeare_path.parent / "shakespeare" / "dataset") + if (shakespeare_path.parent / "shakespeare" / "dataset").exists() + else "data/facts_dataset_shakespeare-Qwen3-8B-Base.hf" + ) + + if hasattr(pirate_ds, "keys"): + pirate_ds = pirate_ds["train"] + if hasattr(shakespeare_ds, "keys"): + shakespeare_ds = shakespeare_ds["train"] + + # Build fact -> index mapping + pirate_facts = pirate_ds["fact"] + shakespeare_facts = shakespeare_ds["fact"] + + pirate_fact_to_idx = {f: i for i, f in enumerate(pirate_facts)} + shakespeare_fact_to_idx = {f: i for i, f in enumerate(shakespeare_facts)} + + # Find common facts (contrastive pairs) and build aligned index arrays + common_facts = list( + set(pirate_fact_to_idx.keys()) & set(shakespeare_fact_to_idx.keys()) + ) + pirate_indices = [pirate_fact_to_idx[f] for f in common_facts] + shakespeare_indices = [shakespeare_fact_to_idx[f] for f in common_facts] + print(f" Found {len(common_facts)} contrastive pairs") + + # Load a processor to get metadata + pirate_proc = GradientProcessor.load(pirate_path) + module_names = list(pirate_proc.preconditioners.keys()) + + # Compute per-module preconditioners from summed gradients (batched) + summed_precs = {} + print(f" Computing per-module preconditioners for {len(module_names)} modules...") + + for name in tqdm(module_names): + pirate_mod = torch.from_numpy(pirate_grads[name].copy()).float() + shakespeare_mod = torch.from_numpy(shakespeare_grads[name].copy()).float() + + # Extract aligned pairs using fancy indexing (batched) + g_pirate_aligned = pirate_mod[pirate_indices] # [n_pairs, d] + g_shakespeare_aligned = shakespeare_mod[shakespeare_indices] # [n_pairs, d] + + # Sum gradients across contrastive pairs + g_summed = g_pirate_aligned + g_shakespeare_aligned # [n_pairs, d] + + # Compute covariance: (1/n) * G^T @ G = sum of outer products / n + R = g_summed.T @ g_summed / len(common_facts) # [d, d] + summed_precs[name] = R + + summed_proc = GradientProcessor( + normalizers=pirate_proc.normalizers, + preconditioners=summed_precs, + preconditioners_eigen={}, + projection_dim=pirate_proc.projection_dim, + projection_type=pirate_proc.projection_type, + include_bias=pirate_proc.include_bias, + ) + output_path.mkdir(parents=True, exist_ok=True) + summed_proc.save(output_path) + print(f"Saved summed loss preconditioner to {output_path}") + return summed_proc + + +def compute_pca_style_subspace( + pirate_index_path: Path | str, + shakespeare_index_path: Path | str, + output_path: Path | str, + top_k: int = 10, +) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: + """Compute the style subspace from pairwise gradient differences using PCA. + + For each contrastive pair (same fact, different styles): + - Δg = g_pirate - g_shakespeare + Stacks all Δg into a matrix and computes PCA to find the top-k principal + components that capture the "style direction". + + Args: + pirate_index_path: Path to pirate gradient index. + shakespeare_index_path: Path to shakespeare gradient index. + output_path: Path to save the style subspace. + top_k: Number of top principal components to keep. + + Returns: + Dictionary mapping module names to (eigenvectors, eigenvalues) tuples. + eigenvectors has shape [d, k] where columns are the top-k style directions. + """ + from datasets import load_from_disk + + output_path = Path(output_path) + cache_file = output_path / f"style_subspace_k{top_k}.pth" + + if cache_file.exists(): + print(f"Loading cached style subspace from {cache_file}") + return torch.load(cache_file, weights_only=True) + + print(f"Computing style subspace via PCA (top_k={top_k})...") + + pirate_path = Path(pirate_index_path) + shakespeare_path = Path(shakespeare_index_path) + + # Load structured gradients + print(" Loading pirate gradients...") + pirate_grads = load_gradients(pirate_path, structured=True) + print(" Loading shakespeare gradients...") + shakespeare_grads = load_gradients(shakespeare_path, structured=True) + + # Load datasets to match facts + pirate_ds = load_from_disk("data/facts_dataset_pirate-Qwen3-8B-Base.hf") + shakespeare_ds = load_from_disk("data/facts_dataset_shakespeare-Qwen3-8B-Base.hf") + + if hasattr(pirate_ds, "keys"): + pirate_ds = pirate_ds["train"] + if hasattr(shakespeare_ds, "keys"): + shakespeare_ds = shakespeare_ds["train"] + + # Build fact -> index mapping + pirate_facts = pirate_ds["fact"] + shakespeare_facts = shakespeare_ds["fact"] + + pirate_fact_to_idx = {f: i for i, f in enumerate(pirate_facts)} + shakespeare_fact_to_idx = {f: i for i, f in enumerate(shakespeare_facts)} + + # Find common facts and build aligned index arrays + common_facts = list( + set(pirate_fact_to_idx.keys()) & set(shakespeare_fact_to_idx.keys()) + ) + pirate_indices = [pirate_fact_to_idx[f] for f in common_facts] + shakespeare_indices = [shakespeare_fact_to_idx[f] for f in common_facts] + print(f" Found {len(common_facts)} contrastive pairs") + + # Get module names from processor + pirate_proc = GradientProcessor.load(pirate_path) + module_names = list(pirate_proc.preconditioners.keys()) + + style_subspace = {} + print(f" Computing PCA for {len(module_names)} modules...") + + for name in tqdm(module_names): + pirate_mod = torch.from_numpy(pirate_grads[name].copy()).float() + shakespeare_mod = torch.from_numpy(shakespeare_grads[name].copy()).float() + + # Extract aligned pairs using fancy indexing (batched) + g_pirate_aligned = pirate_mod[pirate_indices] # [n_pairs, d] + g_shakespeare_aligned = shakespeare_mod[shakespeare_indices] # [n_pairs, d] + + # Compute gradient differences (batched) + diff_matrix = g_pirate_aligned - g_shakespeare_aligned # [n_pairs, d] + + # Center the differences (mean-subtract) + diff_centered = diff_matrix - diff_matrix.mean(dim=0, keepdim=True) + + # Compute covariance matrix: (1/n) * D^T @ D + n = diff_centered.shape[0] + cov = diff_centered.T @ diff_centered / n # [d, d] + + # Eigendecomposition (sorted ascending) + eigvals, eigvecs = torch.linalg.eigh(cov) + + # Get top-k (largest eigenvalues are at the end) + k = min(top_k, eigvals.shape[0]) + top_eigvals = eigvals[-k:].flip(0) # Descending order + top_eigvecs = eigvecs[:, -k:].flip( + 1 + ) # [d, k], columns are principal components + + style_subspace[name] = (top_eigvecs, top_eigvals) + + output_path.mkdir(parents=True, exist_ok=True) + torch.save(style_subspace, cache_file) + print(f"Saved style subspace to {cache_file}") + return style_subspace + + +def project_orthogonal_to_style_subspace( + grads: torch.Tensor, + style_eigenvecs: torch.Tensor, +) -> torch.Tensor: + """Project gradients onto the orthogonal complement of the style subspace. + + Given gradients g and style subspace basis V (columns are principal components), + computes: g_projected = g - V @ V^T @ g + + This removes the component of g that lies in the style subspace. + + Args: + grads: Gradient tensor of shape [n, d]. + style_eigenvecs: Eigenvectors defining style subspace, shape [d, k]. + + Returns: + Projected gradients of shape [n, d]. + """ + # V @ V^T is the projection matrix onto the style subspace + # I - V @ V^T is the projection onto the orthogonal complement + # g_proj = g - V @ (V^T @ g) + style_component = grads @ style_eigenvecs @ style_eigenvecs.T + return grads - style_component + + +def apply_pca_projection_to_eval_grads( + eval_grads: dict[str, torch.Tensor], + style_subspace: dict[str, tuple[torch.Tensor, torch.Tensor]], + device: torch.device | None = None, +) -> dict[str, torch.Tensor]: + """Apply PCA style projection to evaluation gradients. + + Projects eval gradients onto the orthogonal complement of the style subspace, + effectively removing the style direction before computing influence. + + Args: + eval_grads: Dictionary mapping module names to gradient tensors [n, d]. + style_subspace: Dictionary mapping module names to (eigenvecs, eigvals) tuples. + device: Device to use for computation. + + Returns: + Dictionary of projected gradients. + """ + projected = {} + for name, grads in eval_grads.items(): + if name in style_subspace: + eigvecs, _ = style_subspace[name] + if device is not None: + grads = grads.to(device) + eigvecs = eigvecs.to(device) + projected[name] = project_orthogonal_to_style_subspace(grads, eigvecs) + else: + projected[name] = grads + return projected + + +def compute_eval_preconditioner( + eval_grads_path: Path | str, + output_path: Path | str, + reference_proc_path: Path | str | None = None, +) -> GradientProcessor: + """Compute second moment matrix from eval gradients. + + R_eval = (1/n) * G_eval^T @ G_eval + + Args: + eval_grads_path: Path to eval gradients index. + output_path: Path to save the preconditioner. + reference_proc_path: Path to a reference processor for module names (if eval has none). + + Returns: + The computed GradientProcessor. + """ + import json + + output_path = Path(output_path) + + # Check cache first + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached eval preconditioner from {output_path}") + return GradientProcessor.load(output_path) + + print("Computing eval second moment preconditioner...") + + eval_path = Path(eval_grads_path) + + # Load structured gradients + print(" Loading eval gradients...") + eval_grads = load_gradients(eval_path, structured=True) + + # Get module names from info.json + with open(eval_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Load a reference processor to get metadata (use reference if eval doesn't have precs) + if reference_proc_path: + base_proc = GradientProcessor.load(Path(reference_proc_path)) + else: + base_proc = GradientProcessor.load(eval_path) + + # Compute per-module second moment matrices + eval_precs = {} + print(f" Computing per-module preconditioners for {len(module_names)} modules...") + + for name in tqdm(module_names): + g = torch.from_numpy(_load_gradients_as_float(eval_grads, name)) + n = g.shape[0] + # Second moment: (1/n) * G^T @ G + R = g.T @ g / n + eval_precs[name] = R + + eval_proc = GradientProcessor( + normalizers=base_proc.normalizers, + preconditioners=eval_precs, + preconditioners_eigen={}, + projection_dim=base_proc.projection_dim, + projection_type=base_proc.projection_type, + include_bias=base_proc.include_bias, + ) + output_path.mkdir(parents=True, exist_ok=True) + eval_proc.save(output_path) + print(f"Saved eval preconditioner to {output_path}") + return eval_proc + + +def compute_train_eval_mixed_preconditioner( + train_index_path: Path | str, + eval_grads_path: Path | str, + output_path: Path | str, + train_weight: float = 0.5, +) -> GradientProcessor: + """Compute 50:50 (or custom weighted) mixture of train and eval second moments. + + R_mixed = train_weight * R_train + (1 - train_weight) * R_eval + + Args: + train_index_path: Path to train gradients index. + eval_grads_path: Path to eval gradients index. + output_path: Path to save the preconditioner. + train_weight: Weight for train preconditioner (default 0.5). + + Returns: + The computed GradientProcessor. + """ + import json + + output_path = Path(output_path) + + # Check cache first + if (output_path / "preconditioners.pth").exists(): + print(f"Loading cached train-eval mixed preconditioner from {output_path}") + return GradientProcessor.load(output_path) + + print( + f"Computing train-eval mixed preconditioner ({train_weight:.0%} train, {1-train_weight:.0%} eval)..." + ) + + train_path = Path(train_index_path) + eval_path = Path(eval_grads_path) + + # Load structured gradients + print(" Loading train gradients...") + train_grads = load_gradients(train_path, structured=True) + print(" Loading eval gradients...") + eval_grads = load_gradients(eval_path, structured=True) + + # Load a processor to get metadata and module names + base_proc = GradientProcessor.load(train_path) + + # Get module names from info.json + with open(train_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + + # Compute per-module mixed second moment matrices + mixed_precs = {} + print(f" Computing per-module preconditioners for {len(module_names)} modules...") + + for name in tqdm(module_names): + g_train = torch.from_numpy(_load_gradients_as_float(train_grads, name)) + g_eval = torch.from_numpy(_load_gradients_as_float(eval_grads, name)) + + n_train = g_train.shape[0] + n_eval = g_eval.shape[0] + + # Second moments + R_train = g_train.T @ g_train / n_train + R_eval = g_eval.T @ g_eval / n_eval + + # Weighted mixture + R_mixed = train_weight * R_train + (1 - train_weight) * R_eval + mixed_precs[name] = R_mixed + + mixed_proc = GradientProcessor( + normalizers=base_proc.normalizers, + preconditioners=mixed_precs, + preconditioners_eigen={}, + projection_dim=base_proc.projection_dim, + projection_type=base_proc.projection_type, + include_bias=base_proc.include_bias, + ) + output_path.mkdir(parents=True, exist_ok=True) + mixed_proc.save(output_path) + print(f"Saved train-eval mixed preconditioner to {output_path}") + return mixed_proc diff --git a/examples/semantic/scoring.py b/examples/semantic/scoring.py new file mode 100644 index 00000000..34da210e --- /dev/null +++ b/examples/semantic/scoring.py @@ -0,0 +1,284 @@ +"""Score computation utilities for semantic experiments.""" + +import json +import subprocess +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +from bergson.data import load_gradients +from bergson.gradients import GradientProcessor +from bergson.utils.math import compute_damped_inverse + + +def load_scores_matrix(scores_path: Path | str) -> np.ndarray: + """Load the scores matrix from bergson score output as a dense array. + + Args: + scores_path: Path to the scores directory containing info.json and scores.bin. + + Returns: + Dense (num_items, num_scores) float32 array of scores. + """ + scores_path = Path(scores_path) + + with open(scores_path / "info.json") as f: + info = json.load(f) + + num_items = info["num_items"] + num_scores = info["num_scores"] + + # Handle both tuple format (from bergson) and list format (from JSON serialization) + dtype_spec = info["dtype"] + if ( + isinstance(dtype_spec, list) + and len(dtype_spec) > 0 + and isinstance(dtype_spec[0], list) + ): + # Convert list of lists back to list of tuples + dtype_spec = [tuple(item) for item in dtype_spec] + + scores_mmap = np.memmap( + scores_path / "scores.bin", + dtype=np.dtype(dtype_spec), + mode="r", + shape=(num_items,), + ) + + # Extract score columns into a dense matrix + scores = np.zeros((num_items, num_scores), dtype=np.float32) + for i in range(num_scores): + scores[:, i] = scores_mmap[f"score_{i}"] + + return scores + + +def compute_scores_fast( + index_path: Path | str, + output_path: Path | str, + preconditioner_path: Path | str | None = None, + unit_normalize: bool = True, + batch_size: int = 256, +) -> None: + """Compute pairwise similarities directly from precomputed gradients. + + Much faster than bergson score since it doesn't recompute gradients. + Loads gradients from index, applies preconditioning, and computes G @ G.T. + + Args: + index_path: Path to the gradient index. + output_path: Path to save scores. + preconditioner_path: Optional path to preconditioner for query gradients. + unit_normalize: Whether to unit normalize gradients before scoring. + batch_size: Batch size for score computation. + """ + output_path = Path(output_path) + index_path = Path(index_path) + + if output_path.exists(): + print(f"Scores already exist at {output_path}, skipping...") + return + + output_path.mkdir(parents=True, exist_ok=True) + + # Load gradients + print("Loading gradients from index...") + grads = load_gradients(index_path, structured=True) + + # Get module names + with open(index_path / "info.json") as f: + info = json.load(f) + module_names = info["dtype"]["names"] + n_samples = info["num_grads"] + + print(f" {n_samples} samples, {len(module_names)} modules") + + # Load and apply preconditioner if specified + if preconditioner_path: + preconditioner_path = Path(preconditioner_path) + print(f"Loading preconditioner from {preconditioner_path}...") + proc = GradientProcessor.load(preconditioner_path) + + # Compute H^(-1) for each module using the shared utility + h_inv = {} + device = torch.device("cuda:0") + for name in tqdm(module_names, desc="Computing H^(-1)"): + H = proc.preconditioners[name].to(device=device) + h_inv[name] = compute_damped_inverse(H) + + # Bergson's approach (from score.py): + # 1. Query: precondition with H^(-1), then unit normalize + # 2. Index: unit normalize (no preconditioning) + # 3. Score: index @ query.T + print("Loading gradients...") + all_grads_raw = [] + for name in tqdm(module_names, desc="Loading gradients"): + g = torch.from_numpy(grads[name].copy()).float() + all_grads_raw.append(g) + + # Apply H^(-1) to query gradients first (before normalization) + print("Applying H^(-1) to query gradients...") + all_grads_query = [] + for name, g in zip(module_names, all_grads_raw): + g_precond = (g.to(device) @ h_inv[name]).cpu() + all_grads_query.append(g_precond) + all_grads_query = torch.cat(all_grads_query, dim=1) + all_grads_raw = torch.cat(all_grads_raw, dim=1) + print(f"Gradient matrix shape: {all_grads_raw.shape}") + + # Unit normalize after preconditioning (for query) and raw (for index) + if unit_normalize: + print("Unit normalizing gradients...") + # Normalize preconditioned query + query_norms = all_grads_query.norm(dim=1, keepdim=True) + all_grads_query = all_grads_query / (query_norms + 1e-8) + # Normalize raw index + index_norms = all_grads_raw.norm(dim=1, keepdim=True) + all_grads_index = all_grads_raw / (index_norms + 1e-8) + else: + all_grads_index = all_grads_raw + + # Score: index (normalized) @ query (preconditioned then normalized).T + print("Computing pairwise similarities...") + all_grads_index = all_grads_index.cuda() + all_grads_query = all_grads_query.cuda() + + scores = torch.zeros(n_samples, n_samples, dtype=torch.float32) + for i in tqdm(range(0, n_samples, batch_size), desc="Scoring"): + batch = all_grads_index[i : i + batch_size] + scores[i : i + batch_size] = (batch @ all_grads_query.T).cpu() + else: + # No preconditioning - just concatenate modules + print("Concatenating gradients (no preconditioning)...") + all_grads = torch.from_numpy( + load_gradients(index_path, structured=False).copy() + ).float() + + print(f"Gradient matrix shape: {all_grads.shape}") + + # Unit normalize if requested + if unit_normalize: + print("Unit normalizing gradients...") + norms = all_grads.norm(dim=1, keepdim=True) + all_grads = all_grads / (norms + 1e-8) + + # Compute pairwise similarities in batches (G @ G.T) + print("Computing pairwise similarities...") + all_grads = all_grads.cuda() + + scores = torch.zeros(n_samples, n_samples, dtype=torch.float32) + for i in tqdm(range(0, n_samples, batch_size), desc="Scoring"): + batch = all_grads[i : i + batch_size] + scores[i : i + batch_size] = (batch @ all_grads.T).cpu() + + # Save in bergson score format + print(f"Saving scores to {output_path}...") + + # Create structured dtype for scores + score_dtype_list = [(f"score_{i}", " None: + """Run bergson score to compute pairwise similarities. + + NOTE: This recomputes gradients, which is slow. For index-vs-index + scoring, use compute_scores_fast() instead. + + Args: + index_path: Path to the gradient index. + output_path: Path to save scores. + query_preconditioner_path: Optional path to query preconditioner. + index_preconditioner_path: Optional path to index preconditioner. + unit_normalize: Whether to unit normalize gradients. + """ + output_path = Path(output_path) + index_path = Path(index_path) + + if output_path.exists(): + print(f"Scores already exist at {output_path}, skipping...") + return + + # Load index config to get model and dataset info + with open(index_path / "index_config.json") as f: + index_cfg = json.load(f) + + # Get dataset and column info from config + data_cfg = index_cfg.get("data", {}) + dataset_path = data_cfg.get("dataset", str(index_path / "data.hf")) + prompt_column = data_cfg.get("prompt_column", "text") + completion_column = data_cfg.get("completion_column", "") + + cmd = [ + "bergson", + "score", + str(output_path), + "--model", + index_cfg["model"], + "--dataset", + dataset_path, + "--query_path", + str(index_path), + "--score", + "individual", + "--projection_dim", + str(index_cfg.get("projection_dim", 0)), + "--fsdp", + "--prompt_column", + prompt_column, + ] + + if completion_column: + cmd.extend(["--completion_column", completion_column]) + + if unit_normalize: + cmd.append("--unit_normalize") + + if query_preconditioner_path: + cmd.extend(["--query_preconditioner_path", query_preconditioner_path]) + + if index_preconditioner_path: + cmd.extend(["--index_preconditioner_path", index_preconditioner_path]) + + print("Running:", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + raise RuntimeError(f"bergson score failed with return code {result.returncode}") + print(result.stdout) diff --git a/examples/semantics_experiment.py b/examples/semantics_experiment.py new file mode 100644 index 00000000..b4b30db1 --- /dev/null +++ b/examples/semantics_experiment.py @@ -0,0 +1,68 @@ +import subprocess +from pathlib import Path + +import torch +from datasets import load_dataset + +from bergson import load_gradient_dataset + +dataset = load_dataset("HuggingFaceH4/MATH-500", split="test") + +# Build Bergson index +run_path = Path("runs/math-500/gemma") +cmd = [ + "bergson", + "build", + str(run_path), + "--model", + "google/gemma-3-4b-it", + "--dataset", + "HuggingFaceH4/MATH-500", + "--drop_columns", + "False", + "--split", + "test", + "--prompt_column", + "problem", + "--completion_column", + "answer", +] +print(" ".join(cmd)) + +if not run_path.exists(): + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + +# Check whether items with the same subject value have a greater cosine similarity score +# Than items from dissimilar subjects + +gradient_ds = load_gradient_dataset(run_path, structured=False) + +subjects = gradient_ds["subject"] + +# Compute cosine similarity between all items' gradients +gradients = torch.tensor(gradient_ds["gradients"], device="cuda") +gradients /= gradients.norm(dim=1, keepdim=True) +similarities = gradients @ gradients.T + + +# Check whether items with the same subject value have a greater cosine similarity score +# Than items from dissimilar subjects +intra_subject_similarities = [] +inter_subject_similarities = [] + +for i in range(len(gradients)): + for j in range(i + 1, len(gradients)): + if subjects[i] == subjects[j]: + intra_subject_similarities.append(similarities[i, j]) + else: + inter_subject_similarities.append(similarities[i, j]) + + +mean_intra_subject_similarity = torch.mean(torch.tensor(intra_subject_similarities)) +mean_inter_subject_similarity = torch.mean(torch.tensor(inter_subject_similarities)) +print(f"Intra-subject similarity mean: {mean_intra_subject_similarity}") +print(f"Inter-subject similarity mean: {mean_inter_subject_similarity}") + +breakpoint() diff --git a/examples/slurm/data_parallel_score.sh b/examples/slurm/data_parallel_score.sh index 4254d804..9b8a47f8 100644 --- a/examples/slurm/data_parallel_score.sh +++ b/examples/slurm/data_parallel_score.sh @@ -16,7 +16,7 @@ hf auth login --token NUM_NODES=64 RUN_NAME="bergson_score" -TOTAL_EXAMPLES=$(cat dataset_size.txt) +TOTAL_EXAMPLES=100_000_000 EXAMPLES_PER_NODE=$((TOTAL_EXAMPLES / NUM_NODES)) # Export variables for the worker script diff --git a/examples/train_lora.py b/examples/train_lora.py new file mode 100644 index 00000000..66689240 --- /dev/null +++ b/examples/train_lora.py @@ -0,0 +1,311 @@ +import os +from typing import List, Literal, Optional, Union + +import backoff +import torch +import torch.distributed as dist +from datasets import Dataset +from peft import LoraConfig, prepare_model_for_kbit_training +from pydantic import BaseModel, Field, field_validator +from torch.utils.data import SequentialSampler +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from trl import SFTConfig, SFTTrainer + +from bergson.config import DataConfig, IndexConfig +from bergson.utils.worker_utils import setup_data_pipeline + + +class TrainingConfig(BaseModel): + class Config: + extra = "forbid" # Prevent extra fields not defined in the model + + # Required model and data paths + model: str = Field(..., description="Hugging Face model ID") + dataset: str = Field(..., description="Dataset") + split: str = Field(..., description="Split") + + prompt_column: str = Field("prompt", description="Prompt column") + completion_column: str = Field("completion", description="Completion column") + + # Training type configuration + loss: Literal["dpo", "orpo", "sft"] = Field( + ..., description="Loss function / training type" + ) + + # Output model + finetuned_model_id: Optional[str] = Field( + None, description="File ID of the finetuned model" + ) + + # Model configuration + max_seq_length: int = Field( + 2048, description="Maximum sequence length for training" + ) + load_in_4bit: bool = Field( + False, description="Whether to load model in 4-bit quantization" + ) + + # PEFT configuration + is_peft: bool = Field(True, description="Whether to use PEFT for training") + target_modules: Optional[List[str]] = Field( + default=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + description="Target modules for LoRA", + ) + lora_bias: Literal["all", "none"] = Field( + "none", description="Value for FastLanguageModel.get_peft_model(bias=?)" + ) + + # LoRA specific arguments + r: int = Field(16, description="LoRA attention dimension") + lora_alpha: int = Field(16, description="LoRA alpha parameter") + lora_dropout: float = Field(0.0, description="LoRA dropout rate") + use_rslora: bool = Field(True, description="Whether to use RSLoRA") + merge_before_push: bool = Field( + True, + # description="Whether to merge model before pushing to Hub. Only merged models + # can be used as parent models for further finetunes. Only supported for + # bf16 models.", + ) + push_to_private: bool = Field(True, description="Whether to push to private Hub") + + # Training hyperparameters + epochs: int = Field(1, description="Number of training epochs") + max_steps: int = Field(-1, description="Maximum number of training steps") + per_device_train_batch_size: int = Field( + 2, description="Training batch size per device" + ) + gradient_accumulation_steps: int = Field( + 8, description="Number of gradient accumulation steps" + ) + warmup_steps: int = Field(5, description="Number of warmup steps") + learning_rate: Union[float, str] = Field( + 1e-4, description="Learning rate or string expression" + ) + logging_steps: int = Field(1, description="Number of steps between logging") + optim: str = Field("adamw_8bit", description="Optimizer to use for training") + weight_decay: float = Field(0.01, description="Weight decay rate") + lr_scheduler_type: str = Field("linear", description="Learning rate scheduler type") + seed: Optional[int] = Field(None, description="Random seed for reproducibility") + save_steps: int = Field(5000, description="Save checkpoint every X steps") + output_dir: str = Field( + "./tmp", description="Output directory for training checkpoints" + ) + + @field_validator("finetuned_model_id") + def validate_finetuned_model_id(cls, v): + # if v and model_exists(v): + # raise ValueError(f"Model {v} already exists") + if len(v.split("/")) != 2: + raise ValueError("Model ID must be in the format 'user/model'") + org, model = v.split("/") + if org in ["datasets", "models", "unsloth", "None"]: + raise ValueError( + f"You have set org={org}, but it must be an org you have access to" + ) + return v + + @field_validator("learning_rate", mode="before") + def validate_learning_rate(cls, v): + if isinstance(v, float) and v <= 0: + raise ValueError("Learning rate must be positive") + return v + + @field_validator("lora_dropout") + def validate_dropout(cls, v): + if not 0 <= v <= 1: + raise ValueError("Dropout rate must be between 0 and 1") + return v + + @field_validator("lr_scheduler_type") + def validate_scheduler(cls, v): + allowed_schedulers = [ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + ] + if v not in allowed_schedulers: + raise ValueError(f"Scheduler must be one of {allowed_schedulers}") + return v + + +# def process(df, prompt_column: str = "prompt", completion_column: str = "completion"): +# def format_chat_data(example): +# old_example = example +# example["prompt"] = [{"role": "user", "content": old_example[prompt_column]}] +# example["completion"] = [ +# {"role": "assistant", "content": old_example[completion_column]} +# ] +# return example + +# df = df.map(format_chat_data) +# return df + + +class NoShuffleSFTTrainer(SFTTrainer): + def _get_train_sampler(self, dataset): # <-- Add 'dataset' parameter + sampler = SequentialSampler(dataset) + + return sampler + + +def train(training_cfg: TrainingConfig, dataset: Dataset): + """Prepare lora model, call training function, and push to hub""" + + if rank := os.environ.get("LOCAL_RANK"): + rank = int(rank) + dist.init_process_group("nccl", device_id=torch.device(f"cuda:{rank}")) + else: + rank = 0 + + print("Creating new LoRA adapter") + target_modules = training_cfg.target_modules + model = AutoModelForCausalLM.from_pretrained( + training_cfg.model, + device_map={"": f"cuda:{rank}"}, + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + ), + ) + tokenizer = AutoTokenizer.from_pretrained( + training_cfg.model, token=os.environ.get("HF_TOKEN"), max_length=2048 + ) + # Prepare for k-bit training + model = prepare_model_for_kbit_training(model) + + # 3. Define LoRA config + peft_config = LoraConfig( + r=training_cfg.r, + lora_alpha=training_cfg.lora_alpha, + target_modules=target_modules, + lora_dropout=training_cfg.lora_dropout, + use_rslora=training_cfg.use_rslora, + bias=training_cfg.lora_bias, + task_type="CAUSAL_LM", + ) + + # dataset = process( + # dataset, + # prompt_column=training_cfg.prompt_column, + # completion_column=training_cfg.completion_column, + # ) + if training_cfg.seed is not None: + dataset = dataset.shuffle(seed=training_cfg.seed) + + trainer = NoShuffleSFTTrainer( + model=model, + train_dataset=dataset, + args=SFTConfig( + completion_only_loss=True, + ddp_find_unused_parameters=False, + fp16=True, + gradient_accumulation_steps=training_cfg.gradient_accumulation_steps, + learning_rate=training_cfg.learning_rate, + logging_steps=1, + lr_scheduler_type=training_cfg.lr_scheduler_type, + max_length=training_cfg.max_seq_length, + max_steps=training_cfg.max_steps, + num_train_epochs=training_cfg.epochs, + label_names=["labels"], + optim=training_cfg.optim, + output_dir=training_cfg.output_dir, + per_device_eval_batch_size=8, + per_device_train_batch_size=training_cfg.per_device_train_batch_size, + report_to=None, + save_steps=training_cfg.save_steps, + warmup_steps=training_cfg.warmup_steps, + weight_decay=training_cfg.weight_decay, + ), + peft_config=peft_config, + callbacks=[], + ) + trainer.train() + + if rank == 0: + if training_cfg.finetuned_model_id is not None: + push_model(training_cfg, training_cfg.finetuned_model_id, model, tokenizer) + + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +@backoff.on_exception(backoff.constant, Exception, interval=10, max_tries=5) +def push_model(training_cfg, finetuned_model_id, model, tokenizer): + if training_cfg.merge_before_push: + model.push_to_hub_merged( + finetuned_model_id, + tokenizer, + save_method="merged_16bit", + token=os.environ["HF_TOKEN"], + private=training_cfg.push_to_private, + ) + else: + model.push_to_hub( + finetuned_model_id, + token=os.environ["HF_TOKEN"], + private=training_cfg.push_to_private, + ) + tokenizer.push_to_hub( + finetuned_model_id, + token=os.environ["HF_TOKEN"], + private=training_cfg.push_to_private, + ) + + +def main(): + from argparse import ArgumentParser + + parser = ArgumentParser() + # model_name = "Qwen/Qwen2.5-7B" + parser.add_argument("--finetuned_model_path", type=str, default="finetuned-model") + parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B") + parser.add_argument("--dataset_name", type=str, default="HuggingFaceH4/MATH-500") + parser.add_argument("--split", type=str, default="test") + parser.add_argument("--prompt_column", type=str, default="prompt") + parser.add_argument("--completion_column", type=str, default="completion") + parser.add_argument( + "--no_push_to_private", action="store_false", dest="push_to_private" + ) + + args = parser.parse_args() + + training_config = TrainingConfig( # type: ignore + finetuned_model_id=args.finetuned_model_path, # type: ignore + model=args.model_name, # type: ignore + dataset=args.dataset_name, # type: ignore + split=args.split, # type: ignore + loss="sft", # type: ignore + prompt_column=args.prompt_column, # type: ignore + completion_column=args.completion_column, # type: ignore + merge_before_push=False, + push_to_private=args.push_to_private, + ) # type: ignore + + dataset = setup_data_pipeline( + IndexConfig( + run_path=f"runs/{args.finetuned_model_path}", + model=args.model_name, + data=DataConfig( + dataset=args.dataset_name, + split=args.split, + prompt_column=args.prompt_column, + completion_column=args.completion_column, + ), + ) + ) + train(training_config, dataset) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 33fe0f8e..a34608e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ benchmarks = [ "dattri", ] example = [ + "backoff>=2.2.1", + "bitsandbytes>=0.49.0", + "pydantic>=2.12.5", "trl", ] faiss = [ diff --git a/skills/asymmetric-style.md b/skills/asymmetric-style.md new file mode 100644 index 00000000..270bc5d3 --- /dev/null +++ b/skills/asymmetric-style.md @@ -0,0 +1,230 @@ +# Asymmetric Style Suppression Experiment + +**Goal**: Test whether gradient-based data attribution can find semantically matching training examples when the query is in a different style than the training data. + +This simulates a realistic scenario: your training data is mostly in one style (e.g., 95% formal/shakespeare), but users query in a different style (e.g., casual/pirate). Without intervention, gradient similarity is dominated by style—queries match training examples with similar style rather than similar content. This experiment evaluates strategies (preconditioners, PCA, gradient summing) to suppress style and recover semantic matching. + +## Usage + +``` +/asymmetric-style [options] +``` + +Options: +- `--base-path PATH` - Output directory (default: runs/asymmetric_style) +- `--recompute` - Clear cached results and recompute from scratch +- `--inner-product` - Use raw inner product instead of cosine similarity +- `--sweep-pca` - Sweep PCA k values with preconditioner combinations +- `--rewrite-ablation` - Run rewrite ablation (summed rewrites vs summed eval) +- `--summary` - Just print summary of existing results + +## What this does + +1. Creates asymmetric train/eval split: + - Train: 95% shakespeare (dominant), 5% pirate (minority) + - Eval: pirate style queries for facts only in shakespeare style in train +2. Tests whether gradient-based attribution can find semantic matches despite style mismatch +3. Compares strategies: baseline, preconditioners (R_between, H_eval, H_train), PCA projection, summed gradients + +## Strategies Tested + +### Baseline +- **no_precond**: Raw cosine similarity between query and training gradients. Expected to fail because style dominates the gradient representation. + +### Preconditioners +Transform gradients by `g' = g @ H^(-1)` before computing similarity, downweighting certain directions. + +- **R_between**: Computed from the difference between style means on **training data**: `delta = mean(shakespeare_train) - mean(pirate_train)`, then `R = delta @ delta.T`. This is a rank-1 matrix that captures the "style direction". + + *Dataset*: Training set (95% shakespeare, 5% pirate). The shakespeare mean is computed over ~950 samples, pirate mean over ~50 samples. + + Hypothesis: inverting this downweights the style axis, exposing semantic signal. + +- **H_eval**: Second moment of eval gradients: `H = (1/n) * G_eval.T @ G_eval`. Hypothesis: directions that vary a lot in the eval set (which is all one style) might be style-related, so downweighting high-variance eval directions could help. + +- **H_train**: Second moment of training gradients: `H = (1/n) * G_train.T @ G_train`. This has theoretical grounding from influence functions: `g_eval @ H^{-1} @ g_train.T` approximates the change in eval loss from upweighting a training point (second-order Taylor expansion). So H_train is the "correct" similarity metric for influence-based attribution. + +### Dimensionality Reduction +- **PCA projection**: Compute pairwise differences between corresponding shakespeare/pirate gradients (same underlying fact, different style), then PCA on those difference vectors. Project out the top-k components of this "style difference" subspace. + + Hypothesis: the difference `g_shakespeare(fact) - g_pirate(fact)` isolates pure style variation (content is held constant). The top PCs of these differences capture the dominant style directions. Projecting them out should remove style signal while preserving semantic content. + +### Gradient Averaging +- **summed_eval**: For each query, compute gradients in both styles (pirate + shakespeare), then sum them. Hypothesis: style-specific components cancel out, leaving semantic signal. This requires generating the query in multiple styles. + +- **summed_rewrites**: Sum gradients from two non-training styles (e.g., shakespeare + pirate rewrites of the same fact, when training only has formal). Tests whether style cancellation is general or requires matching training distribution. + +### Controls +- **majority_no_precond**: Query in the majority (shakespeare) style—no style mismatch. This is the upper bound showing what's achievable when styles match. + +## Instructions + +### Run full experiment (using HuggingFace data) + +The easiest way to run the experiment is using pre-generated data from HuggingFace: + +```python +from examples.semantic.asymmetric import run_asymmetric_experiment, AsymmetricConfig + +# Use HF dataset - no local generation needed +config = AsymmetricConfig( + hf_dataset="EleutherAI/bergson-asymmetric-style", +) + +results = run_asymmetric_experiment( + config=config, + base_path="runs/asymmetric_style", + # analysis_model defaults to EleutherAI/bergson-asymmetric-style-qwen3-8b-lora +) +``` + +### Run full experiment (generate locally) + +To generate fresh data locally (requires Qwen model for rewording): + +```python +from examples.semantic.asymmetric import run_asymmetric_experiment, AsymmetricConfig + +config = AsymmetricConfig( + dominant_style="shakespeare", + minority_style="pirate", + dominant_ratio=0.95, +) + +results = run_asymmetric_experiment( + config=config, + base_path="runs/asymmetric_style", +) +``` + +### Run PCA k-value sweep + +```python +from examples.semantic.asymmetric import sweep_pca_k + +results = sweep_pca_k( + base_path="runs/asymmetric_style", + k_values=[1, 5, 10, 20, 50, 100], + preconditioners=[None, "index"], +) +``` + +### Run rewrite ablation + +Tests whether summing two non-training styles helps (it doesn't): + +```python +from examples.semantic.asymmetric import run_rewrite_ablation_experiment + +results = run_rewrite_ablation_experiment(base_path="runs/asymmetric_style") +``` + +### Run inner product comparison + +Compare cosine similarity vs raw inner product: + +```python +from examples.semantic.asymmetric import run_inner_product_comparison + +results = run_inner_product_comparison(base_path="runs/asymmetric_style") +``` + +### Print existing results summary + +```python +import json +import numpy as np +from pathlib import Path +from datasets import load_from_disk + +base_path = Path("runs/asymmetric_style") +with open(base_path / "experiment_results.json") as f: + results = json.load(f) + +sorted_results = sorted(results.items(), key=lambda x: -x[1]["top1_semantic"]) +print(f"{'Strategy':<35} {'Top-1 Sem':<12} {'Top-1 Leak':<12} {'Exact':<10}") +print("-" * 70) +for name, m in sorted_results: + print(f"{name:<35} {m['top1_semantic']:<12.2%} {m['top1_leak']:<12.2%} {m['exact']:<10.2%}") +``` + +## Cached Data + +The experiment caches intermediate results to avoid recomputation: + +``` +runs/asymmetric_style/ +├── data/ +│ ├── train.hf # Training set (95% shakespeare, 5% pirate) +│ ├── eval.hf # Eval set (pirate style) +│ ├── eval_majority.hf # Eval in majority style (control) +│ ├── eval_summed.hf # Eval with summed gradients +│ └── rewrites/ # Additional style rewrites for ablations +├── index/ # Training gradients +├── eval_grads/ # Eval gradients (minority style) +├── eval_grads_majority/ # Eval gradients (majority style) +├── preconditioners/ # Various preconditioner matrices +├── scores_*/ # Score matrices for each strategy +└── experiment_results.json # Cached metrics summary +``` + +**What each cache level means:** +- `data/` - Dataset creation and Qwen rewording (~10-20 min) +- `index/` - bergson build for training gradients (~2 min) +- `eval_grads*/` - bergson build for eval gradients (~1 min each) +- `preconditioners/` - Preconditioner computation (~30 sec) +- `scores_*/` - Score computation (~10 sec each) +- `experiment_results.json` - Metrics computed from scores + +If the user specifies `--recompute`, first delete cached data: +```bash +rm -rf runs/asymmetric_style/index runs/asymmetric_style/eval_grads* runs/asymmetric_style/scores_* runs/asymmetric_style/preconditioners +``` + +To recompute everything including data: +```bash +rm -rf runs/asymmetric_style/ +``` + +## Key Metrics + +- **Top-1 Semantic Accuracy**: Top match has same underlying fact (higher is better) +- **Top-1 Style Leakage**: Top match is minority style (lower is better - means not style matching) +- **Exact Match**: Same fact AND dominant style (higher is better) + +## Datasets & Models + +The datasets and fine-tuned model for this experiment are available on Hugging Face: + +- **Dataset**: [EleutherAI/bergson-asymmetric-style](https://huggingface.co/datasets/EleutherAI/bergson-asymmetric-style) + - `train`: 13,500 samples (95% shakespeare, 5% pirate) + - `eval`: 4,500 samples (pirate style queries) + - `eval_majority_style`: 4,500 samples (shakespeare style control) + - `eval_original_style`: 4,500 samples (unstyled) + - `eval_pirate_style`: 4,500 samples (pirate style variant) + +- **Model**: [EleutherAI/bergson-asymmetric-style-qwen3-8b-lora](https://huggingface.co/EleutherAI/bergson-asymmetric-style-qwen3-8b-lora) + - LoRA adapter for Qwen/Qwen3-8B-Base + - Used as the `analysis_model` for gradient collection + +## Key Findings + +| Strategy | Top-1 Semantic | Notes | +|----------|---------------|-------| +| majority_no_precond | 100.00% | Control: no style mismatch | +| summed_eval | 92.71% | Sum minority + majority style eval grads | +| summed_rewrites | 0.87% | Sum shakespeare + pirate (both non-training) | +| no_precond (baseline) | 0.87% | Pure style matching dominates | +| preconditioners | ~1-1.4% | Marginal improvement | + +**Main insight**: summed_eval works because one component matches training distribution, not because of general style cancellation. + +## Similarity Metric Comparison + +With cosine similarity (my experiments): +- summed_eval: 92.71% + +With raw inner product (bergson default): +- summed_eval: 76.91% + +Cosine similarity helps by removing gradient magnitude as a confounding factor. diff --git a/skills/attribute-preservation.md b/skills/attribute-preservation.md new file mode 100644 index 00000000..4877cba4 --- /dev/null +++ b/skills/attribute-preservation.md @@ -0,0 +1,168 @@ +# Attribute Preservation Experiment + +**Goal**: Test whether style suppression preconditioners can remove stylistic signal from gradient embeddings while preserving the ability to match on semantic attributes (occupation, employer type, etc.). + +The core challenge is that gradient-based data attribution tends to match based on surface-level features like writing style rather than underlying content. This experiment creates synthetic data with correlated attributes (e.g., scientists work at research labs, business people work at banks) and tests whether we can surgically remove style signal without damaging these attribute-based matching capabilities. + +## Usage + +``` +/attribute-preservation [options] +``` + +Options: +- `--base-path PATH` - Output directory (default: runs/attribute_preservation) +- `--no-h-eval` - Skip H_eval preconditioner comparison +- `--no-majority` - Skip majority style control +- `--recompute` - Clear cached results and recompute from scratch + +## What this does + +1. Creates a synthetic dataset with occupational clusters (scientists, business, creative) +2. Each cluster has correlated attributes (employers, universities, degrees, titles) +3. Styles are assigned by occupation (scientists→shakespeare, business→pirate, creative→shakespeare) +4. Eval set: scientists in "wrong" style (pirate) to test style suppression +5. Compares preconditioner strategies: none, R_between, H_eval +6. Majority control: scientists in matching style (shakespeare) as upper bound + +## Strategies Tested + +### Baseline +- **no_precond**: Raw cosine similarity between query and training gradients. Expected to mostly match based on style (pirate queries → pirate training examples) rather than occupation. + +### Preconditioners +Transform gradients by `g' = g @ H^(-1)` before computing similarity. + +- **R_between**: Computed from training data style means: `delta = mean(shakespeare_grads) - mean(pirate_grads)`, then `R = delta @ delta.T`. This rank-1 matrix captures the "style direction" in gradient space. + + *Dataset*: Training set with style-occupation mapping: + - shakespeare mean = scientists (400) + creatives (400) = 800 samples + - pirate mean = business (400) = 400 samples + + *Caveat*: Because shakespeare mixes two occupations, the "style direction" is actually `(scientists + creatives)/2 - business`, which conflates style with occupation. This is meant to represent a situation where you can't rewrite scientist data in different styles, and have to work with different styles that already exist in the data. If you can rewrite, majority_no_precond may be the best option. + + Hypothesis: preconditioning with R^(-1) shrinks the style axis, allowing occupation signal to dominate. + +- **H_eval**: Second moment of eval gradients: `H = (1/n) * G_eval.T @ G_eval`. Hypothesis: the eval set is all scientists in pirate style, so directions with high variance in eval might capture style-independent scientist features. Downweighting these could paradoxically help by normalizing the representation. + +### Controls +- **majority_no_precond**: Scientists queried in shakespeare (their training style)—no style mismatch. This shows the upper bound: how well can we match occupation when style isn't a confounder? The gap between this and preconditioned results shows how much room for improvement remains. + +## Why This Experiment Matters + +Previous experiments showed preconditioners have minimal effect on fact-level retrieval. But maybe they work for coarser attribute matching? This tests whether style suppression preserves the ability to match "scientists to scientists" even if it can't match "Alice's employer fact to Alice's employer fact". + +## Instructions + +### Run full experiment (using HuggingFace data) + +The easiest way to run the experiment is using pre-generated data from HuggingFace: + +```python +from examples.semantic.attribute_preservation import ( + run_attribute_preservation_experiment, + AttributePreservationConfig, +) + +# Use HF dataset - no local generation needed +config = AttributePreservationConfig( + hf_dataset="EleutherAI/bergson-attribute-preservation", +) + +results = run_attribute_preservation_experiment( + config=config, + base_path='runs/attribute_preservation', + # analysis_model defaults to EleutherAI/bergson-asymmetric-style-qwen3-8b-lora + include_h_eval=True, + include_majority_control=True +) +``` + +### Run full experiment (generate locally) + +To generate fresh data locally (requires Qwen model for rewording): + +```python +from examples.semantic.attribute_preservation import run_attribute_preservation_experiment + +results = run_attribute_preservation_experiment( + base_path='runs/attribute_preservation', + reword_model='Qwen/Qwen3-8B-Base', + include_h_eval=True, + include_majority_control=True +) +``` + +## Cached Data + +The experiment caches intermediate results to avoid recomputation: + +``` +runs/attribute_preservation/ +├── data/ +│ ├── base_train.hf # Raw facts (no style) +│ ├── base_eval.hf # Raw eval facts +│ ├── train_shakespeare.hf # Reworded train (shakespeare) +│ ├── train_pirate.hf # Reworded train (pirate) +│ ├── train.hf # Combined styled training set +│ ├── eval_pirate.hf # Eval in minority style +│ ├── eval.hf # Final eval set +│ └── eval_majority.hf # Eval in majority style (control) +├── index/ # Training gradients (bergson build) +├── eval_grads/ # Eval gradients (minority style) +├── eval_grads_majority/ # Eval gradients (majority style) +├── r_between/ # R_between preconditioner +├── h_eval/ # H_eval preconditioner +├── scores_no_precond/ # Score matrix (no preconditioner) +├── scores_r_between/ # Score matrix (R_between) +├── scores_h_eval/ # Score matrix (H_eval) +└── scores_majority_no_precond/ # Score matrix (majority control) +``` + +**What each cache level means:** +- `data/` - Regenerating requires re-running Qwen rewording (~10 min) +- `index/` - Regenerating requires re-running bergson build (~2 min) +- `eval_grads*/` - Regenerating requires re-running bergson build (~1 min each) +- `r_between/`, `h_eval/` - Preconditioner computation (~30 sec each) +- `scores_*/` - Score computation (~10 sec each) + +If the user specifies `--recompute`, first delete cached data: +```bash +rm -rf runs/attribute_preservation/index runs/attribute_preservation/eval_grads* runs/attribute_preservation/scores_* runs/attribute_preservation/r_between runs/attribute_preservation/h_eval +``` + +To recompute everything including data rewording: +```bash +rm -rf runs/attribute_preservation/ +``` + +## Key Metrics + +- **Occupation Accuracy**: How often top-1 match has same occupation cluster (higher is better) +- **Style-Only Match**: Style matches but occupation doesn't (lower is better) +- **Trade-off**: Occ Acc - Style Only (higher is better) + +## Datasets & Models + +The dataset and fine-tuned model for this experiment are available on Hugging Face: + +- **Dataset**: [EleutherAI/bergson-attribute-preservation](https://huggingface.co/datasets/EleutherAI/bergson-attribute-preservation) + - `train`: 1,200 samples (scientists + business + creative occupations with correlated styles) + - `eval`: 400 samples (scientists in pirate style - "wrong" style) + - `eval_majority`: 400 samples (scientists in shakespeare style - control) + +- **Model**: [EleutherAI/bergson-asymmetric-style-qwen3-8b-lora](https://huggingface.co/EleutherAI/bergson-asymmetric-style-qwen3-8b-lora) + - LoRA adapter for Qwen/Qwen3-8B-Base + - Used as the `analysis_model` for gradient collection + +## Expected Output + +A summary table comparing strategies: +``` +Strategy Fact Acc Occ Acc Style Only Trade-off +--------------------------------------------------------------------------- +no_precond 0.25% 7.75% 89.75% -82.00% +r_between 0.50% 12.25% 84.00% -71.75% +h_eval 3.25% 16.25% 80.50% -64.25% +majority_no_precond 6.75% 76.00% 23.25% +52.75% +``` diff --git a/skills/gradient-debug.md b/skills/gradient-debug.md new file mode 100644 index 00000000..3f61501d --- /dev/null +++ b/skills/gradient-debug.md @@ -0,0 +1,92 @@ +# Gradient Debug + +Inspect and debug gradient data from bergson experiments. Use this when experiments produce unexpected results (e.g., all scores identical, 0% accuracy) to check whether gradients were computed correctly. + +## Usage + +``` +/gradient-debug [path] [options] +``` + +Arguments: +- `path` - Path to gradient store (index or eval_grads directory) + +Options: +- `--check-zeros` - Check for zero/corrupted gradients +- `--similarity` - Compute pairwise similarity between samples +- `--compare PATH` - Compare gradients between two stores + +## Instructions + +### Check gradient health: + +```python +import json +import numpy as np + +path = 'runs/attribute_preservation/index' # or user-specified path + +with open(f'{path}/info.json') as f: + info = json.load(f) + +grads = np.fromfile(f'{path}/gradients.bin', dtype=np.float16) +grads = grads.reshape(info['num_grads'], -1).astype(np.float32) + +print(f'Shape: {grads.shape}') +print(f'Min: {grads.min():.6f}') +print(f'Max: {grads.max():.6f}') +print(f'Mean: {grads.mean():.6f}') +print(f'Std: {grads.std():.6f}') +print(f'Non-zero entries: {np.count_nonzero(grads):,} / {grads.size:,}') + +# Check for corruption +if grads.std() == 0: + print('WARNING: All gradients are identical (likely corrupted)') +if np.count_nonzero(grads) == 0: + print('ERROR: All gradients are zero!') +``` + +### Check pairwise similarity: + +```python +# Sample a few gradients and compute cosine similarity +for i in range(min(5, len(grads))): + for j in range(i+1, min(5, len(grads))): + cos_sim = np.dot(grads[i], grads[j]) / ( + np.linalg.norm(grads[i]) * np.linalg.norm(grads[j]) + 1e-8 + ) + print(f'Cosine sim [{i}] vs [{j}]: {cos_sim:.4f}') +``` + +### Check score matrix: + +```python +scores_path = 'runs/attribute_preservation/scores_no_precond/scores.npy' +scores = np.load(scores_path) + +print(f'Shape: {scores.shape}') +print(f'Min: {scores.min():.4f}') +print(f'Max: {scores.max():.4f}') +print(f'Unique values: {len(np.unique(scores))}') + +if len(np.unique(scores)) == 1: + print('ERROR: All scores identical (gradients likely corrupted)') +if np.isnan(scores).any(): + print(f'WARNING: {np.isnan(scores).sum()} NaN values in scores') +``` + +## Common Issues + +1. **All zeros**: Gradient computation failed silently. Re-run bergson build. +2. **All identical**: Normalization issue or corrupted storage. +3. **NaN scores**: Division by zero in preconditioner inverse. Check preconditioner matrix. +4. **Same top-k for all queries**: Eval gradients are identical (corrupted). + +## Fix corrupted data + +Delete and rebuild: +```bash +rm -rf runs/attribute_preservation/index runs/attribute_preservation/eval_grads* +rm -rf runs/attribute_preservation/scores_* +# Then re-run the experiment +``` diff --git a/skills/preconditioner-analysis.md b/skills/preconditioner-analysis.md new file mode 100644 index 00000000..bb3e180d --- /dev/null +++ b/skills/preconditioner-analysis.md @@ -0,0 +1,103 @@ +# Preconditioner Analysis + +Analyze and compare preconditioner strategies for style suppression. Preconditioners transform the gradient space to downweight certain directions (e.g., style) before computing similarity. This skill helps understand what each preconditioner is doing and how much variance it captures. + +## Usage + +``` +/preconditioner-analysis [options] +``` + +Options: +- `--base-path PATH` - Experiment directory (default: runs/attribute_preservation) +- `--compare A B` - Compare two preconditioners directly +- `--eigenspectrum` - Analyze eigenvalue distribution + +## Available Preconditioners + +1. **R_between**: Computed from style mean differences + - `delta = mean(style_A) - mean(style_B)` + - `R_between = delta @ delta.T` (rank-1) + - Downweights the "style direction" + +2. **H_eval**: Second moment of eval gradients + - `H_eval = (1/n) * G_eval.T @ G_eval` + - Downweights directions that vary in eval set + - Doesn't require style labels + +3. **R_mixed**: Combination strategies (from preconditioners.py) + +## Instructions + +### Inspect R_between computation: + +```python +from datasets import load_from_disk +from bergson.data import load_gradients +import numpy as np +import json + +base_path = 'runs/attribute_preservation' +train_ds = load_from_disk(f'{base_path}/data/train.hf') +train_styles = np.array(train_ds['style']) + +with open(f'{base_path}/index/info.json') as f: + info = json.load(f) + +# Load one module's gradients to inspect +grads = np.fromfile(f'{base_path}/index/gradients.bin', dtype=np.float16) +grads = grads.reshape(info['num_grads'], -1).astype(np.float32) + +# Compute style means +styles = list(set(train_styles)) +means = {} +for style in styles: + idx = np.where(train_styles == style)[0] + means[style] = grads[idx].mean(axis=0) + print(f'{style}: {len(idx)} samples, mean norm: {np.linalg.norm(means[style]):.4f}') + +# Style direction +delta = means[styles[0]] - means[styles[1]] +print(f'Style delta norm: {np.linalg.norm(delta):.4f}') + +# How much variance is along style direction? +delta_unit = delta / (np.linalg.norm(delta) + 1e-8) +projections = grads @ delta_unit +print(f'Projection variance: {projections.var():.4f}') +print(f'Total variance: {grads.var():.4f}') +print(f'Style direction explains: {projections.var() / grads.var() * 100:.1f}% of variance') +``` + +### Compare preconditioner effects: + +```python +import torch +from bergson.gradients import GradientProcessor + +# Load preconditioners +r_between = GradientProcessor.load(f'{base_path}/r_between') +h_eval = GradientProcessor.load(f'{base_path}/h_eval') + +# Compare a specific module +module = list(r_between.preconditioners.keys())[0] +R = r_between.preconditioners[module] +H = h_eval.preconditioners[module] + +print(f'Module: {module}') +print(f'R_between rank: {torch.linalg.matrix_rank(R).item()}') +print(f'H_eval rank: {torch.linalg.matrix_rank(H).item()}') + +# Eigenvalue analysis +eig_R = torch.linalg.eigvalsh(R) +eig_H = torch.linalg.eigvalsh(H) +print(f'R_between top 5 eigenvalues: {eig_R[-5:].tolist()}') +print(f'H_eval top 5 eigenvalues: {eig_H[-5:].tolist()}') +``` + +## Key Insight + +R_between is computed on **training data**, mixing scientists and creatives in the shakespeare mean: +- `mean_shakespeare` = avg(scientists + creatives) +- `mean_pirate` = avg(business) + +This means the "style direction" might also capture some occupation signal. A cleaner approach would compute style means within occupation to isolate pure style variation. diff --git a/skills/semantic-metrics.md b/skills/semantic-metrics.md new file mode 100644 index 00000000..753ea0df --- /dev/null +++ b/skills/semantic-metrics.md @@ -0,0 +1,74 @@ +# Semantic Metrics + +Compute and display metrics for semantic attribution experiments. Use this to quickly check results from existing experiment runs without re-running the full pipeline. + +## Usage + +``` +/semantic-metrics [experiment] [options] +``` + +Arguments: +- `experiment` - Which experiment: `attribute` (default), `asymmetric`, or path to custom run + +Options: +- `--preconditioner NAME` - Specific preconditioner to evaluate (default: all available) +- `--base-path PATH` - Base path for experiment outputs + +## Instructions + +### For attribute preservation experiment: + +```python +from examples.semantic.attribute_preservation import ( + AttributePreservationConfig, + compute_attribute_metrics, + print_attribute_metrics, +) + +config = AttributePreservationConfig() +base_path = 'runs/attribute_preservation' + +# Compute for each available preconditioner +for name in ['no_precond', 'r_between', 'h_eval']: + precond = name if name != 'no_precond' else None + metrics = compute_attribute_metrics(config, base_path, precond) + print_attribute_metrics(metrics, name) + +# Also compute majority control if available +from pathlib import Path +if (Path(base_path) / 'data' / 'eval_majority.hf').exists(): + from examples.semantic.attribute_preservation import compute_majority_style_metrics + majority = compute_majority_style_metrics(config, base_path, None) + print_attribute_metrics(majority, 'majority_no_precond') +``` + +### For asymmetric experiment: + +```python +from examples.semantic.asymmetric import ( + AsymmetricConfig, + compute_asymmetric_metrics, + print_metrics, +) + +config = AsymmetricConfig() +base_path = 'runs/asymmetric' + +for name in ['no_precond', 'r_between']: + precond = name if name != 'no_precond' else None + metrics = compute_asymmetric_metrics(config, base_path, precond) + print_metrics(metrics, name) +``` + +## Key Metrics Explained + +### Attribute Preservation +- **Fact Accuracy**: Exact semantic match (same person, same fact) +- **Occupation Accuracy**: Same occupation cluster (attribute preservation) +- **Style-Only Match**: Style matches but occupation doesn't (style leakage) +- **Trade-off**: Occ Acc - Style Only (overall quality) + +### Asymmetric +- **Fact Accuracy**: Correct fact retrieval +- **Style Leakage**: Matching based on style rather than content