|
4 | 4 | import os |
5 | 5 | import warnings |
6 | 6 | from collections.abc import Generator, Iterator, Sequence |
| 7 | +from pathlib import Path |
7 | 8 | from typing import Any, Literal, cast |
8 | 9 |
|
9 | 10 | import datasets |
|
13 | 14 | from huggingface_hub.utils import HfHubHTTPError |
14 | 15 | from jaxtyping import Float, Int |
15 | 16 | from requests import HTTPError |
16 | | -from safetensors.torch import save_file |
17 | | -from tqdm import tqdm |
| 17 | +from safetensors.torch import load_file, save_file |
| 18 | +from tqdm.auto import tqdm |
18 | 19 | from transformer_lens.hook_points import HookedRootModule |
19 | 20 | from transformers import AutoTokenizer, PreTrainedTokenizerBase |
20 | 21 |
|
|
24 | 25 | HfDataset, |
25 | 26 | LanguageModelSAERunnerConfig, |
26 | 27 | ) |
27 | | -from sae_lens.constants import DTYPE_MAP |
| 28 | +from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP |
28 | 29 | from sae_lens.pretokenize_runner import get_special_token_from_cfg |
29 | 30 | from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG |
30 | 31 | from sae_lens.tokenization_and_batching import concat_and_batch_sequences |
@@ -729,6 +730,48 @@ def save(self, file_path: str): |
729 | 730 | """save the state dict to a file in safetensors format""" |
730 | 731 | save_file(self.state_dict(), file_path) |
731 | 732 |
|
| 733 | + def save_to_checkpoint(self, checkpoint_path: str | Path): |
| 734 | + """Save the state dict to a checkpoint path""" |
| 735 | + self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME)) |
| 736 | + |
| 737 | + def load_from_checkpoint(self, checkpoint_path: str | Path): |
| 738 | + """Load the state dict from a checkpoint path""" |
| 739 | + self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME)) |
| 740 | + |
| 741 | + def load(self, file_path: str): |
| 742 | + """Load the state dict from a file in safetensors format""" |
| 743 | + |
| 744 | + state_dict = load_file(file_path) |
| 745 | + |
| 746 | + if "n_dataset_processed" in state_dict: |
| 747 | + target_n_dataset_processed = state_dict["n_dataset_processed"].item() |
| 748 | + |
| 749 | + # Only fast-forward if needed |
| 750 | + |
| 751 | + if target_n_dataset_processed > self.n_dataset_processed: |
| 752 | + logger.info( |
| 753 | + "Fast-forwarding through dataset samples to match checkpoint position" |
| 754 | + ) |
| 755 | + samples_to_skip = target_n_dataset_processed - self.n_dataset_processed |
| 756 | + |
| 757 | + pbar = tqdm( |
| 758 | + total=samples_to_skip, |
| 759 | + desc="Fast-forwarding through dataset", |
| 760 | + leave=False, |
| 761 | + ) |
| 762 | + while target_n_dataset_processed > self.n_dataset_processed: |
| 763 | + start = self.n_dataset_processed |
| 764 | + try: |
| 765 | + # Just consume and ignore the values to fast-forward |
| 766 | + next(self.iterable_sequences) |
| 767 | + except StopIteration: |
| 768 | + logger.warning( |
| 769 | + "Dataset exhausted during fast-forward. Resetting dataset." |
| 770 | + ) |
| 771 | + self.iterable_sequences = self._iterate_tokenized_sequences() |
| 772 | + pbar.update(self.n_dataset_processed - start) |
| 773 | + pbar.close() |
| 774 | + |
732 | 775 |
|
733 | 776 | def validate_pretokenized_dataset_tokenizer( |
734 | 777 | dataset_path: str, model_tokenizer: PreTrainedTokenizerBase |
|
0 commit comments