Skip to content

Commit 6f6c73f

Browse files
authored
Merge pull request #17 from amd/add-average-pool-example
Add average pooling example (row-wise mean reduction)
2 parents 2715977 + b1d91d3 commit 6f6c73f

5 files changed

Lines changed: 326 additions & 4 deletions

File tree

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
# Mean subtraction (centering) kernel for AMD XDNA NPU
5+
# Computes: y[i,j] = x[i,j] - mean(x[i,:]) per row
6+
#
7+
# This is the 2D-output form of average pooling that matches the rms_norm
8+
# reduction pattern. The 1D output form (just storing the mean) hits a
9+
# 4-byte DMA alignment constraint on AIE (memref<1xbf16> = 2 bytes < 4).
10+
# Broadcasting the mean back to [BLOCK_M, BLOCK_N] via subtraction avoids
11+
# this constraint (output DMA is [1, 256] = 512 bytes per tile).
12+
#
13+
# Uses BLOCK_M=2 (2D tiling) to avoid the scalar chain issue.
14+
15+
import torch
16+
import triton
17+
import triton.language as tl
18+
import sys, os
19+
20+
sys.path.append(os.path.abspath(".."))
21+
import benchmark
22+
23+
24+
@triton.jit
25+
def avg_pool_kernel(
26+
X,
27+
Y,
28+
N: tl.constexpr,
29+
BLOCK_M: tl.constexpr,
30+
BLOCK_N: tl.constexpr,
31+
):
32+
pid = tl.program_id(0)
33+
row_start = pid * BLOCK_M
34+
rows = row_start + tl.arange(0, BLOCK_M)
35+
cols = tl.arange(0, BLOCK_N)
36+
37+
# Load BLOCK_M rows at once (2D block)
38+
offsets = rows[:, None] * N + cols[None, :]
39+
x = tl.load(X + offsets)
40+
41+
# Sum per row in bf16 (AIE2P only supports bf16 vector add)
42+
row_sum = tl.sum(x, axis=1) # [BLOCK_M], bf16
43+
44+
# Divide by N in f32 (divf is f32-only on AIE2P)
45+
mean = row_sum.to(tl.float32) / N
46+
47+
# Subtract mean from input (2D output, broadcasts mean across columns)
48+
x_f32 = x.to(tl.float32)
49+
y = x_f32 - mean[:, None]
50+
y = y.to(x.dtype)
51+
52+
tl.store(Y + offsets, y)
53+
54+
55+
def bench_avg_pool(M, N, provider):
56+
device = "cpu"
57+
dtype = torch.bfloat16
58+
BLOCK_M = 2
59+
x = torch.randn(M, N, device=device, dtype=dtype)
60+
y = torch.empty(M, N, device=device, dtype=dtype)
61+
if provider == "torch" or provider == "test":
62+
x_f32 = x.float()
63+
mean = x_f32.mean(dim=-1, keepdim=True)
64+
y_ref = (x_f32 - mean).to(dtype)
65+
if provider == "triton" or provider == "test":
66+
grid = (M // BLOCK_M,)
67+
compiled_kernel = avg_pool_kernel[grid](
68+
x,
69+
y,
70+
N,
71+
BLOCK_M=BLOCK_M,
72+
BLOCK_N=N,
73+
)
74+
with open("tt.shared.mlir", "w") as f:
75+
f.write(str(compiled_kernel.asm["ttsharedir"]))
76+
if provider == "test":
77+
torch.testing.assert_close(y, y_ref, atol=5e-1, rtol=1e-1)
78+
79+
80+
if __name__ == "__main__":
81+
benchmark.select_npu_backend()
82+
for M in [32, 64]:
83+
for N in [256]:
84+
bench_avg_pool(M, N, "test")
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
// Mean subtraction transform for AIE2.
5+
// y = x - mean(x, dim=-1)
6+
//
7+
// 2D kernel (BLOCK_M=2 x BLOCK_N=256) with 2D output [BLOCK_M, BLOCK_N].
8+
// Follows the rms_norm reduction pattern exactly: tile [1], fuse_multi_op,
9+
// linalg_promote, herd + DMA, vectorize, type casts.
10+
11+
module attributes {transform.with_named_sequence} {
12+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
13+
%func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
14+
transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization
15+
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op
16+
transform.apply_cse to %func0 : !transform.any_op
17+
%reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
18+
%tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op
19+
%func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
20+
transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op
21+
transform.apply_cse to %func1a : !transform.any_op
22+
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
23+
%f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
24+
%fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
25+
transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op
26+
transform.apply_cse to %fa : !transform.any_op
27+
28+
// After fusion: "out" generic [2, 256] (x - mean broadcast), reduce, fill
29+
// No "sq" generic (unlike rms_norm) -- the reduce directly sums x.
30+
%out = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
31+
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
32+
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
33+
34+
// L2 output alloc
35+
%ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
36+
// Tile at [1] on row dim (same as rms_norm)
37+
%t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
38+
// Fuse all into forall
39+
%f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
40+
%f2, %fl2 = transform.structured.fuse_into_containing_op %fill into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
41+
42+
// L1 for fills only (destination-only)
43+
%fills3 = transform.structured.match ops{["linalg.fill"]} in %fl2 : (!transform.any_op) -> !transform.any_op
44+
%fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3
45+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
46+
47+
// Canonicalize + bufferize
48+
%f2c = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
49+
transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op
50+
transform.apply_cse to %f2c : !transform.any_op
51+
%fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
52+
%fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op
53+
%f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
54+
transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op
55+
transform.apply_cse to %f6 : !transform.any_op
56+
%lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
57+
%mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op
58+
%fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op)
59+
%fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op)
60+
61+
// L1 promote
62+
%forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
63+
%gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op
64+
%reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op
65+
%all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op
66+
%promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op
67+
68+
// Post-promote cleanup
69+
%f_pp = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
70+
transform.apply_patterns to %f_pp { transform.apply_patterns.canonicalization } : !transform.any_op
71+
transform.apply_cse to %f_pp : !transform.any_op
72+
%f_pp2 = transform.air.remove_uninitialized_copy %f_pp : (!transform.any_op) -> (!transform.any_op)
73+
%f_pp3 = transform.air.eliminate_cascade_memcpy %f_pp2 : (!transform.any_op) -> (!transform.any_op)
74+
75+
// Herd + DMA
76+
%fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
77+
%pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op
78+
%h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op
79+
%lc2 = transform.structured.match ops{["linalg.copy"]} in %h : (!transform.any_op) -> !transform.any_op
80+
%mc2 = transform.structured.match ops{["memref.copy"]} in %h : (!transform.any_op) -> !transform.any_op
81+
%mc3 = transform.structured.linalg_copy_to_memref %lc2 : (!transform.any_op) -> !transform.any_op
82+
%ac = transform.merge_handles %mc2, %mc3 { deduplicate } : !transform.any_op
83+
%dm = transform.air.copy_to_dma %ac : (!transform.any_op) -> !transform.any_op
84+
85+
// Vectorization (same as rms_norm)
86+
%h2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
87+
%gens_h = transform.structured.match ops{["linalg.generic"]} in %h2 : (!transform.any_op) -> !transform.any_op
88+
%inner_g, %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
89+
%reds_h = transform.structured.match ops{["linalg.reduce"]} in %h2 : (!transform.any_op) -> !transform.any_op
90+
%inner_r, %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
91+
%fills_h = transform.structured.match ops{["linalg.fill"]} in %h2 : (!transform.any_op) -> !transform.any_op
92+
%fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op) -> !transform.any_op
93+
%vh = transform.air.herd_vectorize %h2 : (!transform.any_op) -> !transform.any_op
94+
95+
// Lower vector reductions
96+
%func_final = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
97+
transform.apply_patterns to %func_final {
98+
transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
99+
transform.apply_patterns.vector.lower_contraction
100+
transform.apply_patterns.vector.lower_transfer
101+
} : !transform.any_op
102+
transform.apply_cse to %func_final : !transform.any_op
103+
104+
// AIE2 type casts: mulf/addf/subf bf16-only, divf f32-only
105+
%vh2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
106+
%vector_adds = transform.structured.match ops{["arith.addf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
107+
%add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
108+
%vector_subs = transform.structured.match ops{["arith.subf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
109+
%sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
110+
%func_s1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
111+
%func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op) -> !transform.any_op
112+
transform.apply_patterns to %func_s1_done {
113+
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
114+
transform.apply_patterns.canonicalization
115+
} : !transform.any_op
116+
transform.apply_cse to %func_s1_done : !transform.any_op
117+
transform.yield
118+
}
119+
}

0 commit comments

Comments
 (0)