Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/saev/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ class Config:
"""Where to log Slurm job stdout/stderr."""


def _get_sparsity_coeff(sae: nn.SparseAutoencoder) -> float | None:
sparsity = getattr(sae.activation.cfg, "sparsity", None)
if sparsity is None or not hasattr(sparsity, "coeff"):
return None
return float(sparsity.coeff)


def _set_sparsity_coeff(sae: nn.SparseAutoencoder, value: float) -> None:
sparsity = getattr(sae.activation.cfg, "sparsity", None)
if sparsity is None or not hasattr(sparsity, "coeff"):
return
object.__setattr__(sparsity, "coeff", float(value))


@beartype.beartype
def make_saes(
cfgs: list[tuple[nn.SparseAutoencoderConfig, nn.ObjectiveConfig]],
Expand Down Expand Up @@ -288,6 +302,7 @@ def train(
grouped_pgs: list[list[dict[str, object]]] = []
optimizers: list[list[torch.optim.Optimizer]] = []
lr_schedulers: list[list[saev.utils.scheduling.WarmupCosine]] = []
sparsity_schedulers: list[saev.utils.scheduling.Warmup | None] = []

for i, (sae, cfg, param_group) in enumerate(zip(saes, cfgs, param_groups)):
if cfg.optim == "adam":
Expand Down Expand Up @@ -318,6 +333,18 @@ def train(
optimizers.append(opts)
grouped_pgs.append(pgs)
lr_schedulers.append(scheds)
current_sparsity_coeff = _get_sparsity_coeff(sae)
if current_sparsity_coeff is None:
sparsity_schedulers.append(None)
else:
_set_sparsity_coeff(sae, 0.0)
sparsity_schedulers.append(
saev.utils.scheduling.Warmup(
0.0,
current_sparsity_coeff,
cfg.n_sparsity_warmup,
)
)

param_groups = grouped_pgs

Expand Down Expand Up @@ -420,6 +447,7 @@ def train(
**{f"loss/{key}": val for key, val in loss.metrics().items()},
"progress/n_patches_seen": n_patches_seen,
"progress/learning_rate": current_lr,
"progress/sparsity_coeff": _get_sparsity_coeff(sae) or 0.0,
"metrics/explained_variance": explained_var.item(),
"metrics/dead_unit_pct": dead_pct.item(),
"metrics/dictionary_coherence": coherence.item(),
Expand Down Expand Up @@ -450,8 +478,9 @@ def train(
for pg, sched in zip(pgs, scheds):
pg["lr"] = sched.step()

# for objective, scheduler in zip(objectives, sparsity_schedulers):
# objective.sparsity_coeff = scheduler.step()
for sae, scheduler in zip(saes, sparsity_schedulers):
if scheduler is not None:
_set_sparsity_coeff(sae, scheduler.step())

for opts in optimizers:
for opt in opts:
Expand Down