Skip to content

Commit 7fd75d1

Browse files
erwei-xilinxclaude
andcommitted
Add SwiGLU example (out = SiLU(gate) * up)
New elementwise example computing SwiGLU activation with two input streams, verified on NPU2 (Strix/AIE2P) hardware. Extends the silu pattern with 3-operand tiling (gate, up, out). AIE2 transform links extern_func.o for math.exp; AIE2P uses native bf16 exp intrinsic. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 548c4f3 commit 7fd75d1

4 files changed

Lines changed: 439 additions & 0 deletions

File tree

examples/swiglu/extern_func.o

5.86 KB
Binary file not shown.

examples/swiglu/swiglu.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
import sys, os
8+
9+
sys.path.append(os.path.abspath(".."))
10+
import benchmark
11+
12+
13+
@triton.jit
14+
def swiglu_kernel(
15+
GATE,
16+
UP,
17+
OUT,
18+
n_elements: tl.constexpr,
19+
BLOCK_SIZE: tl.constexpr,
20+
):
21+
pid = tl.program_id(0)
22+
block_start = pid * BLOCK_SIZE
23+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
24+
25+
gate = tl.load(GATE + offsets[:])
26+
up = tl.load(UP + offsets[:])
27+
# SwiGLU(gate, up) = SiLU(gate) * up = gate * sigmoid(gate) * up
28+
# sigmoid requires f32 input
29+
gate_f32 = gate.to(tl.float32)
30+
sig = tl.sigmoid(gate_f32)
31+
silu_gate = (gate_f32 * sig).to(gate.dtype)
32+
out = silu_gate * up
33+
tl.store(OUT + offsets[:], out)
34+
35+
36+
def bench_swiglu(N, provider):
37+
device = "cpu"
38+
dtype = torch.bfloat16
39+
gate = torch.randn(N, device=device, dtype=dtype)
40+
up = torch.randn(N, device=device, dtype=dtype)
41+
out = torch.empty(N, device=device, dtype=dtype)
42+
if provider == "torch" or provider == "test":
43+
out_ref = torch.nn.functional.silu(gate) * up
44+
if provider == "triton" or provider == "test":
45+
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
46+
compiled_kernel = swiglu_kernel[grid](
47+
gate,
48+
up,
49+
out,
50+
N,
51+
BLOCK_SIZE=1024,
52+
)
53+
with open("tt.shared.mlir", "w") as f:
54+
f.write(str(compiled_kernel.asm["ttsharedir"]))
55+
if provider == "test":
56+
torch.testing.assert_close(out, out_ref, atol=1e-1, rtol=1e-1)
57+
58+
59+
if __name__ == "__main__":
60+
benchmark.select_npu_backend()
61+
for N in [2**i for i in range(10, 16, 1)]:
62+
bench_swiglu(N, "test")
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
////////////////////////////////////////////////////////////////////////////////
5+
// Transform Script for SwiGLU (AIE2): out = SiLU(gate) * up
6+
//
7+
// SwiGLU(gate, up) = gate * sigmoid(gate) * up
8+
//
9+
// The Linalg IR has the silu chain (extf, negf/subf, exp, addf, divf, mulf)
10+
// plus an additional mulf for the final * up. After fuse_elementwise_linalg,
11+
// this becomes a single generic with 2 bf16 inputs (gate, up) and 1 bf16
12+
// output (out).
13+
//
14+
// AIE2 type mapping:
15+
// - math.exp: bf16 ONLY -> needs vector_type_cast
16+
// - arith.divf: f32 ONLY -> keep as f32
17+
// - arith.subf/addf/mulf: bf16 ONLY -> needs vector_type_cast
18+
//
19+
// Strategy: fuse_elementwise_linalg -> 3-operand tiling (like axpy) ->
20+
// vectorize at 16 -> cast exp, subf, addf, mulf to bf16; divf stays f32.
21+
//
22+
// AIE2 requires extern_func.o for math.exp (no native bf16 exp intrinsic).
23+
////////////////////////////////////////////////////////////////////////////////
24+
25+
module attributes {transform.with_named_sequence} {
26+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
27+
28+
//===================================================================
29+
// PHASE 1: Initial Canonicalization
30+
//===================================================================
31+
%func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
32+
transform.apply_patterns to %func0 {
33+
transform.apply_patterns.linalg.tiling_canonicalization
34+
transform.apply_patterns.scf.for_loop_canonicalization
35+
transform.apply_patterns.canonicalization
36+
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes
37+
} : !transform.any_op
38+
transform.apply_cse to %func0 : !transform.any_op
39+
40+
//===================================================================
41+
// PHASE 2: Fuse Elementwise Chain
42+
//===================================================================
43+
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
44+
%func1_fused = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
45+
46+
%func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
47+
transform.apply_patterns to %func1a {
48+
transform.apply_patterns.canonicalization
49+
} : !transform.any_op
50+
transform.apply_cse to %func1a : !transform.any_op
51+
52+
//===================================================================
53+
// PHASE 3: Vec-Add-Style Tiling Pattern
54+
//===================================================================
55+
%op = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56+
57+
%op_flattened = transform.structured.flatten_elementwise %op
58+
: (!transform.any_op) -> !transform.any_op
59+
60+
%op_res_shared, %new_op = transform.structured.bufferize_to_allocation %op_flattened
61+
{memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
62+
63+
%op_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
64+
%tiled_op_1, %forall_op_1 =
65+
transform.structured.tile_using_forall %op_1 tile_sizes [256] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
66+
67+
//===================================================================
68+
// PHASE 4: Canonicalization
69+
//===================================================================
70+
%func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
71+
transform.apply_patterns to %func_2 {
72+
transform.apply_patterns.linalg.tiling_canonicalization
73+
transform.apply_patterns.scf.for_loop_canonicalization
74+
transform.apply_patterns.canonicalization
75+
} : !transform.any_op
76+
transform.apply_cse to %func_2 : !transform.any_op
77+
78+
//===================================================================
79+
// PHASE 5: Pad and Promote to L1 (3 operands: gate, up, out)
80+
//===================================================================
81+
%op_2 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
82+
83+
%padded_op, %pad_op, %__ = transform.structured.pad %op_2 {
84+
padding_values=[0.0 : bf16, 0.0 : bf16, 0.0 : bf16],
85+
padding_dimensions=[0, 1, 2],
86+
nofold_flags=[1, 1, 1],
87+
copy_back_op="linalg.copy"
88+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
89+
90+
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op : (!transform.any_op) -> !transform.any_op
91+
92+
%padded_gate = transform.get_producer_of_operand %padded_op[0] : (!transform.any_op) -> (!transform.any_op)
93+
%padded_gate_buffer, %padded_gate_new = transform.structured.bufferize_to_allocation %padded_gate
94+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
95+
96+
%padded_up = transform.get_producer_of_operand %padded_op[1] : (!transform.any_op) -> (!transform.any_op)
97+
%padded_up_buffer, %padded_up_new = transform.structured.bufferize_to_allocation %padded_up
98+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
99+
100+
%padded_out = transform.get_producer_of_operand %padded_op[2] : (!transform.any_op) -> (!transform.any_op)
101+
%padded_out_buffer, %padded_out_new = transform.structured.bufferize_to_allocation %padded_out
102+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
103+
104+
//===================================================================
105+
// PHASE 6: Canonicalization
106+
//===================================================================
107+
%func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
108+
transform.apply_patterns to %func_3 {
109+
transform.apply_patterns.linalg.tiling_canonicalization
110+
transform.apply_patterns.scf.for_loop_canonicalization
111+
transform.apply_patterns.canonicalization
112+
} : !transform.any_op
113+
transform.apply_cse to %func_3 : !transform.any_op
114+
115+
//===================================================================
116+
// PHASE 7: Bufferization
117+
//===================================================================
118+
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
119+
%func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op
120+
121+
//===================================================================
122+
// PHASE 8: Post-Bufferization Cleanup
123+
//===================================================================
124+
%func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
125+
transform.apply_patterns to %func6 {
126+
transform.apply_patterns.linalg.tiling_canonicalization
127+
transform.apply_patterns.scf.for_loop_canonicalization
128+
transform.apply_patterns.canonicalization
129+
} : !transform.any_op
130+
transform.apply_cse to %func6 : !transform.any_op
131+
transform.apply_patterns to %func6 {
132+
transform.apply_patterns.canonicalization
133+
} : !transform.any_op
134+
%linalg_copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
135+
%memref_copies = transform.structured.linalg_copy_to_memref %linalg_copies : (!transform.any_op) -> !transform.any_op
136+
%func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op
137+
%func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op
138+
139+
//===================================================================
140+
// PHASE 9: Vectorization Tiling (16-lane for bf16)
141+
//===================================================================
142+
%linalg_generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
143+
%inner_most_generics, %vec_loops:1 =
144+
transform.structured.tile_using_for %linalg_generics tile_sizes [16]
145+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
146+
147+
//===================================================================
148+
// PHASE 10: AIR Constructs Mapping + Type Casts
149+
//===================================================================
150+
%forall_as_herd = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
151+
%parallel = transform.loop.forall_to_parallel %forall_as_herd : (!transform.any_op) -> !transform.any_op
152+
%herd = transform.air.par_to_herd %parallel : (!transform.any_op) -> !transform.any_op
153+
154+
// AIE2 needs extern_func.o for math.exp (no native bf16 exp intrinsic)
155+
%extern_func_param = transform.param.constant "extern_func.o" -> !transform.any_param
156+
transform.annotate %herd "link_with" = %extern_func_param : !transform.any_op, !transform.any_param
157+
158+
%copies_in_herd = transform.structured.match ops{["memref.copy", "linalg.copy"]} in %herd : (!transform.any_op) -> !transform.any_op
159+
%dmas_from_copies = transform.air.copy_to_dma %copies_in_herd : (!transform.any_op) -> !transform.any_op
160+
161+
%vectorized_herd = transform.air.herd_vectorize %herd : (!transform.any_op) -> !transform.any_op
162+
163+
// math.exp -> bf16
164+
%vector_exps = transform.structured.match ops{["math.exp"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op
165+
%exp_cast = transform.air.vector_type_cast %vector_exps {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
166+
167+
// arith.subf -> bf16
168+
%vector_subs = transform.structured.match ops{["arith.subf"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op
169+
%sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
170+
171+
// arith.addf -> bf16
172+
%vector_adds = transform.structured.match ops{["arith.addf"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op
173+
%add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
174+
175+
// arith.mulf -> bf16
176+
%vector_muls = transform.structured.match ops{["arith.mulf"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op
177+
%mul_cast = transform.air.vector_type_cast %vector_muls {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
178+
179+
// arith.divf stays f32
180+
181+
transform.yield
182+
}
183+
}

0 commit comments

Comments
 (0)