1
1
import io
2
2
import json
3
- import shutil
4
3
from dataclasses import asdict
5
- from pathlib import Path
4
+ from typing import Generator
6
5
6
+ import datasets
7
7
import einops
8
+ import numpy as np
9
+ import pyarrow as pa
8
10
import torch
9
- from datasets import Array2D , Dataset , Features
10
- from datasets .fingerprint import generate_fingerprint
11
+ from datasets import Dataset , Features , Sequence , Value
11
12
from huggingface_hub import HfApi
12
13
from jaxtyping import Float
13
- from tqdm import tqdm
14
14
from transformer_lens .HookedTransformer import HookedRootModule
15
15
16
16
from sae_lens .config import DTYPE_MAP , CacheActivationsRunnerConfig
17
17
from sae_lens .load_model import load_model
18
18
from sae_lens .training .activations_store import ActivationsStore
19
19
20
20
21
+ class CacheActivationDataset (datasets .ArrowBasedBuilder ):
22
+ cfg : CacheActivationsRunnerConfig
23
+ activation_store : ActivationsStore
24
+ # info: datasets.DatasetInfo # By DatasetBuilder
25
+
26
+ pa_dtype : pa .DataType
27
+ schema : pa .Schema
28
+
29
+ hook_names : list [str ] # while we can only use one hook
30
+
31
+ def __init__ (
32
+ self ,
33
+ cfg : CacheActivationsRunnerConfig ,
34
+ activation_store : ActivationsStore ,
35
+ ):
36
+ self .cfg = cfg
37
+ self .activation_store = activation_store
38
+ self .hook_names = [cfg .hook_name ]
39
+
40
+ if cfg .dtype == "float32" :
41
+ self .pa_dtype = pa .float32 ()
42
+ elif cfg .dtype == "float16" :
43
+ self .pa_dtype = pa .float16 ()
44
+ else :
45
+ raise ValueError (f"dtype { cfg .dtype } not supported" )
46
+
47
+ self .schema = pa .schema (
48
+ [
49
+ pa .field (hook_name , pa .list_ (self .pa_dtype , list_size = cfg .d_in ))
50
+ for hook_name in self .hook_names
51
+ ]
52
+ )
53
+
54
+ features = Features (
55
+ {
56
+ hook_name : Sequence (Value (dtype = cfg .dtype ), length = cfg .d_in )
57
+ for hook_name in [cfg .hook_name ]
58
+ }
59
+ )
60
+ cfg .activation_save_path .mkdir (parents = True , exist_ok = True )
61
+ assert cfg .activation_save_path .is_dir ()
62
+ if any (cfg .activation_save_path .iterdir ()):
63
+ raise ValueError (
64
+ f"Activation save path { cfg .activation_save_path } is not empty. Please delete it or specify a different path"
65
+ )
66
+ cache_dir = cfg .activation_save_path .parent
67
+ dataset_name = cfg .activation_save_path .name
68
+ super ().__init__ (
69
+ cache_dir = str (cache_dir ),
70
+ dataset_name = dataset_name ,
71
+ info = datasets .DatasetInfo (features = features ),
72
+ )
73
+
74
+ def _split_generators (
75
+ self , dl_manager : datasets .DownloadManager | datasets .StreamingDownloadManager
76
+ ) -> list [datasets .SplitGenerator ]:
77
+ return [
78
+ datasets .SplitGenerator (name = str (datasets .Split .TRAIN )),
79
+ ]
80
+
81
+ def _generate_tables (self ) -> Generator [tuple [int , pa .Table ], None , None ]: # type: ignore
82
+ for i in range (self .cfg .n_buffers ):
83
+ buffer = self .activation_store .get_buffer (
84
+ self .cfg .batches_in_buffer , shuffle = False
85
+ )
86
+ assert buffer .device .type == "cpu"
87
+ buffer = einops .rearrange (
88
+ buffer , "batch hook d_in -> hook batch d_in"
89
+ ).numpy ()
90
+ table = pa .Table .from_pydict (
91
+ {
92
+ hn : self .np2pa_2d (buf , d_in = self .cfg .d_in )
93
+ for hn , buf in zip (self .hook_names , buffer )
94
+ },
95
+ schema = self .schema ,
96
+ )
97
+ yield i , table
98
+
99
+ @staticmethod
100
+ def np2pa_2d (data : Float [np .ndarray , "batch d_in" ], d_in : int ) -> pa .Array : # type: ignore
101
+ """
102
+ Convert a 2D numpy array to a PyArrow FixedSizeListArray.
103
+ """
104
+ assert data .ndim == 2 , "Input array must be 2-dimensional."
105
+ _ , d_in_found = data .shape
106
+ if d_in_found != d_in :
107
+ raise RuntimeError (f"d_in { d_in_found } does not match expected d_in { d_in } " )
108
+ flat = data .ravel () # no copy if possible
109
+ pa_data = pa .array (flat )
110
+ return pa .FixedSizeListArray .from_arrays (pa_data , d_in )
111
+
112
+
21
113
class CacheActivationsRunner :
22
114
def __init__ (self , cfg : CacheActivationsRunnerConfig ):
23
115
self .cfg = cfg
@@ -33,19 +125,8 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
33
125
self .model ,
34
126
self .cfg ,
35
127
)
36
- self .context_size = self ._get_sliced_context_size (
37
- self .cfg .context_size , self .cfg .seqpos_slice
38
- )
39
- self .features = Features (
40
- {
41
- hook_name : Array2D (
42
- shape = (self .context_size , self .cfg .d_in ), dtype = self .cfg .dtype
43
- )
44
- for hook_name in [self .cfg .hook_name ]
45
- }
46
- )
47
128
48
- def __str__ (self ):
129
+ def summary (self ):
49
130
"""
50
131
Print the number of tokens to be cached.
51
132
Print the number of buffers, and the number of tokens per buffer.
@@ -58,10 +139,10 @@ def __str__(self):
58
139
if isinstance (self .cfg .dtype , torch .dtype )
59
140
else DTYPE_MAP [self .cfg .dtype ].itemsize
60
141
)
61
- total_training_tokens = self .cfg .dataset_num_rows * self .context_size
142
+ total_training_tokens = self .cfg .dataset_num_rows * self .cfg . sliced_context_size
62
143
total_disk_space_gb = total_training_tokens * bytes_per_token / 10 ** 9
63
144
64
- return (
145
+ print (
65
146
f"Activation Cache Runner:\n "
66
147
f"Total training tokens: { total_training_tokens } \n "
67
148
f"Number of buffers: { self .cfg .n_buffers } \n "
@@ -71,168 +152,15 @@ def __str__(self):
71
152
f"{ self .cfg } "
72
153
)
73
154
74
- @staticmethod
75
- def _consolidate_shards (
76
- source_dir : Path , output_dir : Path , copy_files : bool = True
77
- ) -> Dataset :
78
- """Consolidate sharded datasets into a single directory without rewriting data.
79
-
80
- Each of the shards must be of the same format, aka the full dataset must be able to
81
- be recreated like so:
82
-
83
- ```
84
- ds = concatenate_datasets(
85
- [Dataset.load_from_disk(str(shard_dir)) for shard_dir in sorted(source_dir.iterdir())]
86
- )
87
-
88
- ```
89
-
90
- Sharded dataset format:
91
- ```
92
- source_dir/
93
- shard_00000/
94
- dataset_info.json
95
- state.json
96
- data-00000-of-00002.arrow
97
- data-00001-of-00002.arrow
98
- shard_00001/
99
- dataset_info.json
100
- state.json
101
- data-00000-of-00001.arrow
102
- ```
103
-
104
- And flattens them into the format:
105
-
106
- ```
107
- output_dir/
108
- dataset_info.json
109
- state.json
110
- data-00000-of-00003.arrow
111
- data-00001-of-00003.arrow
112
- data-00002-of-00003.arrow
113
- ```
114
-
115
- allowing the dataset to be loaded like so:
116
-
117
- ```
118
- ds = datasets.load_from_disk(output_dir)
119
- ```
120
-
121
- Args:
122
- source_dir: Directory containing the sharded datasets
123
- output_dir: Directory to consolidate the shards into
124
- copy_files: If True, copy files; if False, move them and delete source_dir
125
- """
126
- first_shard_dir_name = "shard_00000" # shard_{i:05d}
127
-
128
- assert source_dir .exists () and source_dir .is_dir ()
129
- assert (
130
- output_dir .exists ()
131
- and output_dir .is_dir ()
132
- and not any (p for p in output_dir .iterdir () if not p .name == ".tmp_shards" )
133
- )
134
- if not (source_dir / first_shard_dir_name ).exists ():
135
- raise Exception (f"No shards in { source_dir } exist!" )
136
-
137
- transfer_fn = shutil .copy2 if copy_files else shutil .move
138
-
139
- # Move dataset_info.json from any shard (all the same)
140
- transfer_fn (
141
- source_dir / first_shard_dir_name / "dataset_info.json" ,
142
- output_dir / "dataset_info.json" ,
143
- )
144
-
145
- arrow_files = []
146
- file_count = 0
147
-
148
- for shard_dir in sorted (source_dir .iterdir ()):
149
- if not shard_dir .name .startswith ("shard_" ):
150
- continue
151
-
152
- # state.json contains arrow filenames
153
- state = json .loads ((shard_dir / "state.json" ).read_text ())
154
-
155
- for data_file in state ["_data_files" ]:
156
- src = shard_dir / data_file ["filename" ]
157
- new_name = f"data-{ file_count :05d} -of-{ len (list (source_dir .iterdir ())):05d} .arrow"
158
- dst = output_dir / new_name
159
- transfer_fn (src , dst )
160
- arrow_files .append ({"filename" : new_name })
161
- file_count += 1
162
-
163
- new_state = {
164
- "_data_files" : arrow_files ,
165
- "_fingerprint" : None , # temporary
166
- "_format_columns" : None ,
167
- "_format_kwargs" : {},
168
- "_format_type" : None ,
169
- "_output_all_columns" : False ,
170
- "_split" : None ,
171
- }
172
-
173
- # fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint)
174
- with open (output_dir / "state.json" , "w" ) as f :
175
- json .dump (new_state , f , indent = 2 )
176
-
177
- ds = Dataset .load_from_disk (str (output_dir ))
178
- fingerprint = generate_fingerprint (ds )
179
- del ds
180
-
181
- with open (output_dir / "state.json" , "r+" ) as f :
182
- state = json .loads (f .read ())
183
- state ["_fingerprint" ] = fingerprint
184
- f .seek (0 )
185
- json .dump (state , f , indent = 2 )
186
- f .truncate ()
187
-
188
- if not copy_files : # cleanup source dir
189
- shutil .rmtree (source_dir )
190
-
191
- return Dataset .load_from_disk (output_dir )
192
-
193
155
@torch .no_grad ()
194
156
def run (self ) -> Dataset :
195
- activation_save_path = self .cfg .activation_save_path
196
- assert activation_save_path is not None
197
-
198
- ### Paths setup
199
- final_cached_activation_path = Path (activation_save_path )
200
- final_cached_activation_path .mkdir (exist_ok = True , parents = True )
201
- if any (final_cached_activation_path .iterdir ()):
202
- raise Exception (
203
- f"Activations directory ({ final_cached_activation_path } ) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
204
- )
205
-
206
- tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/"
207
- tmp_cached_activation_path .mkdir (exist_ok = False , parents = False )
208
-
209
- ### Create temporary sharded datasets
210
-
211
- print (f"Started caching activations for { self .cfg .hf_dataset_path } " )
212
-
213
- for i in tqdm (range (self .cfg .n_buffers ), desc = "Caching activations" ):
214
- try :
215
- buffer = self .activations_store .get_buffer (
216
- self .cfg .batches_in_buffer , shuffle = False
217
- )
218
- shard = self ._create_shard (buffer )
219
- shard .save_to_disk (
220
- f"{ tmp_cached_activation_path } /shard_{ i :05d} " , num_shards = 1
221
- )
222
- del buffer , shard
223
-
224
- except StopIteration :
225
- print (
226
- f"Warning: Ran out of samples while filling the buffer at batch { i } before reaching { self .cfg .n_buffers } batches."
227
- )
228
- break
157
+ builder = CacheActivationDataset (self .cfg , self .activations_store )
158
+ builder .download_and_prepare ()
159
+ dataset = builder .as_dataset (split = "train" ) # type: ignore
160
+ assert isinstance (dataset , Dataset )
229
161
230
162
### Concatenate shards and push to Huggingface Hub
231
163
232
- dataset = self ._consolidate_shards (
233
- tmp_cached_activation_path , final_cached_activation_path , copy_files = False
234
- )
235
-
236
164
if self .cfg .shuffle :
237
165
print ("Shuffling..." )
238
166
dataset = dataset .shuffle (seed = self .cfg .seed )
@@ -241,7 +169,7 @@ def run(self) -> Dataset:
241
169
print ("Pushing to Huggingface Hub..." )
242
170
dataset .push_to_hub (
243
171
repo_id = self .cfg .hf_repo_id ,
244
- num_shards = self .cfg .hf_num_shards or self . cfg . n_buffers ,
172
+ num_shards = self .cfg .hf_num_shards ,
245
173
private = self .cfg .hf_is_private_repo ,
246
174
revision = self .cfg .hf_revision ,
247
175
)
@@ -263,31 +191,3 @@ def run(self) -> Dataset:
263
191
)
264
192
265
193
return dataset
266
-
267
- def _create_shard (
268
- self ,
269
- buffer : Float [torch .Tensor , "(bs context_size) num_layers d_in" ],
270
- ) -> Dataset :
271
- hook_names = [self .cfg .hook_name ]
272
-
273
- buffer = einops .rearrange (
274
- buffer ,
275
- "(bs context_size) num_layers d_in -> num_layers bs context_size d_in" ,
276
- bs = self .cfg .rows_in_buffer ,
277
- context_size = self .context_size ,
278
- d_in = self .cfg .d_in ,
279
- num_layers = len (hook_names ),
280
- )
281
- shard = Dataset .from_dict (
282
- {hook_name : act for hook_name , act in zip (hook_names , buffer )},
283
- features = self .features ,
284
- )
285
- return shard
286
-
287
- @staticmethod
288
- def _get_sliced_context_size (
289
- context_size : int , seqpos_slice : tuple [int | None , ...] | None
290
- ) -> int :
291
- if seqpos_slice is not None :
292
- context_size = len (range (context_size )[slice (* seqpos_slice )])
293
- return context_size
0 commit comments