-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_head_latency.py
More file actions
72 lines (54 loc) · 2.18 KB
/
benchmark_head_latency.py
File metadata and controls
72 lines (54 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import time
import statistics
from transformers import AutoModelForCausalLM
MODEL_ID = "swiss-ai/Apertus-8B-Instruct-2509"
PERCENTAGES = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.0]
BATCH_SIZE = 1
SEQ_LEN = 1
NUM_RUNS = 5000
def benchmark_real_head():
print(f"Loading model: {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="cuda:0",
trust_remote_code=True
)
weight_matrix = model.lm_head.weight
bias = model.lm_head.bias
hidden_size = weight_matrix.shape[1]
original_vocab = weight_matrix.shape[0]
dtype = weight_matrix.dtype
device = weight_matrix.device
print(f"Head Shape: {weight_matrix.shape}")
print(f"Precision: {dtype}")
print("-" * 80)
print(f"{'Percent':<8} | {'Vocab':<8} | {'Time (us)':<15} | {'Speedup':<10}")
print("-" * 80)
dummy_input = torch.randn(BATCH_SIZE, SEQ_LEN, hidden_size, device=device, dtype=dtype)
baseline_time = 0
for p in reversed(PERCENTAGES):
new_vocab_size = int(original_vocab * p)
sliced_weight = weight_matrix[:new_vocab_size, :]
sliced_bias = bias[:new_vocab_size] if bias is not None else None
for _ in range(100):
_ = torch.nn.functional.linear(dummy_input, sliced_weight, sliced_bias)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(NUM_RUNS):
_ = torch.nn.functional.linear(dummy_input, sliced_weight, sliced_bias)
end_event.record()
torch.cuda.synchronize()
total_ms = start_event.elapsed_time(end_event)
avg_us = (total_ms * 1000) / NUM_RUNS
if p == 1.0:
baseline_time = avg_us
speedup = "1.00x"
else:
speedup = f"{baseline_time / avg_us:.2f}x"
print(f"{p*100:>6.0f}% | {new_vocab_size:<8} | {avg_us:<15.2f} | {speedup:<10}")
if __name__ == "__main__":
benchmark_real_head()