-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Description
# vit_prisma.sae.training.activations_store.py
class VisionActivationsStore:
"""
Class for streaming tokens and generating and storing activations
while training SAEs.
"""
def __init__(
self,
cfg: Any,
model: HookedViT,
dataset,
create_dataloader: bool = True,
eval_dataset = None,
num_workers=0,
):
self.cfg = cfg
self.model = model
self.model.to(cfg.device)
self.dataset = dataset
self.image_dataloader = torch.utils.data.DataLoader(self.dataset, shuffle=True, num_workers=num_workers, batch_size=self.cfg.store_batch_size, collate_fn=collate_fn, drop_last=True)
self.image_dataloader_eval = torch.utils.data.DataLoader(eval_dataset, shuffle=True, num_workers=num_workers, batch_size=self.cfg.store_batch_size, collate_fn=collate_fn_eval, drop_last=True)
self.image_dataloader_iter = self.get_batch_tokens_internal()
self.image_dataloader_eval_iter = self.get_val_batch_tokens_internal()
....This issue serves as documentation and will be addressed later today.
Ensuring that eval_dataset is properly checked will prevent runtime errors and improve the robustness of the VisionActivationsStore class.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels