Skip to content

Commit dee399c

Browse files
authored
feat: Resume training from checkpoints (#561)
* setting up saving / resuming from checkpoints in training * adding some tests for checkpointing * more tests * adding docs * reverting notebooks, not sure why they changed * try busting CI cache * fixing tests * fixing tests * changes from CR
1 parent 5118ef6 commit dee399c

19 files changed

+614
-37
lines changed

docs/generate_sae_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pandas as pd
77
import yaml
8-
from tqdm import tqdm
8+
from tqdm.auto import tqdm
99

1010
from sae_lens import SAEConfig
1111
from sae_lens.loading.pretrained_sae_loaders import (

docs/training_saes.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,20 @@ Some general performance tips:
324324

325325
## Checkpoints
326326

327-
Checkpoints allow you to save a snapshot of the SAE and sparsitity statistics during training. To enable checkpointing, set `n_checkpoints` to a value larger than 0. If WandB logging is enabled, checkpoints will be uploaded as WandB artifacts. To save checkpoints locally, the `checkpoint_path` parameter can be set to a local directory.
327+
Checkpoints allow you to save a snapshot of the SAE and sparsitity statistics during training. To enable checkpointing, set `n_checkpoints` to a value larger than 0. If WandB logging is enabled, checkpoints will be uploaded as WandB artifacts. To save checkpoints locally, the `checkpoint_path` parameter can be set to a local directory. You can also set `save_final_checkpoint=True` to save a final checkpoint after training is finished.
328+
329+
To resume training from a saved checkpoint, set `resume_from_checkpoint` to the path of the checkpoint when creating a `LanguageModelSAETrainingRunner`, or set `--resume_from_checkpoint` when running the CLI.
330+
331+
```python
332+
333+
cfg = LanguageModelSAERunnerConfig(
334+
# ... other LanguageModelSAERunnerConfig parameters ...
335+
resume_from_checkpoint="path/to/checkpoint"
336+
)
337+
runner = LanguageModelSAETrainingRunner(cfg)
338+
runner.run()
339+
340+
```
328341

329342
## Optimizers and Schedulers
330343

sae_lens/cache_activations_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from datasets.fingerprint import generate_fingerprint
1111
from huggingface_hub import HfApi
1212
from jaxtyping import Float, Int
13-
from tqdm import tqdm
13+
from tqdm.auto import tqdm
1414
from transformer_lens.HookedTransformer import HookedRootModule
1515

1616
from sae_lens import logger

sae_lens/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171171
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172172
checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173173
save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
174+
resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174175
output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175176
verbose (bool): Whether to print verbose output. (default is True)
176177
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +262,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261262
checkpoint_path: str | None = "checkpoints"
262263
save_final_checkpoint: bool = False
263264
output_path: str | None = "output"
265+
resume_from_checkpoint: str | None = None
264266

265267
# Misc
266268
verbose: bool = True

sae_lens/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
SAE_CFG_FILENAME = "cfg.json"
1818
RUNNER_CFG_FILENAME = "runner_cfg.json"
1919
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20+
TRAINER_STATE_FILENAME = "trainer_state.pt"
2021
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
2122
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"

sae_lens/llm_sae_training_runner.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from sae_lens import logger
1717
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
1818
from sae_lens.constants import (
19-
ACTIVATIONS_STORE_STATE_FILENAME,
2019
RUNNER_CFG_FILENAME,
2120
SPARSITY_FILENAME,
2221
)
@@ -112,6 +111,7 @@ def __init__(
112111
override_dataset: HfDataset | None = None,
113112
override_model: HookedRootModule | None = None,
114113
override_sae: TrainingSAE[Any] | None = None,
114+
resume_from_checkpoint: Path | str | None = None,
115115
):
116116
if override_dataset is not None:
117117
logger.warning(
@@ -153,6 +153,7 @@ def __init__(
153153
)
154154
else:
155155
self.sae = override_sae
156+
156157
self.sae.to(self.cfg.device)
157158

158159
def run(self):
@@ -185,6 +186,12 @@ def run(self):
185186
cfg=self.cfg.to_sae_trainer_config(),
186187
)
187188

189+
if self.cfg.resume_from_checkpoint is not None:
190+
logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191+
trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192+
self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193+
self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194+
188195
self._compile_if_needed()
189196
sae = self.run_trainer_with_interruption_handling(trainer)
190197

@@ -304,9 +311,7 @@ def save_checkpoint(
304311
if checkpoint_path is None:
305312
return
306313

307-
self.activations_store.save(
308-
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
309-
)
314+
self.activations_store.save_to_checkpoint(checkpoint_path)
310315

311316
runner_config = self.cfg.to_dict()
312317
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:

sae_lens/saes/sae.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
from jaxtyping import Float
2323
from numpy.typing import NDArray
24-
from safetensors.torch import save_file
24+
from safetensors.torch import load_file, save_file
2525
from torch import nn
2626
from transformer_lens.hook_points import HookedRootModule, HookPoint
2727
from typing_extensions import deprecated, overload, override
@@ -1018,6 +1018,12 @@ def get_sae_config_class_for_architecture(
10181018
) -> type[TrainingSAEConfig]:
10191019
return get_sae_training_class(architecture)[1]
10201020

1021+
def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
1022+
checkpoint_path = Path(checkpoint_path)
1023+
state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
1024+
self.process_state_dict_for_loading(state_dict)
1025+
self.load_state_dict(state_dict)
1026+
10211027

10221028
_blank_hook = nn.Identity()
10231029

sae_lens/training/activation_scaler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from dataclasses import dataclass
3+
from pathlib import Path
34
from statistics import mean
45

56
import torch
@@ -51,3 +52,9 @@ def save(self, file_path: str):
5152

5253
with open(file_path, "w") as f:
5354
json.dump({"scaling_factor": self.scaling_factor}, f)
55+
56+
def load(self, file_path: str | Path):
57+
"""load the state dict from a file in json format"""
58+
with open(file_path) as f:
59+
data = json.load(f)
60+
self.scaling_factor = data["scaling_factor"]

sae_lens/training/activations_store.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import warnings
66
from collections.abc import Generator, Iterator, Sequence
7+
from pathlib import Path
78
from typing import Any, Literal, cast
89

910
import datasets
@@ -13,8 +14,8 @@
1314
from huggingface_hub.utils import HfHubHTTPError
1415
from jaxtyping import Float, Int
1516
from requests import HTTPError
16-
from safetensors.torch import save_file
17-
from tqdm import tqdm
17+
from safetensors.torch import load_file, save_file
18+
from tqdm.auto import tqdm
1819
from transformer_lens.hook_points import HookedRootModule
1920
from transformers import AutoTokenizer, PreTrainedTokenizerBase
2021

@@ -24,7 +25,7 @@
2425
HfDataset,
2526
LanguageModelSAERunnerConfig,
2627
)
27-
from sae_lens.constants import DTYPE_MAP
28+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
2829
from sae_lens.pretokenize_runner import get_special_token_from_cfg
2930
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
3031
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
@@ -729,6 +730,48 @@ def save(self, file_path: str):
729730
"""save the state dict to a file in safetensors format"""
730731
save_file(self.state_dict(), file_path)
731732

733+
def save_to_checkpoint(self, checkpoint_path: str | Path):
734+
"""Save the state dict to a checkpoint path"""
735+
self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
736+
737+
def load_from_checkpoint(self, checkpoint_path: str | Path):
738+
"""Load the state dict from a checkpoint path"""
739+
self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
740+
741+
def load(self, file_path: str):
742+
"""Load the state dict from a file in safetensors format"""
743+
744+
state_dict = load_file(file_path)
745+
746+
if "n_dataset_processed" in state_dict:
747+
target_n_dataset_processed = state_dict["n_dataset_processed"].item()
748+
749+
# Only fast-forward if needed
750+
751+
if target_n_dataset_processed > self.n_dataset_processed:
752+
logger.info(
753+
"Fast-forwarding through dataset samples to match checkpoint position"
754+
)
755+
samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
756+
757+
pbar = tqdm(
758+
total=samples_to_skip,
759+
desc="Fast-forwarding through dataset",
760+
leave=False,
761+
)
762+
while target_n_dataset_processed > self.n_dataset_processed:
763+
start = self.n_dataset_processed
764+
try:
765+
# Just consume and ignore the values to fast-forward
766+
next(self.iterable_sequences)
767+
except StopIteration:
768+
logger.warning(
769+
"Dataset exhausted during fast-forward. Resetting dataset."
770+
)
771+
self.iterable_sequences = self._iterate_tokenized_sequences()
772+
pbar.update(self.n_dataset_processed - start)
773+
pbar.close()
774+
732775

733776
def validate_pretokenized_dataset_tokenizer(
734777
dataset_path: str, model_tokenizer: PreTrainedTokenizerBase

sae_lens/training/optim.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
33
"""
44

5+
from typing import Any
6+
57
import torch.optim as optim
68
import torch.optim.lr_scheduler as lr_scheduler
79

@@ -150,3 +152,12 @@ def step(self) -> float:
150152
def value(self) -> float:
151153
"""Returns the current scalar value."""
152154
return self.current_value
155+
156+
def state_dict(self) -> dict[str, Any]:
157+
return {
158+
"current_step": self.current_step,
159+
}
160+
161+
def load_state_dict(self, state_dict: dict[str, Any]):
162+
for k in state_dict:
163+
setattr(self, k, state_dict[k])

0 commit comments

Comments
 (0)