-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbench_inference.py
More file actions
85 lines (66 loc) · 2.32 KB
/
bench_inference.py
File metadata and controls
85 lines (66 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import vllm
import os
from src.training.build_dataset import load_tl_model
PILE_PATH = "/root/pile_uncopyrighted"
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from typing import Optional
import pandas as pd
import numpy as np
import time
from tqdm.auto import tqdm
from functools import partial
import torch
def TL_inference(
model,
tokens : torch.Tensor,
hook_datasets : dict[str, Optional[list[np.ndarray]]],
stop_at_layer_idx : int
):
# Initialize datasets for each hook
def hook_fn(
tensor: torch.Tensor,
hook_name: str,
hook: HookPoint = None
):
t0 = time.time()
processed_tensors = tensor.detach().cpu().numpy()
print("took t0", t0-time.time())
if hook_datasets[hook_name] is None:
hook_datasets[hook_name] = [processed_tensors]
else:
hook_datasets[hook_name].append(processed_tensors)
return tensor
with torch.no_grad():
# Convert to numpy array first, then to tensor
t0 = time.time()
model.run_with_hooks(
tokens,
stop_at_layer = stop_at_layer_idx,
fwd_hooks=[
(hook_name, partial(hook_fn, hook_name=hook_name))
for hook_name in hook_names
]
)
torch.cuda.synchronize()
t1 = time.time()
print(f"forward pass with cache took {t1-t0} seconds")
tok_per_sec = tokens.numel() / (t1-t0)
print(f"tokens per second: {tok_per_sec}")
def compute_gb(hook_acts: np.ndarray):
assert isinstance(hook_acts, np.ndarray)
return hook_acts.nbytes / 1e9 # Use nbytes and divide by 1e9 for GB
tl_bench = True
print("gpu memory", torch.cuda.memory_summary())
if tl_bench:
device = "cuda"
model : HookedTransformer = load_tl_model("google/gemma-2-2b")
print("gpu memory", torch.cuda.memory_summary())
hook_names = [f"blocks.{i}.hook_resid_post" for i in [3]]
hook_datasets : dict[str, Optional[list[np.ndarray]]] = {
hook_name: None for hook_name in hook_names
}
from tqdm import tqdm
for _ in tqdm(range(100)):
tokens = torch.randint(1, 100, (1, 512))
TL_inference(model, tokens, hook_datasets, stop_at_layer_idx=4)