diff --git a/examples/average_pool/average_pool.py b/examples/average_pool/average_pool.py new file mode 100644 index 0000000..2c959aa --- /dev/null +++ b/examples/average_pool/average_pool.py @@ -0,0 +1,84 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +# Mean subtraction (centering) kernel for AMD XDNA NPU +# Computes: y[i,j] = x[i,j] - mean(x[i,:]) per row +# +# This is the 2D-output form of average pooling that matches the rms_norm +# reduction pattern. The 1D output form (just storing the mean) hits a +# 4-byte DMA alignment constraint on AIE (memref<1xbf16> = 2 bytes < 4). +# Broadcasting the mean back to [BLOCK_M, BLOCK_N] via subtraction avoids +# this constraint (output DMA is [1, 256] = 512 bytes per tile). +# +# Uses BLOCK_M=2 (2D tiling) to avoid the scalar chain issue. + +import torch +import triton +import triton.language as tl +import sys, os + +sys.path.append(os.path.abspath("..")) +import benchmark + + +@triton.jit +def avg_pool_kernel( + X, + Y, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * BLOCK_M + rows = row_start + tl.arange(0, BLOCK_M) + cols = tl.arange(0, BLOCK_N) + + # Load BLOCK_M rows at once (2D block) + offsets = rows[:, None] * N + cols[None, :] + x = tl.load(X + offsets) + + # Sum per row in bf16 (AIE2P only supports bf16 vector add) + row_sum = tl.sum(x, axis=1) # [BLOCK_M], bf16 + + # Divide by N in f32 (divf is f32-only on AIE2P) + mean = row_sum.to(tl.float32) / N + + # Subtract mean from input (2D output, broadcasts mean across columns) + x_f32 = x.to(tl.float32) + y = x_f32 - mean[:, None] + y = y.to(x.dtype) + + tl.store(Y + offsets, y) + + +def bench_avg_pool(M, N, provider): + device = "cpu" + dtype = torch.bfloat16 + BLOCK_M = 2 + x = torch.randn(M, N, device=device, dtype=dtype) + y = torch.empty(M, N, device=device, dtype=dtype) + if provider == "torch" or provider == "test": + x_f32 = x.float() + mean = x_f32.mean(dim=-1, keepdim=True) + y_ref = (x_f32 - mean).to(dtype) + if provider == "triton" or provider == "test": + grid = (M // BLOCK_M,) + compiled_kernel = avg_pool_kernel[grid]( + x, + y, + N, + BLOCK_M=BLOCK_M, + BLOCK_N=N, + ) + with open("tt.shared.mlir", "w") as f: + f.write(str(compiled_kernel.asm["ttsharedir"])) + if provider == "test": + torch.testing.assert_close(y, y_ref, atol=5e-1, rtol=1e-1) + + +if __name__ == "__main__": + benchmark.select_npu_backend() + for M in [32, 64]: + for N in [256]: + bench_avg_pool(M, N, "test") diff --git a/examples/average_pool/transform_aie2.mlir b/examples/average_pool/transform_aie2.mlir new file mode 100644 index 0000000..6b7d325 --- /dev/null +++ b/examples/average_pool/transform_aie2.mlir @@ -0,0 +1,119 @@ +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT + +// Mean subtraction transform for AIE2. +// y = x - mean(x, dim=-1) +// +// 2D kernel (BLOCK_M=2 x BLOCK_N=256) with 2D output [BLOCK_M, BLOCK_N]. +// Follows the rms_norm reduction pattern exactly: tile [1], fuse_multi_op, +// linalg_promote, herd + DMA, vectorize, type casts. + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op + transform.apply_cse to %func0 : !transform.any_op + %reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op + %func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %func1a : !transform.any_op + %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op + %fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %fa : !transform.any_op + + // After fusion: "out" generic [2, 256] (x - mean broadcast), reduce, fill + // No "sq" generic (unlike rms_norm) -- the reduce directly sums x. + %out = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // L2 output alloc + %ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op + // Tile at [1] on row dim (same as rms_norm) + %t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // Fuse all into forall + %f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f2, %fl2 = transform.structured.fuse_into_containing_op %fill into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // L1 for fills only (destination-only) + %fills3 = transform.structured.match ops{["linalg.fill"]} in %fl2 : (!transform.any_op) -> !transform.any_op + %fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3 + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + + // Canonicalize + bufferize + %f2c = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %f2c : !transform.any_op + %fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op + %f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %f6 : !transform.any_op + %lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op + %fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op) + %fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op) + + // L1 promote + %forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op + %reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op + %all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op + %promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op + + // Post-promote cleanup + %f_pp = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f_pp { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %f_pp : !transform.any_op + %f_pp2 = transform.air.remove_uninitialized_copy %f_pp : (!transform.any_op) -> (!transform.any_op) + %f_pp3 = transform.air.eliminate_cascade_memcpy %f_pp2 : (!transform.any_op) -> (!transform.any_op) + + // Herd + DMA + %fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op + %h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op + %lc2 = transform.structured.match ops{["linalg.copy"]} in %h : (!transform.any_op) -> !transform.any_op + %mc2 = transform.structured.match ops{["memref.copy"]} in %h : (!transform.any_op) -> !transform.any_op + %mc3 = transform.structured.linalg_copy_to_memref %lc2 : (!transform.any_op) -> !transform.any_op + %ac = transform.merge_handles %mc2, %mc3 { deduplicate } : !transform.any_op + %dm = transform.air.copy_to_dma %ac : (!transform.any_op) -> !transform.any_op + + // Vectorization (same as rms_norm) + %h2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %gens_h = transform.structured.match ops{["linalg.generic"]} in %h2 : (!transform.any_op) -> !transform.any_op + %inner_g, %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %reds_h = transform.structured.match ops{["linalg.reduce"]} in %h2 : (!transform.any_op) -> !transform.any_op + %inner_r, %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fills_h = transform.structured.match ops{["linalg.fill"]} in %h2 : (!transform.any_op) -> !transform.any_op + %fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op) -> !transform.any_op + %vh = transform.air.herd_vectorize %h2 : (!transform.any_op) -> !transform.any_op + + // Lower vector reductions + %func_final = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_final { + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" + transform.apply_patterns.vector.lower_contraction + transform.apply_patterns.vector.lower_transfer + } : !transform.any_op + transform.apply_cse to %func_final : !transform.any_op + + // AIE2 type casts: mulf/addf/subf bf16-only, divf f32-only + %vh2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %vector_adds = transform.structured.match ops{["arith.addf"]} in %vh2 : (!transform.any_op) -> !transform.any_op + %add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op + %vector_subs = transform.structured.match ops{["arith.subf"]} in %vh2 : (!transform.any_op) -> !transform.any_op + %sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op + %func_s1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_s1_done { + transform.apply_patterns.vector.cast_away_vector_leading_one_dim + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_s1_done : !transform.any_op + transform.yield + } +} diff --git a/examples/average_pool/transform_aie2p.mlir b/examples/average_pool/transform_aie2p.mlir new file mode 100644 index 0000000..2a6e779 --- /dev/null +++ b/examples/average_pool/transform_aie2p.mlir @@ -0,0 +1,119 @@ +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT + +// Mean subtraction transform for AIE2P. +// y = x - mean(x, dim=-1) +// +// 2D kernel (BLOCK_M=2 x BLOCK_N=256) with 2D output [BLOCK_M, BLOCK_N]. +// Follows the rms_norm reduction pattern exactly: tile [1], fuse_multi_op, +// linalg_promote, herd + DMA, vectorize, type casts. + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op + transform.apply_cse to %func0 : !transform.any_op + %reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op + %func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %func1a : !transform.any_op + %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op + %fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %fa : !transform.any_op + + // After fusion: "out" generic [2, 256] (x - mean broadcast), reduce, fill + // No "sq" generic (unlike rms_norm) -- the reduce directly sums x. + %out = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // L2 output alloc + %ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op + // Tile at [1] on row dim (same as rms_norm) + %t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // Fuse all into forall + %f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f2, %fl2 = transform.structured.fuse_into_containing_op %fill into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // L1 for fills only (destination-only) + %fills3 = transform.structured.match ops{["linalg.fill"]} in %fl2 : (!transform.any_op) -> !transform.any_op + %fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3 + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + + // Canonicalize + bufferize + %f2c = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %f2c : !transform.any_op + %fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op + %f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %f6 : !transform.any_op + %lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op + %fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op) + %fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op) + + // L1 promote + %forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op + %reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op + %all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op + %promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op + + // Post-promote cleanup + %f_pp = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f_pp { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %f_pp : !transform.any_op + %f_pp2 = transform.air.remove_uninitialized_copy %f_pp : (!transform.any_op) -> (!transform.any_op) + %f_pp3 = transform.air.eliminate_cascade_memcpy %f_pp2 : (!transform.any_op) -> (!transform.any_op) + + // Herd + DMA + %fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op + %h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op + %lc2 = transform.structured.match ops{["linalg.copy"]} in %h : (!transform.any_op) -> !transform.any_op + %mc2 = transform.structured.match ops{["memref.copy"]} in %h : (!transform.any_op) -> !transform.any_op + %mc3 = transform.structured.linalg_copy_to_memref %lc2 : (!transform.any_op) -> !transform.any_op + %ac = transform.merge_handles %mc2, %mc3 { deduplicate } : !transform.any_op + %dm = transform.air.copy_to_dma %ac : (!transform.any_op) -> !transform.any_op + + // Vectorization (same as rms_norm) + %h2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %gens_h = transform.structured.match ops{["linalg.generic"]} in %h2 : (!transform.any_op) -> !transform.any_op + %inner_g, %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %reds_h = transform.structured.match ops{["linalg.reduce"]} in %h2 : (!transform.any_op) -> !transform.any_op + %inner_r, %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fills_h = transform.structured.match ops{["linalg.fill"]} in %h2 : (!transform.any_op) -> !transform.any_op + %fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op) -> !transform.any_op + %vh = transform.air.herd_vectorize %h2 : (!transform.any_op) -> !transform.any_op + + // Lower vector reductions + %func_final = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_final { + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" + transform.apply_patterns.vector.lower_contraction + transform.apply_patterns.vector.lower_transfer + } : !transform.any_op + transform.apply_cse to %func_final : !transform.any_op + + // AIE2P type casts: mulf/addf/subf bf16-only, divf f32-only + %vh2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %vector_adds = transform.structured.match ops{["arith.addf"]} in %vh2 : (!transform.any_op) -> !transform.any_op + %add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op + %vector_subs = transform.structured.match ops{["arith.subf"]} in %vh2 : (!transform.any_op) -> !transform.any_op + %sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op + %func_s1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_s1_done { + transform.apply_patterns.vector.cast_away_vector_leading_one_dim + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_s1_done : !transform.any_op + transform.yield + } +} diff --git a/utils/mlir-aie-hash.txt b/utils/mlir-aie-hash.txt index 1f8c089..b8a74a2 100644 --- a/utils/mlir-aie-hash.txt +++ b/utils/mlir-aie-hash.txt @@ -1,3 +1,3 @@ -Commit: c5d4befdce2bef7a9219b742000cb2f8d9283f39 -Timestamp: 2026030304 +Commit: c668d2cb679fff72dbc67d21d041679169cb05cd +Timestamp: 2026030506 Version: 0.0.1 diff --git a/utils/mlir-air-hash.txt b/utils/mlir-air-hash.txt index 0226e81..44bed0a 100644 --- a/utils/mlir-air-hash.txt +++ b/utils/mlir-air-hash.txt @@ -1,3 +1,3 @@ -Commit: e2bed3f -Timestamp: 2026030919 +Commit: 4bc5734 +Timestamp: 2026031020 Version: 0.0.1