Skip to content

Commit 780aae9

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f82f265 commit 780aae9

File tree

8 files changed

+519
-214
lines changed

8 files changed

+519
-214
lines changed

examples/semantic/asymmetric.py

Lines changed: 302 additions & 134 deletions
Large diffs are not rendered by default.

examples/semantic/attribute_preservation.py

Lines changed: 158 additions & 53 deletions
Large diffs are not rendered by default.

examples/semantic/data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from transformers import AutoModelForCausalLM, AutoTokenizer
99

1010

11-
def reword(dataset: Dataset, model_name: str, prompt_template: str, batch_size: int = 8) -> Dataset:
11+
def reword(
12+
dataset: Dataset, model_name: str, prompt_template: str, batch_size: int = 8
13+
) -> Dataset:
1214
"""Reword facts in a dataset using a language model.
1315
1416
Args:

examples/semantic/experiment.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import subprocess
44
from pathlib import Path
5-
from typing import Any
65

76
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
87

@@ -51,7 +50,9 @@ def create_index(dataset_name: str, analysis_model_name: str) -> None:
5150
print(result.stderr)
5251

5352

54-
def finetune(dataset_path: str, analysis_model_name: str, finetuned_model_path: str) -> None:
53+
def finetune(
54+
dataset_path: str, analysis_model_name: str, finetuned_model_path: str
55+
) -> None:
5556
"""Finetune a model on a dataset using LoRA.
5657
5758
Args:
@@ -145,9 +146,7 @@ def run_preconditioner_comparison() -> dict[str, dict[str, float]]:
145146
compute_scores_fast(
146147
base_path / "combined", # Use precomputed gradients from combined index
147148
output_path,
148-
preconditioner_path=(
149-
base_path / prec_path if prec_path else None
150-
),
149+
preconditioner_path=(base_path / prec_path if prec_path else None),
151150
)
152151

153152
# 4. Compare metrics across strategies
@@ -179,7 +178,9 @@ def run_preconditioner_comparison() -> dict[str, dict[str, float]]:
179178
style_diff = s.get("intra_style", 0) - s.get("inter_style", 0)
180179
fact_diff = s.get("intra_fact", 0) - s.get("inter_fact_same_subject", 0)
181180
subj_diff = s.get("intra_subject", 0) - s.get("inter_subject", 0)
182-
print(f"{name:<15} {style_diff:<12.4f} {fact_diff:<12.4f} {subj_diff:<12.4f}")
181+
print(
182+
f"{name:<15} {style_diff:<12.4f} {fact_diff:<12.4f} {subj_diff:<12.4f}"
183+
)
183184

184185
return all_stats
185186

examples/semantic/metrics.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import torch
9-
from datasets import Dataset, DatasetDict, load_from_disk
9+
from datasets import DatasetDict, load_from_disk
1010

1111
from bergson import load_gradient_dataset
1212
from bergson.data import load_gradients
@@ -29,10 +29,15 @@ def build_style_lookup(include_llama: bool = False) -> dict[tuple[str, str], str
2929
("data/facts_dataset_pirate-Qwen3-8B-Base.hf", "pirate"),
3030
]
3131
if include_llama:
32-
style_datasets.extend([
33-
("data/facts_dataset_shakespeare-Meta-Llama-3-8B.hf", "shakespeare-llama"),
34-
("data/facts_dataset_pirate-Meta-Llama-3-8B.hf", "pirate-llama"),
35-
])
32+
style_datasets.extend(
33+
[
34+
(
35+
"data/facts_dataset_shakespeare-Meta-Llama-3-8B.hf",
36+
"shakespeare-llama",
37+
),
38+
("data/facts_dataset_pirate-Meta-Llama-3-8B.hf", "pirate-llama"),
39+
]
40+
)
3641
for path, style_name in style_datasets:
3742
ds = load_from_disk(path)
3843
if isinstance(ds, DatasetDict):
@@ -319,27 +324,29 @@ def compute_mean(mask: torch.Tensor) -> float:
319324
print("SEMANTIC SIMILARITY RESULTS")
320325
print("=" * 60)
321326

322-
print(f"\nSubject (same person vs different person):")
327+
print("\nSubject (same person vs different person):")
323328
print(f" Intra-subject mean: {stats['intra_subject']:.4f}")
324329
print(f" Inter-subject mean: {stats['inter_subject']:.4f}")
325330
print(f" Difference: {stats['intra_subject'] - stats['inter_subject']:.4f}")
326331

327-
print(f"\nFact (same person+field = same underlying fact):")
332+
print("\nFact (same person+field = same underlying fact):")
328333
print(f" Intra-fact mean: {stats['intra_fact']:.4f}")
329-
print(f" Inter-fact (same person, diff field): {stats['inter_fact_same_subject']:.4f}")
334+
print(
335+
f" Inter-fact (same person, diff field): {stats['inter_fact_same_subject']:.4f}"
336+
)
330337
print(f" Difference: {stats['intra_fact'] - stats['inter_fact_same_subject']:.4f}")
331338

332-
print(f"\nField (same field type, e.g. birthdate, employer):")
339+
print("\nField (same field type, e.g. birthdate, employer):")
333340
print(f" Intra-field mean: {stats['intra_field']:.4f}")
334341
print(f" Inter-field mean: {stats['inter_field']:.4f}")
335342
print(f" Difference: {stats['intra_field'] - stats['inter_field']:.4f}")
336343

337-
print(f"\nTemplate (same original phrasing template):")
344+
print("\nTemplate (same original phrasing template):")
338345
print(f" Intra-template mean: {stats['intra_template']:.4f}")
339346
print(f" Inter-template mean: {stats['inter_template']:.4f}")
340347
print(f" Difference: {stats['intra_template'] - stats['inter_template']:.4f}")
341348

342-
print(f"\nStyle (same rewording style):")
349+
print("\nStyle (same rewording style):")
343350
print(f" Intra-style mean: {stats['intra_style']:.4f}")
344351
print(f" Inter-style mean: {stats['inter_style']:.4f}")
345352
print(f" Difference: {stats['intra_style'] - stats['inter_style']:.4f}")

examples/semantic/preconditioners.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,16 @@ def compute_summed_loss_preconditioner(
351351
shakespeare_grads = load_gradients(shakespeare_path, structured=True)
352352

353353
# Load datasets to match facts
354-
pirate_ds = load_from_disk(str(pirate_path.parent / "pirate" / "dataset") if (pirate_path.parent / "pirate" / "dataset").exists() else "data/facts_dataset_pirate-Qwen3-8B-Base.hf")
355-
shakespeare_ds = load_from_disk(str(shakespeare_path.parent / "shakespeare" / "dataset") if (shakespeare_path.parent / "shakespeare" / "dataset").exists() else "data/facts_dataset_shakespeare-Qwen3-8B-Base.hf")
354+
pirate_ds = load_from_disk(
355+
str(pirate_path.parent / "pirate" / "dataset")
356+
if (pirate_path.parent / "pirate" / "dataset").exists()
357+
else "data/facts_dataset_pirate-Qwen3-8B-Base.hf"
358+
)
359+
shakespeare_ds = load_from_disk(
360+
str(shakespeare_path.parent / "shakespeare" / "dataset")
361+
if (shakespeare_path.parent / "shakespeare" / "dataset").exists()
362+
else "data/facts_dataset_shakespeare-Qwen3-8B-Base.hf"
363+
)
356364

357365
if hasattr(pirate_ds, "keys"):
358366
pirate_ds = pirate_ds["train"]
@@ -367,7 +375,9 @@ def compute_summed_loss_preconditioner(
367375
shakespeare_fact_to_idx = {f: i for i, f in enumerate(shakespeare_facts)}
368376

369377
# Find common facts (contrastive pairs) and build aligned index arrays
370-
common_facts = list(set(pirate_fact_to_idx.keys()) & set(shakespeare_fact_to_idx.keys()))
378+
common_facts = list(
379+
set(pirate_fact_to_idx.keys()) & set(shakespeare_fact_to_idx.keys())
380+
)
371381
pirate_indices = [pirate_fact_to_idx[f] for f in common_facts]
372382
shakespeare_indices = [shakespeare_fact_to_idx[f] for f in common_facts]
373383
print(f" Found {len(common_facts)} contrastive pairs")
@@ -469,7 +479,9 @@ def compute_pca_style_subspace(
469479
shakespeare_fact_to_idx = {f: i for i, f in enumerate(shakespeare_facts)}
470480

471481
# Find common facts and build aligned index arrays
472-
common_facts = list(set(pirate_fact_to_idx.keys()) & set(shakespeare_fact_to_idx.keys()))
482+
common_facts = list(
483+
set(pirate_fact_to_idx.keys()) & set(shakespeare_fact_to_idx.keys())
484+
)
473485
pirate_indices = [pirate_fact_to_idx[f] for f in common_facts]
474486
shakespeare_indices = [shakespeare_fact_to_idx[f] for f in common_facts]
475487
print(f" Found {len(common_facts)} contrastive pairs")
@@ -505,7 +517,9 @@ def compute_pca_style_subspace(
505517
# Get top-k (largest eigenvalues are at the end)
506518
k = min(top_k, eigvals.shape[0])
507519
top_eigvals = eigvals[-k:].flip(0) # Descending order
508-
top_eigvecs = eigvecs[:, -k:].flip(1) # [d, k], columns are principal components
520+
top_eigvecs = eigvecs[:, -k:].flip(
521+
1
522+
) # [d, k], columns are principal components
509523

510524
style_subspace[name] = (top_eigvecs, top_eigvals)
511525

@@ -669,7 +683,9 @@ def compute_train_eval_mixed_preconditioner(
669683
print(f"Loading cached train-eval mixed preconditioner from {output_path}")
670684
return GradientProcessor.load(output_path)
671685

672-
print(f"Computing train-eval mixed preconditioner ({train_weight:.0%} train, {1-train_weight:.0%} eval)...")
686+
print(
687+
f"Computing train-eval mixed preconditioner ({train_weight:.0%} train, {1-train_weight:.0%} eval)..."
688+
)
673689

674690
train_path = Path(train_index_path)
675691
eval_path = Path(eval_grads_path)

examples/semantic/scoring.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ def load_scores_matrix(scores_path: Path | str) -> np.ndarray:
3232

3333
# Handle both tuple format (from bergson) and list format (from JSON serialization)
3434
dtype_spec = info["dtype"]
35-
if isinstance(dtype_spec, list) and len(dtype_spec) > 0 and isinstance(dtype_spec[0], list):
35+
if (
36+
isinstance(dtype_spec, list)
37+
and len(dtype_spec) > 0
38+
and isinstance(dtype_spec[0], list)
39+
):
3640
# Convert list of lists back to list of tuples
3741
dtype_spec = [tuple(item) for item in dtype_spec]
3842

examples/train_lora.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,10 @@ def main():
274274
parser.add_argument("--split", type=str, default="test")
275275
parser.add_argument("--prompt_column", type=str, default="prompt")
276276
parser.add_argument("--completion_column", type=str, default="completion")
277-
parser.add_argument("--no_push_to_private", action="store_false", dest="push_to_private")
278-
277+
parser.add_argument(
278+
"--no_push_to_private", action="store_false", dest="push_to_private"
279+
)
280+
279281
args = parser.parse_args()
280282

281283
training_config = TrainingConfig( # type: ignore

0 commit comments

Comments
 (0)