Skip to content

Commit 28ee687

Browse files
committed
Move from manual sharding to HF dataset builder.
Depends on #389. Inspired by: https://opensourcemechanistic.slack.com/archives/C07EHMK3XC7/p1732413633220709 Instead of manually writing the single arrow shards, we can create a dataset builder that can do this more efficiently. This speeds up saving quite a lot, old method spent a some time calculating the fingerprint of the shard, which was unecessary and would require a hack to get around. > Along with this change, I also switched to a 1D activation scheme. - Previously the dataset was stored as a `(seq_len d_in)` array. - Now stored as a flat `d_in` Primary reason for this change is shuffling activations. I found that by using activations sequence, the activations are not properly shuffled. This is a problem with `ActivationCache` too but there's not a great solution for it there. You can observe this in the loss of the SAE by using small buffer sizes with either using cache or `ActivationStore`.
1 parent dd09264 commit 28ee687

File tree

4 files changed

+146
-272
lines changed

4 files changed

+146
-272
lines changed

sae_lens/cache_activations_runner.py

+105-205
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,115 @@
11
import io
22
import json
3-
import shutil
43
from dataclasses import asdict
5-
from pathlib import Path
4+
from typing import Generator
65

6+
import datasets
77
import einops
8+
import numpy as np
9+
import pyarrow as pa
810
import torch
9-
from datasets import Array2D, Dataset, Features
10-
from datasets.fingerprint import generate_fingerprint
11+
from datasets import Dataset, Features, Sequence, Value
1112
from huggingface_hub import HfApi
1213
from jaxtyping import Float
13-
from tqdm import tqdm
1414
from transformer_lens.HookedTransformer import HookedRootModule
1515

1616
from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
1717
from sae_lens.load_model import load_model
1818
from sae_lens.training.activations_store import ActivationsStore
1919

2020

21+
class CacheActivationDataset(datasets.ArrowBasedBuilder):
22+
cfg: CacheActivationsRunnerConfig
23+
activation_store: ActivationsStore
24+
# info: datasets.DatasetInfo # By DatasetBuilder
25+
26+
pa_dtype: pa.DataType
27+
schema: pa.Schema
28+
29+
hook_names: list[str] # while we can only use one hook
30+
31+
def __init__(
32+
self,
33+
cfg: CacheActivationsRunnerConfig,
34+
activation_store: ActivationsStore,
35+
):
36+
self.cfg = cfg
37+
self.activation_store = activation_store
38+
self.hook_names = [cfg.hook_name]
39+
40+
if cfg.dtype == "float32":
41+
self.pa_dtype = pa.float32()
42+
elif cfg.dtype == "float16":
43+
self.pa_dtype = pa.float16()
44+
else:
45+
raise ValueError(f"dtype {cfg.dtype} not supported")
46+
47+
self.schema = pa.schema(
48+
[
49+
pa.field(hook_name, pa.list_(self.pa_dtype, list_size=cfg.d_in))
50+
for hook_name in self.hook_names
51+
]
52+
)
53+
54+
features = Features(
55+
{
56+
hook_name: Sequence(Value(dtype=cfg.dtype), length=cfg.d_in)
57+
for hook_name in [cfg.hook_name]
58+
}
59+
)
60+
cfg.activation_save_path.mkdir(parents=True, exist_ok=True)
61+
assert cfg.activation_save_path.is_dir()
62+
if any(cfg.activation_save_path.iterdir()):
63+
raise ValueError(
64+
f"Activation save path {cfg.activation_save_path} is not empty. Please delete it or specify a different path"
65+
)
66+
cache_dir = cfg.activation_save_path.parent
67+
dataset_name = cfg.activation_save_path.name
68+
super().__init__(
69+
cache_dir=str(cache_dir),
70+
dataset_name=dataset_name,
71+
info=datasets.DatasetInfo(features=features),
72+
)
73+
74+
def _split_generators(
75+
self, dl_manager: datasets.DownloadManager | datasets.StreamingDownloadManager
76+
) -> list[datasets.SplitGenerator]:
77+
return [
78+
datasets.SplitGenerator(name=str(datasets.Split.TRAIN)),
79+
]
80+
81+
def _generate_tables(self) -> Generator[tuple[int, pa.Table], None, None]: # type: ignore
82+
for i in range(self.cfg.n_buffers):
83+
buffer = self.activation_store.get_buffer(
84+
self.cfg.batches_in_buffer, shuffle=False
85+
)
86+
assert buffer.device.type == "cpu"
87+
buffer = einops.rearrange(
88+
buffer, "batch hook d_in -> hook batch d_in"
89+
).numpy()
90+
table = pa.Table.from_pydict(
91+
{
92+
hn: self.np2pa_2d(buf, d_in=self.cfg.d_in)
93+
for hn, buf in zip(self.hook_names, buffer)
94+
},
95+
schema=self.schema,
96+
)
97+
yield i, table
98+
99+
@staticmethod
100+
def np2pa_2d(data: Float[np.ndarray, "batch d_in"], d_in: int) -> pa.Array: # type: ignore
101+
"""
102+
Convert a 2D numpy array to a PyArrow FixedSizeListArray.
103+
"""
104+
assert data.ndim == 2, "Input array must be 2-dimensional."
105+
_, d_in_found = data.shape
106+
if d_in_found != d_in:
107+
raise RuntimeError(f"d_in {d_in_found} does not match expected d_in {d_in}")
108+
flat = data.ravel() # no copy if possible
109+
pa_data = pa.array(flat)
110+
return pa.FixedSizeListArray.from_arrays(pa_data, d_in)
111+
112+
21113
class CacheActivationsRunner:
22114
def __init__(self, cfg: CacheActivationsRunnerConfig):
23115
self.cfg = cfg
@@ -33,19 +125,8 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
33125
self.model,
34126
self.cfg,
35127
)
36-
self.context_size = self._get_sliced_context_size(
37-
self.cfg.context_size, self.cfg.seqpos_slice
38-
)
39-
self.features = Features(
40-
{
41-
hook_name: Array2D(
42-
shape=(self.context_size, self.cfg.d_in), dtype=self.cfg.dtype
43-
)
44-
for hook_name in [self.cfg.hook_name]
45-
}
46-
)
47128

48-
def __str__(self):
129+
def summary(self):
49130
"""
50131
Print the number of tokens to be cached.
51132
Print the number of buffers, and the number of tokens per buffer.
@@ -58,10 +139,10 @@ def __str__(self):
58139
if isinstance(self.cfg.dtype, torch.dtype)
59140
else DTYPE_MAP[self.cfg.dtype].itemsize
60141
)
61-
total_training_tokens = self.cfg.dataset_num_rows * self.context_size
142+
total_training_tokens = self.cfg.dataset_num_rows * self.cfg.sliced_context_size
62143
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
63144

64-
return (
145+
print(
65146
f"Activation Cache Runner:\n"
66147
f"Total training tokens: {total_training_tokens}\n"
67148
f"Number of buffers: {self.cfg.n_buffers}\n"
@@ -71,168 +152,15 @@ def __str__(self):
71152
f"{self.cfg}"
72153
)
73154

74-
@staticmethod
75-
def _consolidate_shards(
76-
source_dir: Path, output_dir: Path, copy_files: bool = True
77-
) -> Dataset:
78-
"""Consolidate sharded datasets into a single directory without rewriting data.
79-
80-
Each of the shards must be of the same format, aka the full dataset must be able to
81-
be recreated like so:
82-
83-
```
84-
ds = concatenate_datasets(
85-
[Dataset.load_from_disk(str(shard_dir)) for shard_dir in sorted(source_dir.iterdir())]
86-
)
87-
88-
```
89-
90-
Sharded dataset format:
91-
```
92-
source_dir/
93-
shard_00000/
94-
dataset_info.json
95-
state.json
96-
data-00000-of-00002.arrow
97-
data-00001-of-00002.arrow
98-
shard_00001/
99-
dataset_info.json
100-
state.json
101-
data-00000-of-00001.arrow
102-
```
103-
104-
And flattens them into the format:
105-
106-
```
107-
output_dir/
108-
dataset_info.json
109-
state.json
110-
data-00000-of-00003.arrow
111-
data-00001-of-00003.arrow
112-
data-00002-of-00003.arrow
113-
```
114-
115-
allowing the dataset to be loaded like so:
116-
117-
```
118-
ds = datasets.load_from_disk(output_dir)
119-
```
120-
121-
Args:
122-
source_dir: Directory containing the sharded datasets
123-
output_dir: Directory to consolidate the shards into
124-
copy_files: If True, copy files; if False, move them and delete source_dir
125-
"""
126-
first_shard_dir_name = "shard_00000" # shard_{i:05d}
127-
128-
assert source_dir.exists() and source_dir.is_dir()
129-
assert (
130-
output_dir.exists()
131-
and output_dir.is_dir()
132-
and not any(p for p in output_dir.iterdir() if not p.name == ".tmp_shards")
133-
)
134-
if not (source_dir / first_shard_dir_name).exists():
135-
raise Exception(f"No shards in {source_dir} exist!")
136-
137-
transfer_fn = shutil.copy2 if copy_files else shutil.move
138-
139-
# Move dataset_info.json from any shard (all the same)
140-
transfer_fn(
141-
source_dir / first_shard_dir_name / "dataset_info.json",
142-
output_dir / "dataset_info.json",
143-
)
144-
145-
arrow_files = []
146-
file_count = 0
147-
148-
for shard_dir in sorted(source_dir.iterdir()):
149-
if not shard_dir.name.startswith("shard_"):
150-
continue
151-
152-
# state.json contains arrow filenames
153-
state = json.loads((shard_dir / "state.json").read_text())
154-
155-
for data_file in state["_data_files"]:
156-
src = shard_dir / data_file["filename"]
157-
new_name = f"data-{file_count:05d}-of-{len(list(source_dir.iterdir())):05d}.arrow"
158-
dst = output_dir / new_name
159-
transfer_fn(src, dst)
160-
arrow_files.append({"filename": new_name})
161-
file_count += 1
162-
163-
new_state = {
164-
"_data_files": arrow_files,
165-
"_fingerprint": None, # temporary
166-
"_format_columns": None,
167-
"_format_kwargs": {},
168-
"_format_type": None,
169-
"_output_all_columns": False,
170-
"_split": None,
171-
}
172-
173-
# fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint)
174-
with open(output_dir / "state.json", "w") as f:
175-
json.dump(new_state, f, indent=2)
176-
177-
ds = Dataset.load_from_disk(str(output_dir))
178-
fingerprint = generate_fingerprint(ds)
179-
del ds
180-
181-
with open(output_dir / "state.json", "r+") as f:
182-
state = json.loads(f.read())
183-
state["_fingerprint"] = fingerprint
184-
f.seek(0)
185-
json.dump(state, f, indent=2)
186-
f.truncate()
187-
188-
if not copy_files: # cleanup source dir
189-
shutil.rmtree(source_dir)
190-
191-
return Dataset.load_from_disk(output_dir)
192-
193155
@torch.no_grad()
194156
def run(self) -> Dataset:
195-
activation_save_path = self.cfg.activation_save_path
196-
assert activation_save_path is not None
197-
198-
### Paths setup
199-
final_cached_activation_path = Path(activation_save_path)
200-
final_cached_activation_path.mkdir(exist_ok=True, parents=True)
201-
if any(final_cached_activation_path.iterdir()):
202-
raise Exception(
203-
f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
204-
)
205-
206-
tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/"
207-
tmp_cached_activation_path.mkdir(exist_ok=False, parents=False)
208-
209-
### Create temporary sharded datasets
210-
211-
print(f"Started caching activations for {self.cfg.hf_dataset_path}")
212-
213-
for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
214-
try:
215-
buffer = self.activations_store.get_buffer(
216-
self.cfg.batches_in_buffer, shuffle=False
217-
)
218-
shard = self._create_shard(buffer)
219-
shard.save_to_disk(
220-
f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
221-
)
222-
del buffer, shard
223-
224-
except StopIteration:
225-
print(
226-
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
227-
)
228-
break
157+
builder = CacheActivationDataset(self.cfg, self.activations_store)
158+
builder.download_and_prepare()
159+
dataset = builder.as_dataset(split="train") # type: ignore
160+
assert isinstance(dataset, Dataset)
229161

230162
### Concatenate shards and push to Huggingface Hub
231163

232-
dataset = self._consolidate_shards(
233-
tmp_cached_activation_path, final_cached_activation_path, copy_files=False
234-
)
235-
236164
if self.cfg.shuffle:
237165
print("Shuffling...")
238166
dataset = dataset.shuffle(seed=self.cfg.seed)
@@ -241,7 +169,7 @@ def run(self) -> Dataset:
241169
print("Pushing to Huggingface Hub...")
242170
dataset.push_to_hub(
243171
repo_id=self.cfg.hf_repo_id,
244-
num_shards=self.cfg.hf_num_shards or self.cfg.n_buffers,
172+
num_shards=self.cfg.hf_num_shards,
245173
private=self.cfg.hf_is_private_repo,
246174
revision=self.cfg.hf_revision,
247175
)
@@ -263,31 +191,3 @@ def run(self) -> Dataset:
263191
)
264192

265193
return dataset
266-
267-
def _create_shard(
268-
self,
269-
buffer: Float[torch.Tensor, "(bs context_size) num_layers d_in"],
270-
) -> Dataset:
271-
hook_names = [self.cfg.hook_name]
272-
273-
buffer = einops.rearrange(
274-
buffer,
275-
"(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
276-
bs=self.cfg.rows_in_buffer,
277-
context_size=self.context_size,
278-
d_in=self.cfg.d_in,
279-
num_layers=len(hook_names),
280-
)
281-
shard = Dataset.from_dict(
282-
{hook_name: act for hook_name, act in zip(hook_names, buffer)},
283-
features=self.features,
284-
)
285-
return shard
286-
287-
@staticmethod
288-
def _get_sliced_context_size(
289-
context_size: int, seqpos_slice: tuple[int | None, ...] | None
290-
) -> int:
291-
if seqpos_slice is not None:
292-
context_size = len(range(context_size)[slice(*seqpos_slice)])
293-
return context_size

0 commit comments

Comments
 (0)