Skip to content

Commit 70e57e0

Browse files
authored
Merge branch 'main' into add_llava
2 parents af1420b + 812b050 commit 70e57e0

7 files changed

Lines changed: 528 additions & 2 deletions

File tree

benchmark/benchmarks_visualizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import pandas as pd
99
import seaborn as sns
1010

11-
DATA_PATH = "data/all_benchmark_data.csv"
12-
VISUALIZATIONS_PATH = "visualizations/"
11+
DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv"))
12+
VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/"))
1313

1414

1515
@dataclass

benchmark/scripts/benchmark_dyt.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
import triton
6+
7+
from utils import QUANTILES
8+
from utils import SingleBenchmarkRunInput
9+
from utils import SingleBenchmarkRunOutput
10+
from utils import _test_memory
11+
from utils import parse_benchmark_script_args
12+
from utils import run_benchmarks
13+
14+
from liger_kernel.utils import infer_device
15+
16+
device = infer_device()
17+
18+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
19+
20+
21+
def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22+
from test.transformers.test_dyt import LigerDyT
23+
from test.transformers.test_dyt import TorchDyT
24+
25+
BT = input.x
26+
provider = input.kernel_provider
27+
mode = input.kernel_operation_mode
28+
extra_benchmark_config = input.extra_benchmark_config
29+
hidden_size = extra_benchmark_config["hidden_size"]
30+
dtype = extra_benchmark_config["dtype"]
31+
32+
x_shape = (BT, hidden_size)
33+
torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
34+
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
35+
triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
36+
37+
x = torch.randn(x_shape, dtype=dtype, device=device)
38+
dy = torch.randn_like(x)
39+
x.requires_grad_(True)
40+
41+
def fwd():
42+
if provider == "liger":
43+
return triton_dyt(x)
44+
elif provider == "torch":
45+
return torch_dyt(x)
46+
elif provider == "torch_compile":
47+
return torch_compile_dyt(x)
48+
49+
if mode == "forward":
50+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
51+
elif mode == "backward":
52+
y = fwd()
53+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
54+
lambda: y.backward(dy, retain_graph=True),
55+
quantiles=QUANTILES,
56+
grad_to_none=[x],
57+
rep=500,
58+
)
59+
elif mode == "full":
60+
61+
def full():
62+
y = fwd()
63+
y.backward(dy)
64+
65+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
66+
67+
return SingleBenchmarkRunOutput(
68+
y_20=ms_20,
69+
y_50=ms_50,
70+
y_80=ms_80,
71+
)
72+
73+
74+
def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
75+
from test.transformers.test_dyt import LigerDyT
76+
from test.transformers.test_dyt import TorchDyT
77+
78+
BT = input.x
79+
provider = input.kernel_provider
80+
extra_benchmark_config = input.extra_benchmark_config
81+
hidden_size = extra_benchmark_config["hidden_size"]
82+
dtype = extra_benchmark_config["dtype"]
83+
84+
x_shape = (BT, hidden_size)
85+
torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
86+
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
87+
triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
88+
89+
x = torch.randn(x_shape, dtype=dtype, device=device)
90+
dy = torch.randn_like(x)
91+
x.requires_grad_(True)
92+
93+
def fwd():
94+
if provider == "liger":
95+
return triton_dyt(x)
96+
elif provider == "torch":
97+
return torch_dyt(x)
98+
elif provider == "torch_compile":
99+
return torch_compile_dyt(x)
100+
101+
def full():
102+
y = fwd()
103+
y.backward(dy, retain_graph=True)
104+
105+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
106+
return SingleBenchmarkRunOutput(
107+
y_20=mem_20,
108+
y_50=mem_50,
109+
y_80=mem_80,
110+
)
111+
112+
113+
if __name__ == "__main__":
114+
args = parse_benchmark_script_args()
115+
116+
common_configs = {
117+
"kernel_name": "dyt",
118+
"x_name": "BT",
119+
"x_label": "batch_size * seq_len",
120+
"x_values": [2**i for i in range(10, 15)],
121+
"kernel_providers": ["liger", "torch", "torch_compile"],
122+
"extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
123+
"overwrite": args.overwrite,
124+
}
125+
126+
run_benchmarks(
127+
bench_test_fn=bench_speed_dyt,
128+
kernel_operation_modes=["forward", "backward", "full"],
129+
metric_name="speed",
130+
metric_unit="ms",
131+
**common_configs,
132+
)
133+
run_benchmarks(
134+
bench_test_fn=bench_memory_dyt,
135+
kernel_operation_modes=["full"],
136+
metric_name="memory",
137+
metric_unit="MB",
138+
**common_configs,
139+
)

src/liger_kernel/ops/dyt.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import operator
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
from liger_kernel.ops.utils import calculate_settings
8+
from liger_kernel.ops.utils import compare_version
9+
from liger_kernel.ops.utils import ensure_contiguous
10+
from liger_kernel.ops.utils import infer_device
11+
12+
if compare_version("triton", operator.ge, "3.0.0"):
13+
try:
14+
# typical import path with dispatch available
15+
from triton.language.extra.libdevice import tanh
16+
except ModuleNotFoundError:
17+
# for working with NGC containers
18+
from triton.language.extra.cuda.libdevice import tanh
19+
else:
20+
from triton.language.math import tanh
21+
22+
23+
@triton.jit
24+
def _dyt_fwd_kernel(
25+
x_ptr,
26+
x_row_stride,
27+
alpha_ptr,
28+
gamma_ptr,
29+
beta_ptr,
30+
y_ptr,
31+
y_row_stride,
32+
n_cols,
33+
BLOCK_SIZE: tl.constexpr,
34+
):
35+
"""
36+
Reference:
37+
https://arxiv.org/abs/2503.10622
38+
39+
Shapes:
40+
- x: (BT, C)
41+
- alpha: (1)
42+
- gamma: (C)
43+
- beta: (C)
44+
"""
45+
row_idx = tl.program_id(0)
46+
offsets = tl.arange(0, BLOCK_SIZE)
47+
mask = offsets < n_cols
48+
49+
x_ptr += row_idx * x_row_stride
50+
y_ptr += row_idx * y_row_stride
51+
52+
alpha = tl.load(alpha_ptr)
53+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
54+
beta = tl.load(beta_ptr + offsets, mask=mask)
55+
x = tl.load(x_ptr + offsets, mask=mask)
56+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57+
tl.store(y_ptr + offsets, y, mask=mask)
58+
59+
60+
@triton.jit
61+
def _dyt_bwd_kernel(
62+
x_ptr,
63+
x_row_stride,
64+
dy_ptr,
65+
dy_row_stride,
66+
dx_ptr,
67+
dx_row_stride,
68+
alpha_ptr,
69+
dalpha_ptr,
70+
gamma_ptr,
71+
dgamma_ptr,
72+
dgamma_row_stride,
73+
n_cols,
74+
n_rows,
75+
ROWS_PER_PROGRAM: tl.constexpr,
76+
BLOCK_SIZE: tl.constexpr,
77+
):
78+
"""
79+
Reference:
80+
https://arxiv.org/abs/2503.10622
81+
82+
Shapes:
83+
- x: (BT, C)
84+
- alpha: (1)
85+
- gamma: (C)
86+
- dx: (BT, C)
87+
- dy: (BT, C)
88+
- dgamma: (sm_count, C)
89+
- dalpha: (sm_count,)
90+
"""
91+
# d(gamma * tanh(alpha * x) + beta) / dx
92+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
93+
# d(gamma * tanh(alpha * x) + beta) / dalpha
94+
# = gamma * (1 - tanh^2(alpha * x)) * x
95+
# d(gamma * tanh(alpha * x) + beta) / dgamma
96+
# = tanh(alpha * x)
97+
# d(gamma * tanh(alpha * x)) / dbeta = 1
98+
pid = tl.program_id(0)
99+
100+
row_start = pid * ROWS_PER_PROGRAM
101+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102+
offsets = tl.arange(0, BLOCK_SIZE)
103+
mask = offsets < n_cols
104+
105+
dalpha = 0.0
106+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107+
108+
x_ptr += row_start * x_row_stride
109+
dx_ptr += row_start * dx_row_stride
110+
dy_ptr += row_start * dy_row_stride
111+
alpha = tl.load(alpha_ptr)
112+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113+
114+
for _ in tl.range(row_start, row_end):
115+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117+
tanh_ax = tanh((alpha * x).cast(tl.float32))
118+
sech2_ax = 1 - tanh_ax * tanh_ax
119+
120+
dx = dy * gamma * sech2_ax * alpha
121+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
122+
dgamma += dy * tanh_ax
123+
tl.store(dx_ptr + offsets, dx, mask=mask)
124+
125+
dy_ptr += dy_row_stride
126+
x_ptr += x_row_stride
127+
dx_ptr += dx_row_stride
128+
129+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130+
tl.store(dalpha_ptr + pid, dalpha)
131+
132+
pass
133+
134+
135+
def liger_dyt_fwd(x, alpha, gamma, beta):
136+
shape = x.shape
137+
dim = shape[-1]
138+
x = x.view(-1, dim)
139+
n_rows, n_cols = x.shape
140+
y = torch.empty_like(x)
141+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142+
_dyt_fwd_kernel[(n_rows,)](
143+
x_ptr=x,
144+
alpha_ptr=alpha,
145+
gamma_ptr=gamma,
146+
beta_ptr=beta,
147+
y_ptr=y,
148+
x_row_stride=x.stride(0),
149+
y_row_stride=y.stride(0),
150+
n_cols=n_cols,
151+
BLOCK_SIZE=BLOCK_SIZE,
152+
num_warps=num_warps,
153+
)
154+
return y.view(*shape)
155+
156+
157+
def liger_dyt_bwd(dy, x, alpha, gamma):
158+
shape = dy.shape
159+
dtype = x.dtype
160+
dim = shape[-1]
161+
dy = dy.view(-1, dim)
162+
x = x.view(-1, dim)
163+
n_rows, n_cols = dy.shape
164+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165+
sm_count = 1
166+
device = infer_device()
167+
if device == "cuda":
168+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169+
elif device == "xpu":
170+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171+
if n_cols > BLOCK_SIZE:
172+
raise RuntimeError(
173+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174+
)
175+
176+
dx = torch.empty_like(x, dtype=torch.float32)
177+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179+
180+
grid = (sm_count,)
181+
rows_per_program = triton.cdiv(n_rows, sm_count)
182+
_dyt_bwd_kernel[grid](
183+
x_ptr=x,
184+
x_row_stride=x.stride(0),
185+
dy_ptr=dy,
186+
dy_row_stride=dy.stride(0),
187+
dx_ptr=dx,
188+
dx_row_stride=dx.stride(0),
189+
alpha_ptr=alpha,
190+
dalpha_ptr=_dalpha,
191+
gamma_ptr=gamma,
192+
dgamma_ptr=_dgamma,
193+
dgamma_row_stride=_dgamma.stride(0),
194+
n_cols=n_cols,
195+
n_rows=n_rows,
196+
ROWS_PER_PROGRAM=rows_per_program,
197+
BLOCK_SIZE=BLOCK_SIZE,
198+
num_warps=num_warps,
199+
)
200+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201+
dgamma = _dgamma.sum(dim=0).to(dtype)
202+
dbeta = dy.sum(dim=0).to(dtype)
203+
return dx.view(*shape), dalpha, dgamma, dbeta
204+
205+
206+
class LigerDyTFunction(torch.autograd.Function):
207+
@staticmethod
208+
@ensure_contiguous
209+
def forward(ctx, x, alpha, gamma, beta):
210+
y = liger_dyt_fwd(x, alpha, gamma, beta)
211+
ctx.save_for_backward(x, alpha, gamma)
212+
return y
213+
214+
@staticmethod
215+
@ensure_contiguous
216+
def backward(ctx, grad_output):
217+
x, alpha, gamma = ctx.saved_tensors
218+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219+
grad_output,
220+
x,
221+
alpha,
222+
gamma,
223+
)
224+
225+
return (dx, dalpha, dgamma, dbeta)

src/liger_kernel/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
22
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
3+
from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
34
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
45
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
56
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401

0 commit comments

Comments
 (0)