Skip to content

Commit cae5931

Browse files
committed
Add benchmark script
1 parent 9dda6da commit cae5931

2 files changed

Lines changed: 169 additions & 11 deletions

File tree

benchmark/scripts/benchmark_dyt.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
import triton
6+
7+
from utils import QUANTILES
8+
from utils import SingleBenchmarkRunInput
9+
from utils import SingleBenchmarkRunOutput
10+
from utils import _test_memory
11+
from utils import parse_benchmark_script_args
12+
from utils import run_benchmarks
13+
14+
from liger_kernel.utils import infer_device
15+
16+
device = infer_device()
17+
18+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
19+
20+
21+
def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22+
from test.transformers.test_dyt import LigerDyT
23+
from test.transformers.test_dyt import TorchDyT
24+
25+
hidden_size = input.x
26+
provider = input.kernel_provider
27+
mode = input.kernel_operation_mode
28+
extra_benchmark_config = input.extra_benchmark_config
29+
BT = extra_benchmark_config["BT"]
30+
dtype = extra_benchmark_config["dtype"]
31+
32+
x_shape = (BT, hidden_size)
33+
torch_y = lambda x: TorchDyT(hidden_size=hidden_size).to(device)(x)
34+
torch_compile_y = lambda x: torch.compile(
35+
TorchDyT(hidden_size=hidden_size).to(device)
36+
)(x)
37+
triton_y = lambda x: LigerDyT(hidden_size=hidden_size).to(device)(x)
38+
39+
x = torch.randn(x_shape, dtype=dtype, device=device)
40+
dy = torch.randn_like(x)
41+
x.requires_grad_(True)
42+
43+
def fwd():
44+
if provider == "liger":
45+
return triton_y(x)
46+
elif provider == "torch":
47+
return torch_y(x)
48+
elif provider == "torch_compile":
49+
return torch_compile_y(x)
50+
51+
if mode == "forward":
52+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
53+
fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
54+
)
55+
elif mode == "backward":
56+
y = fwd()
57+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
58+
lambda: y.backward(dy, retain_graph=True),
59+
quantiles=QUANTILES,
60+
grad_to_none=[x],
61+
rep=500,
62+
)
63+
elif mode == "full":
64+
65+
def full():
66+
y = fwd()
67+
y.backward(dy)
68+
69+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
70+
full, quantiles=QUANTILES, grad_to_none=[x], rep=500
71+
)
72+
73+
return SingleBenchmarkRunOutput(
74+
y_20=ms_20,
75+
y_50=ms_50,
76+
y_80=ms_80,
77+
)
78+
79+
80+
def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
81+
from test.transformers.test_dyt import LigerDyT
82+
from test.transformers.test_dyt import TorchDyT
83+
84+
hidden_size = input.x
85+
provider = input.kernel_provider
86+
extra_benchmark_config = input.extra_benchmark_config
87+
BT = extra_benchmark_config["BT"]
88+
dtype = extra_benchmark_config["dtype"]
89+
90+
x_shape = (BT, hidden_size)
91+
torch_y = lambda x: TorchDyT(hidden_size=hidden_size).to(device)(x)
92+
torch_compile_y = lambda x: torch.compile(
93+
TorchDyT(hidden_size=hidden_size).to(device)
94+
)(x)
95+
triton_y = lambda x: LigerDyT(hidden_size=hidden_size).to(device)(x)
96+
97+
x = torch.randn(x_shape, dtype=dtype, device=device)
98+
dy = torch.randn_like(x)
99+
x.requires_grad_(True)
100+
101+
def fwd():
102+
if provider == "liger":
103+
return triton_y(x)
104+
elif provider == "torch":
105+
return torch_y(x)
106+
elif provider == "torch_compile":
107+
return torch_compile_y(x)
108+
109+
def full():
110+
y = fwd()
111+
y.backward(dy, retain_graph=True)
112+
113+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
114+
return SingleBenchmarkRunOutput(
115+
y_20=mem_20,
116+
y_50=mem_50,
117+
y_80=mem_80,
118+
)
119+
120+
121+
if __name__ == "__main__":
122+
args = parse_benchmark_script_args()
123+
124+
common_configs = {
125+
"kernel_name": "dyt",
126+
"x_name": "hidden_size",
127+
"x_label": "hidden size",
128+
"x_values": [2**i for i in range(10, 15)],
129+
"kernel_providers": ["liger", "torch", "torch_compile"],
130+
"extra_benchmark_configs": [{"BT": 4096, "dtype": torch.float32}],
131+
"overwrite": args.overwrite,
132+
}
133+
134+
run_benchmarks(
135+
bench_test_fn=bench_speed_dyt,
136+
kernel_operation_modes=["forward", "backward", "full"],
137+
metric_name="speed",
138+
metric_unit="ms",
139+
**common_configs,
140+
)
141+
run_benchmarks(
142+
bench_test_fn=bench_memory_dyt,
143+
kernel_operation_modes=["full"],
144+
metric_name="memory",
145+
metric_unit="MB",
146+
**common_configs,
147+
)

test/transformers/test_dyt.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313

1414

1515
class TorchDyT(nn.Module):
16-
def __init__(self, hidden_size, init_alpha, dtype):
16+
def __init__(self, hidden_size, init_alpha=0.5):
1717
super().__init__()
1818
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
1919
self.gamma = nn.Parameter(torch.ones(hidden_size))
2020
self.beta = nn.Parameter(torch.zeros(hidden_size))
21-
self.dtype = dtype
2221

2322
def forward(self, x):
24-
return (self.gamma * torch.tanh((self.alpha * x).to(torch.float32)) + self.beta).to(self.dtype)
23+
return self.gamma * torch.tanh(self.alpha * x) + self.beta
2524

2625

2726
set_seed(42)
@@ -55,12 +54,16 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
5554
gamma = torch.randn(hidden_size, device=device, dtype=dtype)
5655
beta = torch.randn(hidden_size, device=device, dtype=dtype)
5756

58-
torch_dyt = TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha, dtype=dtype).to(device).to(dtype)
57+
torch_dyt = (
58+
TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
59+
)
5960
torch_dyt.alpha.data = alpha.clone()
6061
torch_dyt.gamma.data = gamma.clone()
6162
torch_dyt.beta.data = beta.clone()
6263

63-
liger_dyt = LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
64+
liger_dyt = (
65+
LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
66+
)
6467
liger_dyt.alpha.data = alpha.clone()
6568
liger_dyt.gamma.data = gamma.clone()
6669
liger_dyt.beta.data = beta.clone()
@@ -75,9 +78,15 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
7578
liger_output.backward(grad_output)
7679

7780
assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
78-
assert_verbose_allclose(torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol)
79-
assert_verbose_allclose(torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol)
80-
assert_verbose_allclose(torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol)
81+
assert_verbose_allclose(
82+
torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol
83+
)
84+
assert_verbose_allclose(
85+
torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol
86+
)
87+
assert_verbose_allclose(
88+
torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol
89+
)
8190

8291

8392
@pytest.mark.parametrize(
@@ -99,7 +108,9 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
99108
torch.bfloat16,
100109
1e-8,
101110
5e-2,
102-
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
111+
marks=pytest.mark.skipif(
112+
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
113+
),
103114
),
104115
],
105116
)
@@ -128,8 +139,8 @@ def test_liger_dyt_functional(B, T, hidden_size, dtype, atol, rtol):
128139
assert_verbose_allclose(output1, output2, rtol=rtol, atol=atol)
129140

130141
grad_output = torch.randn_like(_input)
131-
output1.backward(grad_output, retain_graph=True)
132-
output2.backward(grad_output, retain_graph=True)
142+
output1.backward(grad_output)
143+
output2.backward(grad_output)
133144

134145
assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
135146
assert_verbose_allclose(alpha1.grad, alpha2.grad, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)