Skip to content

Memmap shuffling #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ee78d91
call dataloader next consistent with async loading
jbloomAus Jul 2, 2024
eda7db7
start using np.memmap (but with none of the actual advantages
jbloomAus Jul 2, 2024
bff803f
various changes, iterating toward np.memmap
jbloomAus Jul 6, 2024
c486f05
deduplicate small mathematical operations
Lewington-pitsos Jul 6, 2024
112960a
named parameters for shuffle idxs
Lewington-pitsos Jul 6, 2024
e01238d
add diagnostic prints, fix typing in activations_store
Lewington-pitsos Jul 6, 2024
545a33c
add diagnostic prints, fix typing in activations_store
Lewington-pitsos Jul 6, 2024
7eec1c2
rename shuffling methods
Lewington-pitsos Jul 6, 2024
a720ad0
add dataset override for cache activations runner, update test to use…
Lewington-pitsos Jul 7, 2024
18f4216
replicate error using activationstore alone
Lewington-pitsos Jul 7, 2024
c053910
replicate error using activationstore alone
Lewington-pitsos Jul 7, 2024
a76e4ed
fix float32 vs float16 memmap double size issue
Lewington-pitsos Jul 7, 2024
4b7fbdb
fix typing
Lewington-pitsos Jul 7, 2024
bf3e28d
skip test_load_cached_activations
Lewington-pitsos Jul 7, 2024
7a74cf6
get all unit test passing with new next_batch functionality
Lewington-pitsos Jul 7, 2024
9ca4aba
Merge branch 'main' into memmap_shuffling
Lewington-pitsos Jul 7, 2024
43ccbb8
merge with main
Lewington-pitsos Jul 7, 2024
d82b17c
format
Lewington-pitsos Jul 7, 2024
a955a7f
add small test of next_batch functionality
Lewington-pitsos Jul 7, 2024
ac45f4c
format again
Lewington-pitsos Jul 7, 2024
c80de83
map bfloat16 to np.float32
Lewington-pitsos Jul 7, 2024
fc147d5
reformat
Jul 7, 2024
4e8847b
add memory pinning
Jul 7, 2024
fd0d499
reformat
Lewington-pitsos Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/generate_sae_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def generate_sae_table():
# )

for info in tqdm(model_info["saes"]):

# can remove this by explicitly overriding config in yaml. Do this later.
if model_info["conversion_func"] == "connor_rob_hook_z":
repo_id = model_info["repo_id"]
Expand Down
1 change: 0 additions & 1 deletion sae_lens/analysis/hooked_sae_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def set_deep_attr(obj: Any, path: str, value: Any):


class HookedSAETransformer(HookedTransformer):

def __init__(
self,
*model_args: Any,
Expand Down
2 changes: 0 additions & 2 deletions sae_lens/analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def default(self, o: Any):


class NeuronpediaRunner:

def __init__(
self,
sae_id: str,
Expand All @@ -83,7 +82,6 @@ def __init__(
top_acts_group_size: int = 20,
quantile_group_size: int = 5,
):

self.device = "cpu"
if torch.backends.mps.is_available():
self.device = "mps"
Expand Down
4 changes: 0 additions & 4 deletions sae_lens/analysis/tsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def get_enrichment_df(
features: list[int],
gene_sets_selected: dict[str, set[int]],
):

gene_sets_token_ids_padded = pad_gene_sets(gene_sets_selected)
gene_sets_token_ids_tensor = torch.tensor(list(gene_sets_token_ids_padded.values()))
enrichment_scores = calculate_batch_enrichment_scores(
Expand Down Expand Up @@ -91,7 +90,6 @@ def calculate_batch_enrichment_scores(scores: torch.Tensor, index_lists: torch.T
def manhattan_plot_enrichment_scores(
df_enrichment_scores: pd.DataFrame, label_threshold: float = 1.0, top_n: int = 3
):

tmp_df = df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x))

# wide to long format
Expand Down Expand Up @@ -167,7 +165,6 @@ def plot_top_k_feature_projections_by_token_and_category(
log_y: bool = True,
histnorm: Optional[str] = None,
):

if not os.path.exists("es_plots"):
os.makedirs("es_plots")

Expand Down Expand Up @@ -291,7 +288,6 @@ def get_gene_set_from_regex(vocab: dict[str, int], pattern: str) -> set[int]:


def get_test_gene_sets(model: HookedTransformer) -> dict[str, set[int]]:

colors = [
"red",
"blue",
Expand Down
164 changes: 96 additions & 68 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import math
import os
from typing import Tuple

import numpy as np
import torch
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from tqdm import tqdm

from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
from sae_lens.load_model import load_model
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.activations_store import FILE_EXTENSION, ActivationsStore


class CacheActivationsRunner:

def __init__(self, cfg: CacheActivationsRunnerConfig):
def __init__(
self,
cfg: CacheActivationsRunnerConfig,
override_dataset: (
DatasetDict | Dataset | IterableDatasetDict | IterableDataset | None
) = None,
):
self.cfg = cfg
self.model = load_model(
model_class_name=cfg.model_class_name,
Expand All @@ -23,9 +29,10 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
self.activations_store = ActivationsStore.from_config(
self.model,
cfg,
override_dataset=override_dataset,
)

self.file_extension = "safetensors"
self.file_extension = FILE_EXTENSION

def __str__(self):
"""
Expand All @@ -40,27 +47,33 @@ def __str__(self):
if isinstance(self.cfg.dtype, torch.dtype)
else DTYPE_MAP[self.cfg.dtype].itemsize
)
tokens_in_buffer = (
self.cfg.n_batches_in_buffer
* self.cfg.store_batch_size_prompts
* self.cfg.context_size
)
total_training_tokens = self.cfg.training_tokens
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

return (
f"Activation Cache Runner:\n"
f"Total training tokens: {total_training_tokens}\n"
f"Number of buffers: {math.ceil(total_training_tokens / tokens_in_buffer)}\n"
f"Tokens per buffer: {tokens_in_buffer}\n"
f"Number of buffers: {self.n_buffers}\n"
f"Tokens per buffer: {self.tokens_in_buffer}\n"
f"Disk space required: {total_disk_space_gb:.2f} GB\n"
f"Configuration:\n"
f"{self.cfg}"
)

@property
def tokens_in_buffer(self):
return (
self.cfg.n_batches_in_buffer
* self.cfg.store_batch_size_prompts
* self.cfg.context_size
)

@property
def n_buffers(self):
return math.ceil(self.cfg.training_tokens / self.tokens_in_buffer)

@torch.no_grad()
def run(self):

new_cached_activations_path = self.cfg.new_cached_activations_path

# if the activations directory exists and has files in it, raise an exception
Expand All @@ -73,94 +86,109 @@ def run(self):
else:
os.makedirs(new_cached_activations_path)

print(f"Started caching {self.cfg.training_tokens} activations")
tokens_per_buffer = (
self.cfg.store_batch_size_prompts
* self.cfg.context_size
* self.cfg.n_batches_in_buffer
)

n_buffers = math.ceil(self.cfg.training_tokens / tokens_per_buffer)

for i in tqdm(range(n_buffers), desc="Caching activations"):
for i in tqdm(range(self.n_buffers), desc="Caching activations"):
try:
buffer = self.activations_store.get_buffer(self.cfg.n_batches_in_buffer)

self.activations_store.save_buffer(
buffer, f"{new_cached_activations_path}/{i}.safetensors"
)
buffer = self.activations_store.get_buffer()
buffer_path = f"{new_cached_activations_path}/{i}.{self.file_extension}"
self.activations_store.save_buffer(buffer, buffer_path)

del buffer

if i % self.cfg.shuffle_every_n_buffers == 0 and i > 0:
if i > 0 and i % self.cfg.shuffle_every_n_buffers == 0:
# Shuffle the buffers on disk

# Do random pairwise shuffling between the last shuffle_every_n_buffers buffers
for _ in range(self.cfg.n_shuffles_with_last_section):
self.shuffle_activations_pairwise(
self.shuffle_two_random_buffers(
new_cached_activations_path,
buffer_idx_range=(i - self.cfg.shuffle_every_n_buffers, i),
start_idx=i - self.cfg.shuffle_every_n_buffers,
end_idx=i,
)

# Do more random pairwise shuffling between all the buffers
for _ in range(self.cfg.n_shuffles_in_entire_dir):
self.shuffle_activations_pairwise(
new_cached_activations_path,
buffer_idx_range=(0, i),
self.shuffle_two_random_buffers(
new_cached_activations_path, start_idx=0, end_idx=i
)
except StopIteration:
print(
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {n_buffers} batches. No more caching will occur."
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.n_buffers} batches. No more caching will occur."
)
break

# More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers)
if n_buffers > 1:
if self.n_buffers > 1:
for _ in tqdm(range(self.cfg.n_shuffles_final), desc="Final shuffling"):
self.shuffle_activations_pairwise(
self.shuffle_two_random_buffers(
new_cached_activations_path,
buffer_idx_range=(0, n_buffers),
start_idx=0,
end_idx=self.n_buffers,
)

@torch.no_grad()
def shuffle_activations_pairwise(
self, datapath: str, buffer_idx_range: Tuple[int, int]
):
def shuffle_two_random_buffers(self, datapath: str, start_idx: int, end_idx: int):
"""
Shuffles two buffers on disk.
Shuffles two randomly selected buffers on disk.
"""
assert (
buffer_idx_range[0] < buffer_idx_range[1] - 1
start_idx < end_idx - 1
), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1"

buffer_idx1 = torch.randint(
buffer_idx_range[0], buffer_idx_range[1], (1,)
).item()
buffer_idx2 = torch.randint(
buffer_idx_range[0], buffer_idx_range[1], (1,)
).item()
buffer_idx1 = int(torch.randint(start_idx, end_idx, (1,)).item())
buffer_idx2 = int(torch.randint(start_idx, end_idx, (1,)).item())
while buffer_idx1 == buffer_idx2: # Make sure they're not the same
buffer_idx2 = torch.randint(
buffer_idx_range[0], buffer_idx_range[1], (1,)
).item()
buffer_idx2 = int(torch.randint(start_idx, end_idx, (1,)).item())

buffer1 = self.activations_store.load_buffer(
f"{datapath}/{buffer_idx1}.{self.file_extension}"
)
buffer2 = self.activations_store.load_buffer(
f"{datapath}/{buffer_idx2}.{self.file_extension}"
self.shuffle_two_buffers(datapath, buffer_idx1, buffer_idx2)

@torch.no_grad()
def shuffle_two_buffers(self, datapath: str, buffer_idx1: int, buffer_idx2: int):
path1 = f"{datapath}/{buffer_idx1}.{self.file_extension}"
path2 = f"{datapath}/{buffer_idx2}.{self.file_extension}"

buffer1 = self.activations_store.load_buffer(path1)
buffer2 = self.activations_store.load_buffer(path2)

# Get total size and create a joint buffer
total_size = buffer1.shape[0] + buffer2.shape[0]
joint_buffer = np.memmap(
f"{datapath}/temp_joint_buffer",
dtype=buffer1.dtype,
mode="w+",
shape=(total_size,) + buffer1.shape[1:],
)
joint_buffer = torch.cat([buffer1, buffer2])

# Shuffle them
joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])]
shuffled_buffer1 = joint_buffer[: buffer1.shape[0]]
shuffled_buffer2 = joint_buffer[buffer1.shape[0] :]
# Copy data to joint buffer
joint_buffer[: buffer1.shape[0]] = buffer1
joint_buffer[buffer1.shape[0] :] = buffer2

# Save them back
self.activations_store.save_buffer(
shuffled_buffer1, f"{datapath}/{buffer_idx1}.{self.file_extension}"
# Generate random permutation
permutation = np.random.permutation(total_size)

# Create shuffled buffers
shuffled_buffer1 = np.memmap(
f"{datapath}/temp_shuffled_1",
dtype=buffer1.dtype,
mode="w+",
shape=buffer1.shape,
)
self.activations_store.save_buffer(
shuffled_buffer2, f"{datapath}/{buffer_idx2}.{self.file_extension}"
shuffled_buffer2 = np.memmap(
f"{datapath}/temp_shuffled_2",
dtype=buffer2.dtype,
mode="w+",
shape=buffer2.shape,
)

# Apply permutation
shuffled_buffer1[:] = joint_buffer[permutation[: buffer1.shape[0]]]
shuffled_buffer2[:] = joint_buffer[permutation[buffer1.shape[0] :]]

# Save shuffled buffers back to original files
self.activations_store.save_buffer(shuffled_buffer1, path1)
self.activations_store.save_buffer(shuffled_buffer2, path2)

# Clean up temporary files
import os

os.remove(f"{datapath}/temp_joint_buffer")
os.remove(f"{datapath}/temp_shuffled_1")
os.remove(f"{datapath}/temp_shuffled_2")
4 changes: 0 additions & 4 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ class LanguageModelSAERunnerConfig:
sae_lens_training_version: str = field(default_factory=lambda: __version__)

def __post_init__(self):

if self.resume:
raise ValueError(
"Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
Expand Down Expand Up @@ -393,7 +392,6 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]:
}

def to_dict(self) -> dict[str, Any]:

cfg_dict = {
**self.__dict__,
# some args may not be serializable by default
Expand All @@ -405,7 +403,6 @@ def to_dict(self) -> dict[str, Any]:
return cfg_dict

def to_json(self, path: str) -> None:

if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))

Expand Down Expand Up @@ -483,7 +480,6 @@ def __post_init__(self):

@dataclass
class ToyModelSAERunnerConfig:

architecture: Literal["standard", "gated"] = "standard"

# ReLu Model Parameters
Expand Down
4 changes: 0 additions & 4 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def run_evals(
eval_batch_size_prompts: int | None = None,
model_kwargs: Mapping[str, Any] = {},
) -> Mapping[str, Any]:

hook_name = sae.cfg.hook_name
hook_head_index = sae.cfg.hook_head_index
### Evals
Expand Down Expand Up @@ -153,7 +152,6 @@ def get_recons_loss(

# TODO(tomMcGrath): the rescaling below is a bit of a hack and could probably be tidied up
def standard_replacement_hook(activations: torch.Tensor, hook: Any):

original_device = activations.device
activations = activations.to(sae.device)

Expand All @@ -171,7 +169,6 @@ def standard_replacement_hook(activations: torch.Tensor, hook: Any):
return activations.to(original_device)

def all_head_replacement_hook(activations: torch.Tensor, hook: Any):

original_device = activations.device
activations = activations.to(sae.device)

Expand All @@ -195,7 +192,6 @@ def all_head_replacement_hook(activations: torch.Tensor, hook: Any):
return new_activations.to(original_device)

def single_head_replacement_hook(activations: torch.Tensor, hook: Any):

original_device = activations.device
activations = activations.to(sae.device)

Expand Down
Loading
Loading