-
Notifications
You must be signed in to change notification settings - Fork 525
Expand file tree
/
Copy pathbenchmark_attn_res.py
More file actions
224 lines (187 loc) · 7.37 KB
/
benchmark_attn_res.py
File metadata and controls
224 lines (187 loc) · 7.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
AttnRes Benchmark: Liger (Triton) vs PyTorch
Kimi Attention Residuals: softmax attention over depth blocks.
"""
import math
import os
import sys
import torch
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from benchmark_model_configs import MODEL_REGISTRY
from benchmark_model_configs import compute_model_config_sweep_config
from benchmark_model_configs import compute_seq_len_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import get_benchmark_model_config
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from utils import run_memory_benchmark
from utils import run_speed_benchmark
from liger_kernel.ops.attn_res import LigerAttnResFunction
from liger_kernel.utils import infer_device
device = infer_device()
def _setup_attn_res(input: SingleBenchmarkRunInput):
"""Create input tensors for AttnRes from benchmark config."""
cfg = input.extra_benchmark_config
seq_len = input.x
# V: [N, B, T, D]
V = torch.randn(
cfg["N"],
cfg["bsz"],
seq_len,
cfg["hidden_size"],
device=device,
dtype=cfg["dtype"],
requires_grad=True,
)
w_query = torch.randn(cfg["hidden_size"], device=device, dtype=cfg["dtype"]) * 0.02
w_norm = torch.ones(cfg["hidden_size"], device=device, dtype=cfg["dtype"])
eps = cfg.get("eps", 1e-6)
if input.kernel_provider == "liger":
fn = lambda: LigerAttnResFunction.apply(V, w_query, w_norm, eps)
elif input.kernel_provider == "pytorch":
from test.transformers.test_attn_res import pytorch_attn_res
fn = lambda: pytorch_attn_res(V, w_query, w_norm, eps)
else:
raise ValueError(f"Invalid provider: {input.kernel_provider}")
return V, fn
def bench_speed_attn_res(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V, fn = _setup_attn_res(input)
return run_speed_benchmark(fn, input.kernel_operation_mode, [V])
def bench_memory_attn_res(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V, fn = _setup_attn_res(input)
return run_memory_benchmark(fn, input.kernel_operation_mode)
def _resolve_model_config_attn_res(input: SingleBenchmarkRunInput):
"""Resolve model-config-sweep input into standard setup args."""
cfg = input.extra_benchmark_config
model_info = cfg["model_configs"][input.x]
return _setup_attn_res(
SingleBenchmarkRunInput(
x=cfg["seq_len"],
kernel_provider=input.kernel_provider,
extra_benchmark_config={
"N": cfg["N"],
"bsz": cfg["bsz"],
"hidden_size": model_info["hidden_size"],
"dtype": model_info["dtype"],
"eps": cfg.get("eps", 1e-6),
},
)
)
def bench_speed_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V, fn = _resolve_model_config_attn_res(input)
return run_speed_benchmark(fn, input.kernel_operation_mode, [V])
def bench_memory_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V, fn = _resolve_model_config_attn_res(input)
return run_memory_benchmark(fn, input.kernel_operation_mode)
if __name__ == "__main__":
args = parse_benchmark_script_args()
if args.sweep_mode == "model_config":
all_model_configs = list(MODEL_REGISTRY.values())
def _probe_factory(model_cfg, probe_seq_len):
def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="pytorch",
extra_benchmark_config={
"N": 8,
"bsz": 1,
"hidden_size": model_cfg.hidden_size,
"dtype": model_cfg.dtype,
"eps": 1e-6,
},
)
V, fn = _setup_attn_res(probe_input)
return fn()
return _probe
sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt)
model_configs_info = {
cfg.name: {
"hidden_size": cfg.hidden_size,
"dtype": cfg.dtype,
}
for cfg in sweep.model_configs
}
common_configs = {
"kernel_name": "attn_res",
"x_name": "model_config",
"x_label": "model configuration",
"x_values": [cfg.name for cfg in sweep.model_configs],
"kernel_providers": ["liger", "pytorch"],
"extra_benchmark_configs": [
{
"model_configs": model_configs_info,
"N": 8,
"bsz": sweep.batch_size,
"seq_len": sweep.seq_len,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_attn_res_model_config,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_attn_res_model_config,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
else:
model = get_benchmark_model_config(args.model)
probe_seq_len = 1024
def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="pytorch",
extra_benchmark_config={
"N": 8,
"bsz": 1,
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
},
)
V, fn = _setup_attn_res(probe_input)
return fn()
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // probe_seq_len
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)
common_configs = {
"kernel_name": "attn_res",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
"kernel_providers": ["liger", "pytorch"],
"extra_benchmark_configs": [
{
"N": 8,
"bsz": config.batch_size,
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_attn_res,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_attn_res,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)