Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions examples/average_pool/average_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

# Mean subtraction (centering) kernel for AMD XDNA NPU
# Computes: y[i,j] = x[i,j] - mean(x[i,:]) per row
#
# This is the 2D-output form of average pooling that matches the rms_norm
# reduction pattern. The 1D output form (just storing the mean) hits a
# 4-byte DMA alignment constraint on AIE (memref<1xbf16> = 2 bytes < 4).
# Broadcasting the mean back to [BLOCK_M, BLOCK_N] via subtraction avoids
# this constraint (output DMA is [1, 256] = 512 bytes per tile).
#
# Uses BLOCK_M=2 (2D tiling) to avoid the scalar chain issue.

import torch
import triton
import triton.language as tl
import sys, os

sys.path.append(os.path.abspath(".."))
import benchmark


@triton.jit
def avg_pool_kernel(
X,
Y,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
row_start = pid * BLOCK_M
rows = row_start + tl.arange(0, BLOCK_M)
cols = tl.arange(0, BLOCK_N)

# Load BLOCK_M rows at once (2D block)
offsets = rows[:, None] * N + cols[None, :]
x = tl.load(X + offsets)

# Sum per row in bf16 (AIE2P only supports bf16 vector add)
row_sum = tl.sum(x, axis=1) # [BLOCK_M], bf16

# Divide by N in f32 (divf is f32-only on AIE2P)
mean = row_sum.to(tl.float32) / N

# Subtract mean from input (2D output, broadcasts mean across columns)
x_f32 = x.to(tl.float32)
y = x_f32 - mean[:, None]
y = y.to(x.dtype)

tl.store(Y + offsets, y)


def bench_avg_pool(M, N, provider):
device = "cpu"
dtype = torch.bfloat16
BLOCK_M = 2
x = torch.randn(M, N, device=device, dtype=dtype)
y = torch.empty(M, N, device=device, dtype=dtype)
if provider == "torch" or provider == "test":
x_f32 = x.float()
mean = x_f32.mean(dim=-1, keepdim=True)
y_ref = (x_f32 - mean).to(dtype)
if provider == "triton" or provider == "test":
grid = (M // BLOCK_M,)
compiled_kernel = avg_pool_kernel[grid](
x,
y,
N,
BLOCK_M=BLOCK_M,
BLOCK_N=N,
)
with open("tt.shared.mlir", "w") as f:
f.write(str(compiled_kernel.asm["ttsharedir"]))
if provider == "test":
torch.testing.assert_close(y, y_ref, atol=5e-1, rtol=1e-1)


if __name__ == "__main__":
benchmark.select_npu_backend()
for M in [32, 64]:
for N in [256]:
bench_avg_pool(M, N, "test")
119 changes: 119 additions & 0 deletions examples/average_pool/transform_aie2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT

// Mean subtraction transform for AIE2.
// y = x - mean(x, dim=-1)
//
// 2D kernel (BLOCK_M=2 x BLOCK_N=256) with 2D output [BLOCK_M, BLOCK_N].
// Follows the rms_norm reduction pattern exactly: tile [1], fuse_multi_op,
// linalg_promote, herd + DMA, vectorize, type casts.

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op
transform.apply_cse to %func0 : !transform.any_op
%reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op
%func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %func1a : !transform.any_op
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
%fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %fa : !transform.any_op

// After fusion: "out" generic [2, 256] (x - mean broadcast), reduce, fill
// No "sq" generic (unlike rms_norm) -- the reduce directly sums x.
%out = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op

// L2 output alloc
%ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
// Tile at [1] on row dim (same as rms_norm)
%t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Fuse all into forall
%f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f2, %fl2 = transform.structured.fuse_into_containing_op %fill into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

// L1 for fills only (destination-only)
%fills3 = transform.structured.match ops{["linalg.fill"]} in %fl2 : (!transform.any_op) -> !transform.any_op
%fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op

// Canonicalize + bufferize
%f2c = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %f2c : !transform.any_op
%fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op
%f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %f6 : !transform.any_op
%lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op
%fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op)
%fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op)

// L1 promote
%forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op
%reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op
%all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op
%promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op

// Post-promote cleanup
%f_pp = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f_pp { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %f_pp : !transform.any_op
%f_pp2 = transform.air.remove_uninitialized_copy %f_pp : (!transform.any_op) -> (!transform.any_op)
%f_pp3 = transform.air.eliminate_cascade_memcpy %f_pp2 : (!transform.any_op) -> (!transform.any_op)

// Herd + DMA
%fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op
%h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op
%lc2 = transform.structured.match ops{["linalg.copy"]} in %h : (!transform.any_op) -> !transform.any_op
%mc2 = transform.structured.match ops{["memref.copy"]} in %h : (!transform.any_op) -> !transform.any_op
%mc3 = transform.structured.linalg_copy_to_memref %lc2 : (!transform.any_op) -> !transform.any_op
%ac = transform.merge_handles %mc2, %mc3 { deduplicate } : !transform.any_op
%dm = transform.air.copy_to_dma %ac : (!transform.any_op) -> !transform.any_op

// Vectorization (same as rms_norm)
%h2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%gens_h = transform.structured.match ops{["linalg.generic"]} in %h2 : (!transform.any_op) -> !transform.any_op
%inner_g, %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%reds_h = transform.structured.match ops{["linalg.reduce"]} in %h2 : (!transform.any_op) -> !transform.any_op
%inner_r, %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fills_h = transform.structured.match ops{["linalg.fill"]} in %h2 : (!transform.any_op) -> !transform.any_op
%fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op) -> !transform.any_op
%vh = transform.air.herd_vectorize %h2 : (!transform.any_op) -> !transform.any_op

// Lower vector reductions
%func_final = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_final {
transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
transform.apply_patterns.vector.lower_contraction
transform.apply_patterns.vector.lower_transfer
} : !transform.any_op
transform.apply_cse to %func_final : !transform.any_op

// AIE2 type casts: mulf/addf/subf bf16-only, divf f32-only
%vh2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%vector_adds = transform.structured.match ops{["arith.addf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
%add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
%vector_subs = transform.structured.match ops{["arith.subf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
%sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
%func_s1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_s1_done {
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.apply_cse to %func_s1_done : !transform.any_op
transform.yield
}
}
Loading
Loading