Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,9 @@ influence_results/
tmp/
.idea/
uv.lock
data/*.hf
zeki_requirements.txt
.python-version
*package-lock.json
*package.json
david_wips/
2 changes: 1 addition & 1 deletion bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build_worker(
init_method=f"tcp://{addr}:{port}",
device_id=torch.device(f"cuda:{local_rank}"),
rank=rank,
timeout=timedelta(hours=1),
timeout=timedelta(minutes=2),
world_size=world_size,
)

Expand Down
1 change: 1 addition & 0 deletions bergson/collector/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def fwd_bwd(model, x: Tensor, y: Tensor, batch: dict):
logits.reshape(-1, logits.size(-1)),
y[:, 1:].flatten(),
reduction="none",
label_smoothing=cfg.label_smoothing,
).reshape_as(y[:, 1:])
losses = losses.sum(1) / denoms
if "advantage" in batch:
Expand Down
5 changes: 5 additions & 0 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ class IndexConfig:
loss_reduction: Literal["mean", "sum"] = "mean"
"""Reduction method for the loss function."""

label_smoothing: float = 0.0
"""Label smoothing coefficient for cross-entropy loss. When > 0, prevents
near-zero gradients for high-confidence predictions that can cause numerical
instability. Recommended value: 0.005-0.01."""

stream_shard_size: int = 400_000
"""Shard size for streaming the dataset into Dataset objects."""

Expand Down
15 changes: 5 additions & 10 deletions bergson/score/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from bergson.gradients import GradientProcessor
from bergson.score.score_writer import MemmapScoreWriter
from bergson.score.scorer import Scorer
from bergson.utils.math import compute_damped_inverse
from bergson.utils.utils import (
assert_type,
convert_precision_to_torch,
Expand Down Expand Up @@ -177,16 +178,10 @@ def precondition_ds(
)

# Compute H^(-1) via eigendecomposition and apply to query gradients
h_inv = {}
for name, H in mixed_preconditioner.items():
H = H.to(device=device, dtype=torch.float64)
damping_val = 0.1 * H.abs().mean()
H = H + damping_val * torch.eye(H.shape[0], device=H.device, dtype=H.dtype)

eigval, eigvec = torch.linalg.eigh(H)
h_inv[name] = (eigvec * (1.0 / eigval) @ eigvec.mT).to(
mixed_preconditioner[name].dtype
)
h_inv = {
name: compute_damped_inverse(H.to(device=device))
for name, H in mixed_preconditioner.items()
}

def precondition(batch):
for name in target_modules:
Expand Down
34 changes: 34 additions & 0 deletions bergson/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,40 @@ def psd_rsqrt(A: Tensor) -> Tensor:
return rsqrt


def compute_damped_inverse(
H: Tensor,
damping_factor: float = 0.1,
dtype: torch.dtype = torch.float64,
regularizer: Tensor | None = None,
) -> Tensor:
"""Compute H^(-1) with damping for numerical stability.

Uses eigendecomposition to compute the inverse of a positive semi-definite
matrix with adaptive damping based on the matrix's mean absolute value.

Args:
H: Positive semi-definite matrix to invert.
damping_factor: Multiplier for the damping term (default: 0.1).
dtype: Dtype for intermediate computation (default: float64 for stability).
regularizer: Optional matrix to use as regularizer instead of identity.
If provided, computes inv(H + damping_factor * regularizer).
If None (default), uses scaled identity: inv(H + damping_factor * |H|_mean * I).

Returns:
The damped inverse H^(-1) in the original dtype of H.
"""
original_dtype = H.dtype
H = H.to(dtype=dtype)
if regularizer is not None:
regularizer = regularizer.to(dtype=dtype, device=H.device)
H = H + damping_factor * regularizer
else:
damping_val = damping_factor * H.abs().mean()
H = H + damping_val * torch.eye(H.shape[0], device=H.device, dtype=H.dtype)
eigval, eigvec = torch.linalg.eigh(H)
return (eigvec * (1.0 / eigval) @ eigvec.mT).to(original_dtype)


def trace(matrices: Tensor) -> Tensor:
"""Version of `torch.trace` that works for batches of matrices."""
diag = torch.linalg.diagonal(matrices)
Expand Down
1 change: 1 addition & 0 deletions bergson/utils/worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def setup_model_and_peft(
try:
peft_config = PeftConfig.from_pretrained(cfg.model)
except ValueError:
print(f"PEFT config not found for model {cfg.model}")
peft_config = None

if peft_config is None:
Expand Down
17 changes: 17 additions & 0 deletions data/generate_facts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from argparse import ArgumentParser

from datasets import Dataset

from .dataset import fact_generator

if __name__ == "__main__":
from argparse import ArgumentParser

from datasets import Dataset

parser = ArgumentParser()
parser.add_argument("--num_facts", type=int, default=1000)
args = parser.parse_args()

dataset = fact_generator(args.num_facts)
Dataset.from_list(list(dataset)).save_to_disk("data/facts_dataset.hf")
56 changes: 56 additions & 0 deletions examples/semantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Backward-compatible wrapper for semantic experiments.

This module re-exports all functions from the semantic package for backward
compatibility. New code should import directly from examples.semantic instead.
"""

# Re-export everything from the semantic package
from semantic import (
build_style_indices,
build_style_lookup,
compute_between_preconditioner,
compute_between_preconditioner_covariance,
compute_between_preconditioner_means,
compute_metrics,
compute_metrics_groupwise,
compute_mixed_preconditioner,
compute_scores_fast,
compute_scores_with_bergson,
create_data,
create_index,
create_qwen_only_dataset,
finetune,
load_scores_matrix,
main,
reword,
run_preconditioner_comparison,
)

__all__ = [
# Data creation
"reword",
"create_data",
"create_qwen_only_dataset",
# Scoring
"load_scores_matrix",
"compute_scores_fast",
"compute_scores_with_bergson",
# Metrics
"build_style_lookup",
"compute_metrics_groupwise",
"compute_metrics",
# Preconditioners
"build_style_indices",
"compute_between_preconditioner_covariance",
"compute_between_preconditioner_means",
"compute_between_preconditioner",
"compute_mixed_preconditioner",
# Experiment
"create_index",
"finetune",
"run_preconditioner_comparison",
"main",
]

if __name__ == "__main__":
main()
98 changes: 98 additions & 0 deletions examples/semantic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Semantic experiments for analyzing gradient-based embeddings.

This package provides tools for:
- Creating reworded datasets in different styles (Shakespeare, Pirate)
- Computing pairwise similarity scores from gradient embeddings
- Analyzing semantic vs stylistic similarity patterns
- Comparing different preconditioning strategies
- Asymmetric style distribution experiments for style suppression validation
"""

from .asymmetric import (
AsymmetricConfig,
AsymmetricMetrics,
compute_asymmetric_metrics,
compute_style_preconditioner,
create_asymmetric_dataset,
create_asymmetric_index,
print_metrics,
run_asymmetric_experiment,
score_asymmetric_eval,
)
from .attribute_preservation import (
AttributePreservationConfig,
AttributePreservationMetrics,
compute_attribute_metrics,
create_attribute_dataset,
create_attribute_index,
create_styled_datasets,
print_attribute_metrics,
run_attribute_preservation_experiment,
score_attribute_eval,
)
from .data import create_data, create_qwen_only_dataset, reword
from .experiment import (
create_index,
finetune,
main,
run_preconditioner_comparison,
)
from .metrics import build_style_lookup, compute_metrics, compute_metrics_groupwise
from .preconditioners import (
build_style_indices,
compute_between_preconditioner,
compute_between_preconditioner_covariance,
compute_between_preconditioner_means,
compute_mixed_preconditioner,
)
from .scoring import (
compute_scores_fast,
compute_scores_with_bergson,
load_scores_matrix,
)

__all__ = [
# Data creation
"reword",
"create_data",
"create_qwen_only_dataset",
# Scoring
"load_scores_matrix",
"compute_scores_fast",
"compute_scores_with_bergson",
# Metrics
"build_style_lookup",
"compute_metrics_groupwise",
"compute_metrics",
# Preconditioners
"build_style_indices",
"compute_between_preconditioner_covariance",
"compute_between_preconditioner_means",
"compute_between_preconditioner",
"compute_mixed_preconditioner",
# Experiment
"create_index",
"finetune",
"run_preconditioner_comparison",
"main",
# Asymmetric style experiment
"AsymmetricConfig",
"AsymmetricMetrics",
"create_asymmetric_dataset",
"create_asymmetric_index",
"score_asymmetric_eval",
"compute_asymmetric_metrics",
"compute_style_preconditioner",
"print_metrics",
"run_asymmetric_experiment",
# Attribute preservation experiment
"AttributePreservationConfig",
"AttributePreservationMetrics",
"create_attribute_dataset",
"create_attribute_index",
"create_styled_datasets",
"score_attribute_eval",
"compute_attribute_metrics",
"print_attribute_metrics",
"run_attribute_preservation_experiment",
]
Loading
Loading