diff --git a/bergson/process_preconditioners.py b/bergson/process_preconditioners.py index da53ea3d..4c57ab94 100644 --- a/bergson/process_preconditioners.py +++ b/bergson/process_preconditioners.py @@ -45,7 +45,7 @@ def process_preconditioners( if rank == 0: print("Gathering preconditioners...") - cpu_group = dist.new_group(backend="gloo") + cpu_group: dist.ProcessGroup = dist.new_group(backend="gloo") # type: ignore[assignment] for name, grad_size in grad_sizes.items(): if name in preconditioners: diff --git a/bergson/utils/math.py b/bergson/utils/math.py index b0280ce1..33faf31d 100644 --- a/bergson/utils/math.py +++ b/bergson/utils/math.py @@ -84,7 +84,8 @@ def compute_damped_inverse( dtype: Dtype for intermediate computation (default: float64 for stability). regularizer: Optional matrix to use as regularizer instead of identity. If provided, computes inv(H + damping_factor * regularizer). - If None (default), uses scaled identity: inv(H + damping_factor * |H|_mean * I). + If None (default), uses scaled identity: + inv(H + damping_factor * |H|_mean * I). Returns: The damped inverse H^(-1) in the original dtype of H. diff --git a/docs/experiments.md b/docs/experiments.md new file mode 100644 index 00000000..273d378c --- /dev/null +++ b/docs/experiments.md @@ -0,0 +1,137 @@ +# Experiment Walkthroughs + +This page provides walkthroughs for running bergson experiments. Skill files for reproducing the results are available in `skills/`. + +## Asymmetric Style Suppression + +This experiment evaluates various influence functions on a fact retrieval task where the query is in a different writing style to the training data. + +This is analogous to a common use case where your query differs stylistically from the training corpus you want to query; for example, because the query comes from an evaluation set written in a multi-choice question format while the training data is sampled from a more general distribution. We are often only interested in examining how the "content" was learned, and not the style. + +### Requirements + +**Using existing HuggingFace artifacts**: +- GPU with ~24GB VRAM for the analysis model (Qwen3-8B with LoRA) +This option will pull [EleutherAI/bergson-asymmetric-style](https://huggingface.co/datasets/EleutherAI/bergson-asymmetric-style) and [EleutherAI/bergson-asymmetric-style-qwen3-8b-lora](https://huggingface.co/EleutherAI/bergson-asymmetric-style-qwen3-8b-lora) from Hugging Face. + +**Using a regenerated dataset**: +- Qwen3-8B-Base for style rewording +- Additional disk space for intermediate datasets + +### Quickstart + +**Reproduce results with an AI agent**: Point Claude Code or another AI agent at `skills/asymmetric-style.md` for detailed instructions and options (`--recompute`, `--sweep-pca`, `--rewrite-ablation`, `--summary`). + +**Reproduce results manually**: + +```python +from examples.semantic.asymmetric import run_asymmetric_experiment, AsymmetricConfig + +config = AsymmetricConfig( + # Use pre-computed data + hf_dataset="EleutherAI/bergson-asymmetric-style", +) + +results = run_asymmetric_experiment( + config=config, + base_path="runs/asymmetric_style", +) + +sorted_results = sorted(results.items(), key=lambda x: -x[1]["top1_semantic"]) +print(f"{'Strategy':<35} {'Top-1 Sem':<12} {'Top-5 Recall':<13} {'Top-1 Leak':<12}") +print("-" * 72) +for name, m in sorted_results: + print(f"{name:<35} {m['top1_semantic']:<12.2%} {m['top5_semantic_recall']:<13.2%} {m['top1_leak']:<12.2%}") +``` + +### Dataset Structure +The experiment creates train/eval splits with disjoint fact-style combinations: + +- **Training set**: Each fact appears in exactly one style (95% shakespeare, 5% pirate) +- **Eval set**: Queries use the *opposite* style from training—facts that were trained in shakespeare are queried in pirate + +This design means style leakage and semantic accuracy are mutually exclusive: if attribution finds a training example with matching style, it necessarily has the wrong fact (since that fact-style combo doesn't exist in training). The exception is the `majority_no_precond` control, which queries in the majority (shakespeare) style—here style and semantic matches align. + +### Pipeline + +The experiment (`run_asymmetric_experiment`) runs these steps: + +1. **Create dataset** - Downloads from HuggingFace or generates locally with style rewording +2. **Build gradient index** - Collects gradients for all training samples using `bergson build` +3. **Collect eval gradients** - Computes gradients for eval queries +4. **Compute preconditioners** - Builds various preconditioner matrices (R_between, H_train, H_eval, PCA projection) +5. **Score and evaluate** - Computes similarity scores and metrics for each strategy + +### Output Structure + +``` +runs/asymmetric_style/ +├── data/ +│ ├── train.hf # Training set (HuggingFace Dataset) +│ ├── eval.hf # Eval set (HuggingFace Dataset) +│ └── eval_majority.hf # Eval in majority style (control) +├── index/ # Training gradients (bergson index format) +├── eval_grads/ # Eval gradients +├── preconditioners/ # Preconditioner matrices (.pt files) +├── scores_*/ # Score matrices for each strategy +└── experiment_results.json # Metrics summary +``` + +### Key Metrics + +- **Top-1 Semantic Accuracy**: Top match has same underlying fact (higher is better) +- **Top-5 Semantic Recall**: Any of top-5 matches has same underlying fact (higher is better) +- **Top-1 Style Leakage**: Top match is minority style (lower is better). Due to the disjoint partitioning, high leakage implies low semantic accuracy and vice versa. + +### Strategies + +The experiment compares multiple strategies for suppressing style and recovering semantic matching. + +#### Baseline + +- **no_precond**: Bare gradient cosine similarity. Expected to fail because style dominates. + +#### Controls (alternative evaluation sets) + +- **majority_no_precond**: Query in shakespeare style (the majority/dominant style). No style mismatch, so this is the upper bound—style and semantic matches align. +- **original_style_no_precond**: Eval set uses original (unstyled) facts instead of pirate style. +- **summed_majority_minority**: Eval gradients are the sum of pirate and shakespeare style gradients for each fact. Hypothesis: style-specific components cancel out. + +#### Preconditioners + +Without preconditioning, similarity is computed as cosine similarity of gradients: + +```python +score(q, t) = cos(g_q, g_t) + = (g_q · g_t) / (||g_q|| ||g_t||) +``` + +where `g_q` is the eval gradient and `g_t` is a training gradient (row vectors). + +With a preconditioner matrix `H`, we transform the eval gradient before computing similarity: + +```python +H_inv = (H + λI)^(-1) # damped inverse +g_eval_precond = g_eval @ H_inv +g_eval_norm = g_eval_precond / ||g_eval_precond|| +g_train_norm = g_train / ||g_train|| +score(q, t) = g_eval_norm · g_train_norm +``` + +The unnormalized inner product `g_eval @ H^(-1) @ g_train.T` is the classic influence function formula. Preconditioning downweights directions where `H` has large eigenvalues. + +- **index**: `H = G_train.T @ G_train` (training set second moment). This is the classic [influence function](https://arxiv.org/abs/1703.04730) formulation: a second-order Taylor approximation shows that the change in loss from upweighting a training point is proportional to `g_eval @ H^(-1) @ g_train.T`. Intuitively, H^(-1) gives each training point less credit for similarity in "common" directions (where many training points contribute) and more credit in rare/specific directions. + +- **eval_second_moment**: `H = G_eval.T @ G_eval`. Since training gradients average to ~0 at convergence, directions where eval gradients deviate from zero will dominate `H_eval`. Preconditioning downweights these directions. If style causes systematic deviation in eval gradients (e.g., pirate queries all shift in a similar direction), this suppresses the style signal. + +- **train_eval_mixed**: `H = α * H_train + (1-α) * H_eval`. Combines intuitions from both. + +- **r_between**: `H = (μ_pirate - μ_shakespeare)(μ_pirate - μ_shakespeare)^T + λI` (computed on train set). A rank-1 matrix capturing the "style direction" directly. Preconditioning projects out this direction. + +#### Dimensionality Reduction + +- **pca_k{n}_index**: Compute PCA on the matrix of paired pirate/shakespeare style differences (differences between matched facts, train set only). Project gradients onto the orthogonal complement of the top-n principal components, then precondition with the train set second moment matrix. This removes the dominant style directions while preserving semantic signal. + +#### Semantic-only Eval + +- **semantic_index**, **semantic_no_precond**, etc.: Transform eval data into Q&A format like `"Where does Paul Tilmouth work? Siemens"` and mask all gradients up to the `?`. This isolates the semantic content (answer tokens) from any style in the query. Combined with preconditioning (`semantic_index`), this method achieves the best results by a significant margin. diff --git a/docs/index.rst b/docs/index.rst index f972e8fe..305f4810 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,6 +52,14 @@ API Reference api +Experiments +----------- + +.. toctree:: + :maxdepth: 2 + + experiments + Content Index ------------------ diff --git a/examples/semantic/asymmetric.py b/examples/semantic/asymmetric.py index ae6d3178..524d4d35 100644 --- a/examples/semantic/asymmetric.py +++ b/examples/semantic/asymmetric.py @@ -46,7 +46,7 @@ class AsymmetricConfig: # HuggingFace dataset repo. If set, skips local generation and downloads from HF. hf_dataset: str | None = None # Template split for train/test segregation (only used for local generation) - # Train uses templates < train_template_cutoff, eval majority uses templates >= cutoff + # Train uses templates < cutoff, eval majority uses templates >= cutoff train_template_cutoff: int = 5 @@ -81,13 +81,19 @@ def create_asymmetric_dataset( # Return cached if exists if train_path.exists() and eval_path.exists(): print(f"Loading cached datasets from {output_dir}") - return load_from_disk(str(train_path)), load_from_disk(str(eval_path)) + train_cached = load_from_disk(str(train_path)) + eval_cached = load_from_disk(str(eval_path)) + if isinstance(train_cached, DatasetDict): + train_cached = train_cached["train"] + if isinstance(eval_cached, DatasetDict): + eval_cached = eval_cached["train"] + return train_cached, eval_cached # Load original facts to get metadata columns original = load_from_disk("data/facts_dataset.hf") if isinstance(original, DatasetDict): original = original["train"] - fact_to_meta = {row["fact"]: row for row in original} + fact_to_meta = {row["fact"]: row for row in original} # type: ignore[index] # Load style-specific datasets (Qwen only for consistency) style_datasets = { @@ -98,14 +104,14 @@ def create_asymmetric_dataset( } for name in style_datasets: if isinstance(style_datasets[name], DatasetDict): - style_datasets[name] = style_datasets[name]["train"] + style_datasets[name] = style_datasets[name]["train"] # type: ignore[index] # Add back metadata columns from original ds = style_datasets[name] for col in original.column_names: if col not in ds.column_names: - restored_col = [fact_to_meta[row["fact"]][col] for row in ds] - ds = ds.add_column(col, restored_col) + restored_col = [fact_to_meta[row["fact"]][col] for row in ds] # type: ignore[index] + ds = ds.add_column(col, restored_col) # type: ignore[union-attr] style_datasets[name] = ds dominant_ds = style_datasets[config.dominant_style] @@ -113,7 +119,7 @@ def create_asymmetric_dataset( # Get unique (identifier, field) pairs - these represent underlying semantic facts # Each pair has multiple templates (different surface forms of the same fact) - semantic_facts = list({(row["identifier"], row["field"]) for row in original}) + semantic_facts = list({(row["identifier"], row["field"]) for row in original}) # type: ignore[index] n_semantic_facts = len(semantic_facts) # Split into exclusive (dominant-only) and shared by semantic fact @@ -130,22 +136,24 @@ def create_asymmetric_dataset( print(f"Template cutoff for train/eval split: {config.train_template_cutoff}") # Build training set with template filtering - # 1. Dominant style: only templates < cutoff (to reserve rest for eval majority control) + # 1. Dominant style: only templates < cutoff + # (to reserve rest for eval majority control) train_dominant_indices = [ i for i, row in enumerate(dominant_ds) - if row["template"] < config.train_template_cutoff + if row["template"] < config.train_template_cutoff # type: ignore[index] ] - train_dominant = dominant_ds.select(train_dominant_indices) + train_dominant = dominant_ds.select(train_dominant_indices) # type: ignore[union-attr] - # 2. Minority style only for shared facts (any template since minority eval is different) + # 2. Minority style only for shared facts + # (any template since minority eval is different) minority_shared_indices = [ i for i, row in enumerate(minority_ds) - if (row["identifier"], row["field"]) in shared_semantic_facts - and row["template"] < config.train_template_cutoff + if (row["identifier"], row["field"]) in shared_semantic_facts # type: ignore[index] + and row["template"] < config.train_template_cutoff # type: ignore[index] ] - train_minority = minority_ds.select(minority_shared_indices) + train_minority = minority_ds.select(minority_shared_indices) # type: ignore[union-attr] # Add style column train_dominant = train_dominant.add_column( @@ -172,10 +180,10 @@ def create_asymmetric_dataset( eval_minority_indices = [ i for i, row in enumerate(minority_ds) - if (row["identifier"], row["field"]) in exclusive_semantic_facts - and row["template"] >= config.train_template_cutoff + if (row["identifier"], row["field"]) in exclusive_semantic_facts # type: ignore[index] + and row["template"] >= config.train_template_cutoff # type: ignore[index] ] - eval_ds = minority_ds.select(eval_minority_indices) + eval_ds = minority_ds.select(eval_minority_indices) # type: ignore[union-attr] eval_ds = eval_ds.add_column("style", [config.minority_style] * len(eval_ds)) # Add expected_match_style to indicate where the ground truth is @@ -208,7 +216,8 @@ def create_asymmetric_index( Args: config: Experiment configuration. base_path: Base path for experiment outputs. - analysis_model: Model to use for gradient collection. Defaults to HF_ANALYSIS_MODEL. + analysis_model: Model to use for gradient collection. + Defaults to HF_ANALYSIS_MODEL. Returns: Path to the created index. @@ -287,16 +296,22 @@ def score_asymmetric_eval( Args: config: Experiment configuration. base_path: Base path for experiment outputs. - preconditioner_name: Name of preconditioner subdirectory (None for no precond). - damping_factor: Damping factor for matrix inversion (default: 0.1). - regularizer_name: Name of preconditioner to use as regularizer instead of identity. - If provided, computes inv(H + damping_factor * H_regularizer). - Useful for regularizing rank-deficient preconditioners like r_between - with a well-conditioned matrix like H_train or H_eval. - eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). - eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). - Set to "question"/"answer" for semantic-only attribution where gradients - only come from the answer tokens. + preconditioner_name: Name of preconditioner subdirectory + (None for no precond). + damping_factor: Damping factor for matrix inversion + (default: 0.1). + regularizer_name: Name of preconditioner to use as + regularizer instead of identity. If provided, computes + inv(H + damping_factor * H_regularizer). Useful for + regularizing rank-deficient preconditioners like + r_between with a well-conditioned matrix like + H_train or H_eval. + eval_prompt_column: Column to use as prompt for eval + gradients (default: "fact"). + eval_completion_column: Column to use as completion for + eval gradients (default: "reworded"). Set to + "question"/"answer" for semantic-only attribution + where gradients only come from the answer tokens. Returns: Score matrix of shape (n_eval, n_train). @@ -315,7 +330,8 @@ def score_asymmetric_eval( index_path = base_path / "index" data_path = base_path / "data" - # Determine output path (include damping factor, regularizer, and eval columns in cache key) + # Determine output path + # (include damping factor, regularizer, and eval columns in cache key) damping_suffix = f"_d{damping_factor:.0e}" if damping_factor != 0.1 else "" reg_suffix = f"_reg_{regularizer_name}" if regularizer_name else "" # Add eval column suffix if not using default columns @@ -324,8 +340,8 @@ def score_asymmetric_eval( eval_col_suffix = f"_{eval_prompt_column}_{eval_completion_column}" if preconditioner_name: scores_path = ( - base_path - / f"scores_{preconditioner_name}{damping_suffix}{reg_suffix}{eval_col_suffix}" + base_path / f"scores_{preconditioner_name}" + f"{damping_suffix}{reg_suffix}{eval_col_suffix}" ) precond_path = base_path / preconditioner_name else: @@ -510,11 +526,15 @@ def compute_asymmetric_metrics( config: Experiment configuration. base_path: Base path for experiment outputs. preconditioner_name: Name of preconditioner to use. - damping_factor: Damping factor for matrix inversion (default: 0.1). - regularizer_name: Name of preconditioner to use as regularizer instead of identity. - eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). - eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). - Set to "question"/"answer" for semantic-only attribution. + damping_factor: Damping factor for matrix inversion + (default: 0.1). + regularizer_name: Name of preconditioner to use as + regularizer instead of identity. + eval_prompt_column: Column to use as prompt for eval + gradients (default: "fact"). + eval_completion_column: Column to use as completion for + eval gradients (default: "reworded"). Set to + "question"/"answer" for semantic-only attribution. Returns: AsymmetricMetrics dataclass. @@ -545,12 +565,12 @@ def compute_asymmetric_metrics( n_eval = len(eval_ds) # Extract metadata - train_styles = train_ds["style"] - train_identifiers = train_ds["identifier"] - train_fields = train_ds["field"] + train_styles = train_ds["style"] # type: ignore[index] + train_identifiers = train_ds["identifier"] # type: ignore[index] + train_fields = train_ds["field"] # type: ignore[index] - eval_identifiers = eval_ds["identifier"] - eval_fields = eval_ds["field"] + eval_identifiers = eval_ds["identifier"] # type: ignore[index] + eval_fields = eval_ds["field"] # type: ignore[index] # Get top-k indices for each query top_k = 10 @@ -682,7 +702,7 @@ def compute_style_preconditioner( if isinstance(train_ds, DatasetDict): train_ds = train_ds["train"] - train_styles = train_ds["style"] + train_styles = train_ds["style"] # type: ignore[index] train_grads = load_gradients(index_path, structured=True) with open(index_path / "info.json") as f: @@ -762,10 +782,13 @@ def score_asymmetric_eval_with_pca_projection( style_subspace: Dictionary from compute_pca_style_subspace(). top_k: Number of principal components used (for cache naming). preconditioner_name: Optional preconditioner to apply after projection. - damping_factor: Damping factor for matrix inversion (default: 0.1). - eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). - eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). - Set to "question"/"answer" for semantic-only attribution. + damping_factor: Damping factor for matrix inversion + (default: 0.1). + eval_prompt_column: Column to use as prompt for eval + gradients (default: "fact"). + eval_completion_column: Column to use as completion for + eval gradients (default: "reworded"). Set to + "question"/"answer" for semantic-only attribution. Returns: Score matrix of shape (n_eval, n_train). @@ -793,8 +816,8 @@ def score_asymmetric_eval_with_pca_projection( eval_col_suffix = f"_{eval_prompt_column}_{eval_completion_column}" if preconditioner_name: scores_path = ( - base_path - / f"scores_pca_k{top_k}_{preconditioner_name}{damping_suffix}{eval_col_suffix}" + base_path / f"scores_pca_k{top_k}_{preconditioner_name}" + f"{damping_suffix}{eval_col_suffix}" ) precond_path = base_path / preconditioner_name else: @@ -823,7 +846,8 @@ def score_asymmetric_eval_with_pca_projection( n_eval = len(eval_ds) print( - f"Scoring {n_eval} eval queries against {n_train} train samples (PCA projection)" + f"Scoring {n_eval} eval queries against " + f"{n_train} train samples (PCA projection)" ) # Load train gradients @@ -958,11 +982,15 @@ def compute_asymmetric_metrics_with_pca( base_path: Base path for experiment outputs. style_subspace: Dictionary from compute_pca_style_subspace(). top_k: Number of principal components. - preconditioner_name: Optional preconditioner to combine with PCA. - damping_factor: Damping factor for matrix inversion (default: 0.1). - eval_prompt_column: Column to use as prompt for eval gradients (default: "fact"). - eval_completion_column: Column to use as completion for eval gradients (default: "reworded"). - Set to "question"/"answer" for semantic-only attribution. + preconditioner_name: Optional preconditioner to combine + with PCA. + damping_factor: Damping factor for matrix inversion + (default: 0.1). + eval_prompt_column: Column to use as prompt for eval + gradients (default: "fact"). + eval_completion_column: Column to use as completion for + eval gradients (default: "reworded"). Set to + "question"/"answer" for semantic-only attribution. Returns: AsymmetricMetrics dataclass. @@ -994,12 +1022,12 @@ def compute_asymmetric_metrics_with_pca( n_eval = len(eval_ds) # Extract metadata - train_styles = train_ds["style"] - train_identifiers = train_ds["identifier"] - train_fields = train_ds["field"] + train_styles = train_ds["style"] # type: ignore[index] + train_identifiers = train_ds["identifier"] # type: ignore[index] + train_fields = train_ds["field"] # type: ignore[index] - eval_identifiers = eval_ds["identifier"] - eval_fields = eval_ds["field"] + eval_identifiers = eval_ds["identifier"] # type: ignore[index] + eval_fields = eval_ds["field"] # type: ignore[index] # Get top-k indices for each query top_k_results = 10 @@ -1110,8 +1138,8 @@ def create_majority_style_eval( if isinstance(majority_eval_ds, DatasetDict): majority_eval_ds = majority_eval_ds["train"] - train_reworded = set(train_ds["reworded"]) - eval_reworded = set(majority_eval_ds["reworded"]) + train_reworded = set(train_ds["reworded"]) # type: ignore[index] + eval_reworded = set(majority_eval_ds["reworded"]) # type: ignore[index] overlap = train_reworded & eval_reworded has_leakage = len(overlap) > 0 @@ -1142,7 +1170,7 @@ def create_majority_style_eval( eval_ds = eval_ds["train"] # Get semantic facts from eval (identifier, field pairs) - eval_semantic_facts = {(row["identifier"], row["field"]) for row in eval_ds} + eval_semantic_facts = {(row["identifier"], row["field"]) for row in eval_ds} # type: ignore[index] # Load dominant style dataset dominant_ds = load_from_disk(str(local_styled_path)) @@ -1153,11 +1181,11 @@ def create_majority_style_eval( original = load_from_disk("data/facts_dataset.hf") if isinstance(original, DatasetDict): original = original["train"] - fact_to_meta = {row["fact"]: row for row in original} + fact_to_meta = {row["fact"]: row for row in original} # type: ignore[index] for col in original.column_names: if col not in dominant_ds.column_names: - restored_col = [fact_to_meta[row["fact"]][col] for row in dominant_ds] + restored_col = [fact_to_meta[row["fact"]][col] for row in dominant_ds] # type: ignore[index] dominant_ds = dominant_ds.add_column(col, restored_col) # Select dominant style versions of eval semantic facts @@ -1165,8 +1193,8 @@ def create_majority_style_eval( dominant_eval_indices = [ i for i, row in enumerate(dominant_ds) - if (row["identifier"], row["field"]) in eval_semantic_facts - and row["template"] >= config.train_template_cutoff + if (row["identifier"], row["field"]) in eval_semantic_facts # type: ignore[index] + and row["template"] >= config.train_template_cutoff # type: ignore[index] ] majority_eval_ds = dominant_ds.select(dominant_eval_indices) @@ -1224,7 +1252,8 @@ def score_majority_style_eval( _, has_leakage = create_majority_style_eval(config, base_path) if has_leakage: print( - " Note: Majority control may show inflated accuracy due to train/test leakage" + " Note: Majority control may show inflated " + "accuracy due to train/test leakage" ) # Determine output path @@ -1387,12 +1416,12 @@ def compute_majority_style_metrics( n_eval = len(eval_ds) # Extract metadata - train_styles = train_ds["style"] - train_identifiers = train_ds["identifier"] - train_fields = train_ds["field"] + train_styles = train_ds["style"] # type: ignore[index] + train_identifiers = train_ds["identifier"] # type: ignore[index] + train_fields = train_ds["field"] # type: ignore[index] - eval_identifiers = eval_ds["identifier"] - eval_fields = eval_ds["field"] + eval_identifiers = eval_ds["identifier"] # type: ignore[index] + eval_fields = eval_ds["field"] # type: ignore[index] # Get top-k indices for each query top_k = 10 @@ -1429,7 +1458,8 @@ def compute_majority_style_metrics( semantic_top10 += 1 break - # Check style leakage - for majority style, "leakage" means NOT matching dominant + # Check style leakage - for majority style, + # "leakage" means NOT matching dominant # We flip the interpretation: matching minority style would be leakage top1_style = train_styles[top_k_idx[0]] if top1_style == config.minority_style: @@ -1531,16 +1561,17 @@ def score_summed_eval( n_eval = len(eval_minority_ds) print( - f"Scoring {n_eval} summed eval queries (minority + majority) against {n_train} train samples" + f"Scoring {n_eval} summed eval queries " + f"(minority + majority) against {n_train} train samples" ) # Build semantic fact mapping for alignment (identifier, field pairs) # This works even when templates differ between minority and majority eval minority_semantic_facts = [ - (row["identifier"], row["field"]) for row in eval_minority_ds + (row["identifier"], row["field"]) for row in eval_minority_ds # type: ignore[index] ] majority_semantic_to_idx = { - (row["identifier"], row["field"]): i for i, row in enumerate(eval_majority_ds) + (row["identifier"], row["field"]): i for i, row in enumerate(eval_majority_ds) # type: ignore[index] } # Verify alignment by semantic fact @@ -1732,12 +1763,12 @@ def compute_summed_eval_metrics( n_eval = len(eval_ds) # Extract metadata - train_styles = train_ds["style"] - train_identifiers = train_ds["identifier"] - train_fields = train_ds["field"] + train_styles = train_ds["style"] # type: ignore[index] + train_identifiers = train_ds["identifier"] # type: ignore[index] + train_fields = train_ds["field"] # type: ignore[index] - eval_identifiers = eval_ds["identifier"] - eval_fields = eval_ds["field"] + eval_identifiers = eval_ds["identifier"] # type: ignore[index] + eval_fields = eval_ds["field"] # type: ignore[index] # Get top-k indices for each query top_k = 10 @@ -1890,7 +1921,8 @@ def sweep_pca_k( for name, m in sorted(all_metrics.items()): print( - f"{name:<30} {m.top1_semantic_accuracy:<15.2%} {m.top1_style_leakage:<17.2%}" + f"{name:<30} {m.top1_semantic_accuracy:<15.2%} " + f"{m.top1_style_leakage:<17.2%}" ) return all_metrics @@ -1917,17 +1949,24 @@ def run_asymmetric_experiment( config: Experiment configuration (uses defaults if None). Set config.hf_dataset to load data from HuggingFace instead of generating locally. base_path: Base path for experiment outputs. - analysis_model: Model to use for gradient collection. Defaults to HF_ANALYSIS_MODEL. + analysis_model: Model to use for gradient collection. + Defaults to HF_ANALYSIS_MODEL. include_pca: Whether to include PCA projection strategy. - pca_top_k: Number of principal components for PCA projection. - include_summed_loss: Whether to include summed loss preconditioner strategy. - include_second_moments: Whether to include train/eval/mixed second moment strategies. - include_majority_control: Whether to include majority style eval as control. - include_summed_eval: Whether to include summed eval gradient approach (minority + majority). - include_semantic_eval: Whether to include semantic-only eval using question/answer columns. - This tests attribution when gradients only come from the semantic content (answer tokens), - ignoring style in the eval query entirely. - damping_factor: Damping factor for matrix inversion (default: 0.1). + pca_top_k: Number of principal components for PCA. + include_summed_loss: Whether to include summed loss + preconditioner strategy. + include_second_moments: Whether to include train/eval/mixed + second moment strategies. + include_majority_control: Whether to include majority style + eval as control. + include_summed_eval: Whether to include summed eval gradient + approach (minority + majority). + include_semantic_eval: Whether to include semantic-only eval + using question/answer columns. This tests attribution + when gradients only come from the semantic content + (answer tokens), ignoring style in the eval query. + damping_factor: Damping factor for matrix inversion + (default: 0.1). Returns: Dictionary mapping preconditioner names to their metrics. @@ -1965,7 +2004,6 @@ def run_asymmetric_experiment( compute_style_preconditioner(base_path, config) # Step 2b: Compute summed loss preconditioner if requested - summed_loss_proc = None if include_summed_loss: print("\n" + "-" * 60) print("STEP 2b: Computing summed loss preconditioner") @@ -1981,7 +2019,8 @@ def run_asymmetric_experiment( ) else: print( - " Style-specific indices not found, skipping summed loss preconditioner" + " Style-specific indices not found, " + "skipping summed loss preconditioner" ) print(f" (Expected: {pirate_idx} and {shakespeare_idx})") include_summed_loss = False @@ -2101,7 +2140,8 @@ def run_asymmetric_experiment( print_metrics(metrics, "summed_eval") all_metrics["summed_eval"] = metrics - # Evaluate semantic-only approach (question/answer columns - gradients only from answer) + # Evaluate semantic-only approach + # (question/answer columns - gradients only from answer) if include_semantic_eval: print("\n" + "-" * 60) print("SEMANTIC-ONLY EVAL (gradients only from answer tokens)") @@ -2156,7 +2196,8 @@ def run_asymmetric_experiment( for name, m in all_metrics.items(): print( - f"{name:<25} {m.top1_semantic_accuracy:<15.2%} {m.top1_style_leakage:<17.2%}" + f"{name:<25} {m.top1_semantic_accuracy:<15.2%} " + f"{m.top1_style_leakage:<17.2%}" ) print("\n" + "=" * 70) @@ -2178,15 +2219,18 @@ def run_asymmetric_experiment( print(f" - pca_projection_k{pca_top_k}: Project out top-{pca_top_k} style PCs") if include_majority_control: print( - " - majority_no_precond: Control using majority style for eval (no mismatch)" + " - majority_no_precond: Control using majority " + "style for eval (no mismatch)" ) if include_summed_eval: print( - " - summed_eval: Sum minority + majority style eval gradients (style-neutral query)" + " - summed_eval: Sum minority + majority style eval gradients " + "(style-neutral query)" ) if include_semantic_eval: print( - " - semantic_*: Eval gradients only from answer tokens (question/answer format)" + " - semantic_*: Eval gradients only from answer tokens " + "(question/answer format)" ) print( " Tests if attribution works when query has no style information at all" @@ -2297,10 +2341,10 @@ def score_with_inner_product( # Use semantic fact alignment (identifier, field) since templates may differ minority_semantic_facts = [ - (row["identifier"], row["field"]) for row in eval_minority_ds + (row["identifier"], row["field"]) for row in eval_minority_ds # type: ignore[index] ] majority_semantic_to_idx = { - (row["identifier"], row["field"]): i + (row["identifier"], row["field"]): i # type: ignore[index] for i, row in enumerate(eval_majority_ds) } @@ -2394,11 +2438,11 @@ def run_inner_product_comparison( eval_ds = eval_ds["train"] n_eval = len(eval_ds) - train_styles = train_ds["style"] - train_identifiers = train_ds["identifier"] - train_fields = train_ds["field"] - eval_identifiers = eval_ds["identifier"] - eval_fields = eval_ds["field"] + train_styles = train_ds["style"] # type: ignore[index] + train_identifiers = train_ds["identifier"] # type: ignore[index] + train_fields = train_ds["field"] # type: ignore[index] + eval_identifiers = eval_ds["identifier"] # type: ignore[index] + eval_fields = eval_ds["field"] # type: ignore[index] def compute_metrics_from_scores(scores): top_indices = np.argsort(-scores, axis=1)[:, :10] @@ -2514,13 +2558,13 @@ def create_original_style_eval( if isinstance(eval_ds, DatasetDict): eval_ds = eval_ds["train"] - eval_facts = list(eval_ds["fact"]) + eval_facts = list(eval_ds["fact"]) # type: ignore[index] # Load original facts dataset to get metadata original = load_from_disk("data/facts_dataset.hf") if isinstance(original, DatasetDict): original = original["train"] - fact_to_row = {row["fact"]: row for row in original} + fact_to_row = {row["fact"]: row for row in original} # type: ignore[index] # Build original style eval dataset (fact = reworded = original text) rows = [] @@ -2537,7 +2581,8 @@ def create_original_style_eval( original_eval_ds = Dataset.from_list(rows) original_eval_ds.save_to_disk(str(original_eval_path)) print( - f"Saved original style eval ({len(original_eval_ds)} samples) to {original_eval_path}" + f"Saved original style eval ({len(original_eval_ds)} samples)" + f" to {original_eval_path}" ) return original_eval_path @@ -2571,7 +2616,7 @@ def create_pirate_style_eval( if isinstance(eval_ds, DatasetDict): eval_ds = eval_ds["train"] - eval_facts = set(eval_ds["fact"]) + eval_facts = set(eval_ds["fact"]) # type: ignore[index] # Load pirate dataset pirate_ds = load_from_disk("data/facts_dataset_pirate-Qwen3-8B-Base.hf") @@ -2582,16 +2627,16 @@ def create_pirate_style_eval( original = load_from_disk("data/facts_dataset.hf") if isinstance(original, DatasetDict): original = original["train"] - fact_to_meta = {row["fact"]: row for row in original} + fact_to_meta = {row["fact"]: row for row in original} # type: ignore[index] for col in original.column_names: if col not in pirate_ds.column_names: - restored_col = [fact_to_meta[row["fact"]][col] for row in pirate_ds] + restored_col = [fact_to_meta[row["fact"]][col] for row in pirate_ds] # type: ignore[index] pirate_ds = pirate_ds.add_column(col, restored_col) # Select only the exclusive facts (same facts as in minority eval) pirate_eval_indices = [ - i for i, row in enumerate(pirate_ds) if row["fact"] in eval_facts + i for i, row in enumerate(pirate_ds) if row["fact"] in eval_facts # type: ignore[index] ] pirate_eval_ds = pirate_ds.select(pirate_eval_indices) @@ -2679,7 +2724,8 @@ def score_summed_rewrites( n_eval = len(shakespeare_eval_ds) print( - f"Scoring {n_eval} summed rewrite queries (shakespeare + pirate) against {n_train} train" + f"Scoring {n_eval} summed rewrite queries (shakespeare + pirate)" + f" against {n_train} train" ) # Build fact-to-index mapping for alignment @@ -3010,7 +3056,8 @@ def compute_rewrite_ablation_metrics( Args: config: Experiment configuration. base_path: Base path for experiment outputs. - strategy: One of "original", "summed_rewrites", "shakespeare_only", "pirate_only". + strategy: One of "original", "summed_rewrites", + "shakespeare_only", "pirate_only". preconditioner_name: Name of preconditioner subdirectory (None for no precond). Returns: @@ -3048,12 +3095,12 @@ def compute_rewrite_ablation_metrics( n_eval = len(eval_ds) # Extract metadata - train_styles = train_ds["style"] - train_identifiers = train_ds["identifier"] - train_fields = train_ds["field"] + train_styles = train_ds["style"] # type: ignore[index] + train_identifiers = train_ds["identifier"] # type: ignore[index] + train_fields = train_ds["field"] # type: ignore[index] - eval_identifiers = eval_ds["identifier"] - eval_fields = eval_ds["field"] + eval_identifiers = eval_ds["identifier"] # type: ignore[index] + eval_fields = eval_ds["field"] # type: ignore[index] # Get top-k indices for each query top_k = 10 @@ -3340,7 +3387,8 @@ def run_rewrite_ablation_experiment( for name, m in all_metrics.items(): print( - f"{name:<25} {m.top1_semantic_accuracy:<15.2%} {m.top1_style_leakage:<17.2%}" + f"{name:<25} {m.top1_semantic_accuracy:<15.2%} " + f"{m.top1_style_leakage:<17.2%}" ) return all_metrics diff --git a/examples/semantic/attribute_preservation.py b/examples/semantic/attribute_preservation.py index e8d06a87..21782370 100644 --- a/examples/semantic/attribute_preservation.py +++ b/examples/semantic/attribute_preservation.py @@ -237,7 +237,8 @@ class AttributePreservationConfig: default_factory=lambda: { "scientist": "shakespeare", # Scientists in Shakespeare style "business": "pirate", # Business in Pirate style - "creative": "shakespeare", # Creative in Shakespeare style (same as scientist) + # Creative in Shakespeare style (same as scientist) + "creative": "shakespeare", } ) @@ -413,14 +414,17 @@ def reword_dataset_with_style( style_prompts = { "shakespeare": ( - "Reword the following fact in a Shakespearean style, adding flair and poetry.\n" - "Do not include other text in your response, just the contents of the reworded fact.\n" + "Reword the following fact in a Shakespearean style, adding flair" + " and poetry.\n" + "Do not include other text in your response, just the contents of " + "the reworded fact.\n" "Fact: {fact}\n" "Your rewrite:" ), "pirate": ( "Reword the following fact like it's coming from a pirate. Be creative!\n" - "Do not include any other text in your response, just the contents of the reworded fact.\n" + "Do not include any other text in your response, just the contents of " + "the reworded fact.\n" "Fact: {fact}\n" "Your rewrite:" ), @@ -449,7 +453,7 @@ def reword_dataset_with_style( for i in tqdm(range(0, len(data_list), batch_size)): batch_items = data_list[i : i + batch_size] - prompts = [prompt_template.format(fact=item["fact"]) for item in batch_items] + prompts = [prompt_template.format(fact=item["fact"]) for item in batch_items] # type: ignore[index] inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) input_len = inputs.input_ids.shape[1] @@ -470,7 +474,7 @@ def reword_dataset_with_style( ) for item, output_text in zip(batch_items, decoded_batch): - new_facts.append(item["fact"]) + new_facts.append(item["fact"]) # type: ignore[index] new_reworded.append(output_text.strip()) # Build new dataset with all original columns plus 'reworded' @@ -502,7 +506,7 @@ def create_styled_datasets( if train_path.exists() and eval_path.exists(): print(f"Loading cached styled datasets from {output_dir}") - return load_from_disk(str(train_path)), load_from_disk(str(eval_path)) + return load_from_disk(str(train_path)), load_from_disk(str(eval_path)) # type: ignore[index] # Get base facts base_train, base_eval = create_attribute_dataset(config, output_dir) @@ -552,7 +556,7 @@ def create_styled_datasets( print(f" Train: {len(styled_train)} samples") print(f" Eval: {len(styled_eval)} samples") - return styled_train, styled_eval + return styled_train, styled_eval # type: ignore[index] def create_attribute_index( @@ -950,7 +954,7 @@ def compute_attribute_metrics( # Attribute-level matching (for top-1) top1_idx = top_k_idx[0] - top1_occ = train_occupations[top1_idx] + # top1_occ = train_occupations[top1_idx] top1_field = train_fields[top1_idx] top1_value = train_values[top1_idx] @@ -1022,8 +1026,8 @@ def print_attribute_metrics(metrics: AttributePreservationMetrics, name: str) -> print(f" Top-10: {metrics.top10_style_only_match:.2%}") print("\nPer-Field Top-1 Accuracy:") - for field, acc in sorted(metrics.top1_by_field.items()): - print(f" {field}: {acc:.2%}") + for field_name, acc in sorted(metrics.top1_by_field.items()): + print(f" {field_name}: {acc:.2%}") def compute_style_preconditioner_from_data( @@ -1640,7 +1644,8 @@ def run_attribute_preservation_experiment( for occ, style in config.style_occupation_map.items(): print(f" {occ}: {style}") print( - f" Eval occupation: {config.eval_occupation} (queried in {config.eval_style} style)" + f" Eval occupation: {config.eval_occupation} " + f"(queried in {config.eval_style} style)" ) print(f" People per occupation: {config.people_per_occupation}") @@ -1705,7 +1710,8 @@ def run_attribute_preservation_experiment( print("=" * 70) print( - f"\n{'Strategy':<25} {'Fact Acc':<12} {'Occ Acc':<12} {'Style Only':<12} {'Trade-off':<12}" + f"\n{'Strategy':<25} {'Fact Acc':<12} {'Occ Acc':<12} " + f"{'Style Only':<12} {'Trade-off':<12}" ) print("-" * 73) @@ -1714,21 +1720,25 @@ def run_attribute_preservation_experiment( # A good trade-off is when occ_acc is high and style_only is low trade_off = m.top1_occupation_accuracy - m.top1_style_only_match print( - f"{name:<25} {m.top1_fact_accuracy:<12.2%} {m.top1_occupation_accuracy:<12.2%} " + f"{name:<25} {m.top1_fact_accuracy:<12.2%} " + f"{m.top1_occupation_accuracy:<12.2%} " f"{m.top1_style_only_match:<12.2%} {trade_off:<12.2%}" ) print("\nInterpretation:") print(" - Fact Accuracy: How well we match exact facts (semantic matching)") print( - " - Occupation Accuracy: How well we match occupation cluster (attribute preservation)" + " - Occupation Accuracy: How well we match occupation cluster " + "(attribute preservation)" ) print( - " - Style Only: Matches where style matches but occupation doesn't (should be LOW)" + " - Style Only: Matches where style matches but occupation doesn't " + "(should be LOW)" ) print(" - Trade-off: Occ Acc - Style Only (higher is better)") print( - " - majority_no_precond: Control showing baseline when eval style matches training" + " - majority_no_precond: Control showing baseline when eval " + "style matches training" ) baseline = all_metrics.get("no_precond") @@ -1769,9 +1779,16 @@ def run_attribute_preservation_experiment( print(f" Occupation Accuracy: {majority.top1_occupation_accuracy:.2%}") if baseline and r_between: + style_reduction = ( + baseline.top1_style_only_match - r_between.top1_style_only_match + ) + occ_change = ( + r_between.top1_occupation_accuracy - baseline.top1_occupation_accuracy + ) if style_reduction > 0 and occ_change >= -0.05: print( - "\n✓ SUCCESS: Style suppression works without damaging attribute preservation!" + "\n✓ SUCCESS: Style suppression works without " + "damaging attribute preservation!" ) elif style_reduction > 0 and occ_change < -0.05: print("\n⚠ PARTIAL: Style suppressed but attribute preservation damaged") diff --git a/examples/semantic/data.py b/examples/semantic/data.py index 35493e61..c85b51b8 100644 --- a/examples/semantic/data.py +++ b/examples/semantic/data.py @@ -1,6 +1,7 @@ """Dataset creation and rewording utilities for semantic experiments.""" from pathlib import Path +from typing import Any, cast import torch from datasets import ( @@ -27,8 +28,10 @@ def load_experiment_data( """Load experiment data from HuggingFace or local disk. Args: - base_path: Local path containing data/*.hf directories. Required if hf_repo is None. - hf_repo: HuggingFace dataset repo ID (e.g., "EleutherAI/bergson-asymmetric-style"). + base_path: Local path containing data/*.hf directories. + Required if hf_repo is None. + hf_repo: HuggingFace dataset repo ID + (e.g., "EleutherAI/bergson-asymmetric-style"). If provided, downloads from HF and ignores base_path. splits: Optional list of splits to load. If None, loads all available splits. @@ -46,11 +49,13 @@ def load_experiment_data( data = load_experiment_data(hf_repo="...", splits=["train", "eval"]) """ if hf_repo: - dataset_dict = load_dataset(hf_repo) + loaded = load_dataset(hf_repo) + if not isinstance(loaded, DatasetDict): + raise TypeError(f"Expected DatasetDict from HF, got {type(loaded)}") + dataset_dict: DatasetDict = loaded if splits: - dataset_dict = DatasetDict( - {k: dataset_dict[k] for k in splits if k in dataset_dict} - ) + filtered = {k: dataset_dict[k] for k in splits if k in dataset_dict} + dataset_dict = DatasetDict(cast(Any, filtered)) return dataset_dict if base_path is None: @@ -71,12 +76,13 @@ def load_experiment_data( if not available_splits: raise FileNotFoundError(f"No .hf datasets found in {data_path}") - return DatasetDict( - { - split: load_from_disk(str(data_path / f"{split}.hf")) - for split in available_splits - } - ) + result: dict[str, Dataset] = {} + for split in available_splits: + ds = load_from_disk(str(data_path / f"{split}.hf")) + if isinstance(ds, DatasetDict): + ds = ds["train"] + result[split] = ds + return DatasetDict(cast(Any, result)) def reword( @@ -119,7 +125,7 @@ def reword( for i in tqdm(range(0, len(data_list), batch_size)): # 1. Prepare the batch batch_items = data_list[i : i + batch_size] - prompts = [prompt_template.format(fact=item["fact"]) for item in batch_items] + prompts = [prompt_template.format(fact=item["fact"]) for item in batch_items] # type: ignore[index] # 2. Tokenize (Batch mode) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) @@ -148,7 +154,7 @@ def reword( # 6. Store results for item, output_text in zip(batch_items, decoded_batch): - new_facts.append(item["fact"]) + new_facts.append(item["fact"]) # type: ignore[index] new_reworded.append(output_text.strip()) # Reconstruct dataset @@ -184,8 +190,10 @@ def create_data() -> None: pirate_path = f"data/facts_dataset_pirate-{model_short}.hf" if not Path(pirate_path).exists(): prompt_pirate = ( - "Reword the following fact like it's coming from a pirate. Be creative!\n" - "Do not include any other text in your response, just the contents of the " + "Reword the following fact like it's coming from a pirate. " + "Be creative!\n" + "Do not include any other text in your response, " + "just the contents of the " "reworded fact.\n" "Fact: {fact}\n" "Your rewrite:" @@ -227,8 +235,8 @@ def create_qwen_only_dataset() -> Path: # Add back any dropped columns from original for col in original.column_names: if col not in ds.column_names: - orig_map = {row["fact"]: row for row in original} - restored_col = [orig_map[row["fact"]][col] for row in ds] + orig_map = {row["fact"]: row for row in original} # type: ignore[index] + restored_col = [orig_map[row["fact"]][col] for row in ds] # type: ignore[index] ds = ds.add_column(col, restored_col) merged_datasets.append(ds) diff --git a/examples/semantic/experiment.py b/examples/semantic/experiment.py index 2d2c4582..b6995026 100644 --- a/examples/semantic/experiment.py +++ b/examples/semantic/experiment.py @@ -214,10 +214,10 @@ def main() -> None: if col not in ds.column_names: # Align ds length with original by matching on "fact" # Create a mapping from fact -> row - orig_map = {row["fact"]: row for row in original} + orig_map = {row["fact"]: row for row in original} # type: ignore[index] # Build list for restored column - restored_col = [orig_map[row["fact"]][col] for row in ds] + restored_col = [orig_map[row["fact"]][col] for row in ds] # type: ignore[index] ds = ds.add_column(col, restored_col) diff --git a/examples/semantic/metrics.py b/examples/semantic/metrics.py index 7563aa50..c89481af 100644 --- a/examples/semantic/metrics.py +++ b/examples/semantic/metrics.py @@ -43,7 +43,7 @@ def build_style_lookup(include_llama: bool = False) -> dict[tuple[str, str], str if isinstance(ds, DatasetDict): ds = ds["train"] for row in ds: - style_lookup[(row["fact"], row["reworded"])] = style_name + style_lookup[(row["fact"], row["reworded"])] = style_name # type: ignore[index] return style_lookup @@ -241,7 +241,7 @@ def compute_metrics( if isinstance(ds, DatasetDict): ds = ds["train"] for row in ds: - style_lookup[(row["fact"], row["reworded"])] = style_name + style_lookup[(row["fact"], row["reworded"])] = style_name # type: ignore[index] # Extract metadata identifiers = meta_ds["identifier"] @@ -288,7 +288,8 @@ def compute_metrics( fields_t = torch.tensor([field_to_idx[f] for f in fields]) styles_t = torch.tensor([style_to_idx[s] for s in styles]) - # Build masks for upper triangle (i < j to avoid double counting and self-similarity) + # Build masks for upper triangle + # (i < j to avoid double counting and self-similarity) row_idx, col_idx = torch.triu_indices(n, n, offset=1) # Get similarities for upper triangle pairs @@ -332,7 +333,8 @@ def compute_mean(mask: torch.Tensor) -> float: print("\nFact (same person+field = same underlying fact):") print(f" Intra-fact mean: {stats['intra_fact']:.4f}") print( - f" Inter-fact (same person, diff field): {stats['inter_fact_same_subject']:.4f}" + f" Inter-fact (same person, diff field): " + f"{stats['inter_fact_same_subject']:.4f}" ) print(f" Difference: {stats['intra_fact'] - stats['inter_fact_same_subject']:.4f}") diff --git a/examples/semantic/preconditioners.py b/examples/semantic/preconditioners.py index 6cf1677b..24aa8932 100644 --- a/examples/semantic/preconditioners.py +++ b/examples/semantic/preconditioners.py @@ -32,7 +32,8 @@ def _load_gradients_as_float(grads: np.memmap, name: str) -> np.ndarray: def build_style_indices(analysis_model: str = "tmp/checkpoint-282") -> None: - """Build separate indices for pirate and shakespeare to get separate preconditioners. + """Build separate indices for pirate and shakespeare to + get separate preconditioners. Args: analysis_model: Model to use for gradient collection. @@ -187,7 +188,8 @@ def compute_between_preconditioner_means( shakespeare_index_path: Path | str, output_path: Path | str, ) -> GradientProcessor: - """Compute R_between = (mu_pirate - mu_shakespeare)(mu_pirate - mu_shakespeare)^T per module. + """Compute R_between = + (mu_pirate - mu_shakespeare)(mu_pirate - mu_shakespeare)^T per module. This creates a rank-1 preconditioner from the difference in class means. More targeted than the covariance method - captures exactly the "style direction". @@ -368,8 +370,8 @@ def compute_summed_loss_preconditioner( shakespeare_ds = shakespeare_ds["train"] # Build fact -> index mapping - pirate_facts = pirate_ds["fact"] - shakespeare_facts = shakespeare_ds["fact"] + pirate_facts = pirate_ds["fact"] # type: ignore[index] + shakespeare_facts = shakespeare_ds["fact"] # type: ignore[index] pirate_fact_to_idx = {f: i for i, f in enumerate(pirate_facts)} shakespeare_fact_to_idx = {f: i for i, f in enumerate(shakespeare_facts)} @@ -472,8 +474,8 @@ def compute_pca_style_subspace( shakespeare_ds = shakespeare_ds["train"] # Build fact -> index mapping - pirate_facts = pirate_ds["fact"] - shakespeare_facts = shakespeare_ds["fact"] + pirate_facts = pirate_ds["fact"] # type: ignore[index] + shakespeare_facts = shakespeare_ds["fact"] # type: ignore[index] pirate_fact_to_idx = {f: i for i, f in enumerate(pirate_facts)} shakespeare_fact_to_idx = {f: i for i, f in enumerate(shakespeare_facts)} @@ -597,7 +599,8 @@ def compute_eval_preconditioner( Args: eval_grads_path: Path to eval gradients index. output_path: Path to save the preconditioner. - reference_proc_path: Path to a reference processor for module names (if eval has none). + reference_proc_path: Path to a reference processor + for module names (if eval has none). Returns: The computed GradientProcessor. @@ -624,7 +627,8 @@ def compute_eval_preconditioner( info = json.load(f) module_names = info["dtype"]["names"] - # Load a reference processor to get metadata (use reference if eval doesn't have precs) + # Load a reference processor to get metadata + # (use reference if eval doesn't have precs) if reference_proc_path: base_proc = GradientProcessor.load(Path(reference_proc_path)) else: @@ -684,7 +688,8 @@ def compute_train_eval_mixed_preconditioner( return GradientProcessor.load(output_path) print( - f"Computing train-eval mixed preconditioner ({train_weight:.0%} train, {1-train_weight:.0%} eval)..." + f"Computing train-eval mixed preconditioner ({train_weight:.0%} " + f"train, {1-train_weight:.0%} eval)..." ) train_path = Path(train_index_path) diff --git a/examples/train_lora.py b/examples/train_lora.py index 66689240..1060ce46 100644 --- a/examples/train_lora.py +++ b/examples/train_lora.py @@ -153,8 +153,8 @@ def validate_scheduler(cls, v): class NoShuffleSFTTrainer(SFTTrainer): - def _get_train_sampler(self, dataset): # <-- Add 'dataset' parameter - sampler = SequentialSampler(dataset) + def _get_train_sampler(self, train_dataset): # type: ignore[override] + sampler = SequentialSampler(train_dataset) # type: ignore[arg-type] return sampler @@ -210,7 +210,7 @@ def train(training_cfg: TrainingConfig, dataset: Dataset): ddp_find_unused_parameters=False, fp16=True, gradient_accumulation_steps=training_cfg.gradient_accumulation_steps, - learning_rate=training_cfg.learning_rate, + learning_rate=float(training_cfg.learning_rate), logging_steps=1, lr_scheduler_type=training_cfg.lr_scheduler_type, max_length=training_cfg.max_seq_length, @@ -304,6 +304,8 @@ def main(): ), ) ) + if not isinstance(dataset, Dataset): + raise TypeError(f"Expected Dataset, got {type(dataset)}") train(training_config, dataset) diff --git a/skills/asymmetric-style.md b/skills/asymmetric-style.md index 270bc5d3..8798fd5a 100644 --- a/skills/asymmetric-style.md +++ b/skills/asymmetric-style.md @@ -24,7 +24,7 @@ Options: - Train: 95% shakespeare (dominant), 5% pirate (minority) - Eval: pirate style queries for facts only in shakespeare style in train 2. Tests whether gradient-based attribution can find semantic matches despite style mismatch -3. Compares strategies: baseline, preconditioners (R_between, H_eval, H_train), PCA projection, summed gradients +3. Compares strategies: baseline, preconditioners (R_between, H_eval, H_train), PCA projection, summed gradients, semantic-only eval ## Strategies Tested @@ -57,6 +57,9 @@ Transform gradients by `g' = g @ H^(-1)` before computing similarity, downweight ### Controls - **majority_no_precond**: Query in the majority (shakespeare) style—no style mismatch. This is the upper bound showing what's achievable when styles match. +### Semantic-only Eval (Best Performing) +- **semantic_index**, **semantic_no_precond**: Transform eval data into Q&A format like `"Where does Paul Tilmouth work? Siemens"` and mask all gradients up to the `?`. This isolates the semantic content (answer tokens) from any style in the query. Combined with H_train preconditioning (`semantic_index`), this achieves the best results by a significant margin. + ## Instructions ### Run full experiment (using HuggingFace data) @@ -142,10 +145,10 @@ with open(base_path / "experiment_results.json") as f: results = json.load(f) sorted_results = sorted(results.items(), key=lambda x: -x[1]["top1_semantic"]) -print(f"{'Strategy':<35} {'Top-1 Sem':<12} {'Top-1 Leak':<12} {'Exact':<10}") -print("-" * 70) +print(f"{'Strategy':<35} {'Top-1 Sem':<12} {'Top-5 Recall':<13} {'Top-1 Leak':<12}") +print("-" * 72) for name, m in sorted_results: - print(f"{name:<35} {m['top1_semantic']:<12.2%} {m['top1_leak']:<12.2%} {m['exact']:<10.2%}") + print(f"{name:<35} {m['top1_semantic']:<12.2%} {m['top5_semantic_recall']:<13.2%} {m['top1_leak']:<12.2%}") ``` ## Cached Data @@ -189,8 +192,8 @@ rm -rf runs/asymmetric_style/ ## Key Metrics - **Top-1 Semantic Accuracy**: Top match has same underlying fact (higher is better) +- **Top-5 Semantic Recall**: Any of top-5 matches has same underlying fact (higher is better) - **Top-1 Style Leakage**: Top match is minority style (lower is better - means not style matching) -- **Exact Match**: Same fact AND dominant style (higher is better) ## Datasets & Models @@ -211,13 +214,17 @@ The datasets and fine-tuned model for this experiment are available on Hugging F | Strategy | Top-1 Semantic | Notes | |----------|---------------|-------| -| majority_no_precond | 100.00% | Control: no style mismatch | -| summed_eval | 92.71% | Sum minority + majority style eval grads | -| summed_rewrites | 0.87% | Sum shakespeare + pirate (both non-training) | -| no_precond (baseline) | 0.87% | Pure style matching dominates | -| preconditioners | ~1-1.4% | Marginal improvement | - -**Main insight**: summed_eval works because one component matches training distribution, not because of general style cancellation. +| majority_no_precond | ~100% | Control: no style mismatch | +| semantic_index | ~95%+ | **Best**: Q&A format + H_train preconditioning | +| semantic_no_precond | ~90%+ | Q&A format without preconditioning | +| summed_eval | ~93% | Sum minority + majority style eval grads | +| summed_rewrites | <1% | Sum shakespeare + pirate (both non-training) | +| no_precond (baseline) | <1% | Pure style matching dominates | +| preconditioners alone | ~1-2% | Marginal improvement without semantic masking | + +**Main insights**: +- The semantic Q&A approach (masking style tokens, keeping only answer gradients) combined with H_train preconditioning achieves the best results +- summed_eval works because one component matches training distribution, not because of general style cancellation ## Similarity Metric Comparison