From c4173cc542860de56b4e30bb5adb501d49d24919 Mon Sep 17 00:00:00 2001 From: erweiw Date: Wed, 11 Mar 2026 23:02:07 -0700 Subject: [PATCH 1/2] Add weighted RMS normalization example Weighted RMS norm: y = x * rsqrt(mean(x^2) + eps) * w Extends rms_norm by multiplying each normalized element by a learned weight vector. Uses BLOCK_M=2 (2D tiling) with 3 memref arguments (X, W, Y) where W has broadcast indexing. The transform script relies on mlir-air PR #1412 (cross-op buffer sharing in linalg_promote) to share the X subview buffer across the squaring and output generics, keeping DMA count within AIE tile limits. Update mlir-air to 98f2fc3 which includes all necessary fixes: - PR #1407: broadcast operand promotion - PR #1408: dead memref.global cleanup - PR #1411: memory space comparison fix - PR #1412: cross-op promotedValueMap with DominanceInfo Co-Authored-By: Claude Opus 4.6 --- examples/generate_readme.py | 6 ++ .../weighted_rms_norm/transform_aie2p.mlir | 94 ++++++++++++++++++ .../weighted_rms_norm/weighted_rms_norm.py | 95 +++++++++++++++++++ utils/mlir-air-hash.txt | 4 +- 4 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 examples/weighted_rms_norm/transform_aie2p.mlir create mode 100644 examples/weighted_rms_norm/weighted_rms_norm.py diff --git a/examples/generate_readme.py b/examples/generate_readme.py index 3eec6bb..0bc8144 100644 --- a/examples/generate_readme.py +++ b/examples/generate_readme.py @@ -92,6 +92,12 @@ "path": "rms_norm", "datatypes": "bf16", }, + { + "category": "Normalization", + "name": "Weighted RMS Normalization", + "path": "weighted_rms_norm", + "datatypes": "bf16", + }, { "category": "Normalization", "name": "Softmax", diff --git a/examples/weighted_rms_norm/transform_aie2p.mlir b/examples/weighted_rms_norm/transform_aie2p.mlir new file mode 100644 index 0000000..ff6d9b3 --- /dev/null +++ b/examples/weighted_rms_norm/transform_aie2p.mlir @@ -0,0 +1,94 @@ +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT + +// Weighted RMS Norm transform for AIE2P. +// y = x * rsqrt(mean(x^2) + eps) * w +// +// No fuse_multi_op_linalg (creates L2 refs). CSE merges duplicate X +// extract_slices so linalg_promote shares one L1 X buffer across ops. + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op + transform.apply_cse to %func0 : !transform.any_op + %reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op + %func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %func1a : !transform.any_op + %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op + %fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %fa : !transform.any_op + + %ag = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %sq, %out = transform.split_handle %ag : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + %ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op + %t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f2, %fl2 = transform.structured.fuse_into_containing_op %sq into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f3, %fl3 = transform.structured.fuse_into_containing_op %fill into %fl2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fill dest → L1 + %fills3 = transform.structured.match ops{["linalg.fill"]} in %fl3 : (!transform.any_op) -> !transform.any_op + %fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3 + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + + // CSE to merge duplicate X extract_slices from fuse_into_containing_op. + // sq and out both slice from the same X tensor -- CSE merges them so + // linalg_promote's promotedValueMap shares one L1 buffer. + %func_cse = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_cse { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_cse : !transform.any_op + + // Bufferize + %fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op + %f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op + transform.apply_cse to %f6 : !transform.any_op + %lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op + %fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op) + %fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op) + + // linalg_promote for ALL operands (X, weight, output) with consistent + // memory space attributes. No tensor-domain promotion needed. + %forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op + %reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op + %all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op + %promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op + + // No post-promote canonicalization (it hoists allocs out of forall region) + + // Herd + DMA + %fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op + %h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op + + // Override herd-internal allocs to L1 + %herd_ms = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %herd_ms_done = transform.air.override_memref_memory_space %herd_ms {memory_space = 2 : i32} : (!transform.any_op) -> !transform.any_op + + %h_fresh = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // Skip copy_to_dma here -- driver's step 3 runs air-copy-to-dma pass + // which handles self-copy elimination (PR #1390). + + // No post-herd cleanup (canonicalization hoists allocs across regions) + + // STOP HERE to check domination + transform.yield + } +} diff --git a/examples/weighted_rms_norm/weighted_rms_norm.py b/examples/weighted_rms_norm/weighted_rms_norm.py new file mode 100644 index 0000000..e25ef0f --- /dev/null +++ b/examples/weighted_rms_norm/weighted_rms_norm.py @@ -0,0 +1,95 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +# Weighted RMS Normalization kernel for AMD XDNA NPU +# Computes: y = x * rsqrt(mean(x^2) + eps) * w per row +# +# Extends rms_norm by multiplying each element by a learned weight vector w. +# Uses BLOCK_M=2 (2D tiling) to avoid the scalar chain issue. + +import torch +import triton +import triton.language as tl +import sys, os + +sys.path.append(os.path.abspath("..")) +import benchmark + +EPS = 1e-5 + + +@triton.jit +def weighted_rms_norm_kernel( + X, + W, + Y, + N: tl.constexpr, + eps: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * BLOCK_M + rows = row_start + tl.arange(0, BLOCK_M) + cols = tl.arange(0, BLOCK_N) + + # Load BLOCK_M rows at once (2D block) + offsets = rows[:, None] * N + cols[None, :] + x = tl.load(X + offsets) + + # Load weight vector (same for all rows, 1D [BLOCK_N]) + w = tl.load(W + cols) + + # Compute mean of squares per row in bf16 + x_f32 = x.to(tl.float32) + x_sq = x_f32 * x_f32 + x_sq_bf16 = x_sq.to(x.dtype) + sum_sq_bf16 = tl.sum(x_sq_bf16, axis=1) + sum_sq = sum_sq_bf16.to(tl.float32) + + # Compute rsqrt per row + mean_sq = sum_sq / N + rstd = tl.math.rsqrt(mean_sq + eps) + + # Normalize and multiply by weight: y = x * rstd * w + w_f32 = w.to(tl.float32) + y = x_f32 * rstd[:, None] * w_f32[None, :] + y = y.to(x.dtype) + tl.store(Y + offsets, y) + + +def bench_weighted_rms_norm(M, N, provider): + device = "cpu" + dtype = torch.bfloat16 + BLOCK_M = 2 + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=dtype) + y = torch.empty(M, N, device=device, dtype=dtype) + if provider == "torch" or provider == "test": + x_f32 = x.float() + w_f32 = w.float() + mean_sq = (x_f32 * x_f32).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(mean_sq + EPS) + y_ref = (x_f32 * rstd * w_f32.unsqueeze(0)).to(dtype) + if provider == "triton" or provider == "test": + grid = (M // BLOCK_M,) + compiled_kernel = weighted_rms_norm_kernel[grid]( + x, + w, + y, + N, + EPS, + BLOCK_M=BLOCK_M, + BLOCK_N=N, + ) + with open("tt.shared.mlir", "w") as f: + f.write(str(compiled_kernel.asm["ttsharedir"])) + if provider == "test": + torch.testing.assert_close(y, y_ref, atol=5e-1, rtol=1e-1) + + +if __name__ == "__main__": + benchmark.select_npu_backend() + for M in [32, 64]: + for N in [256]: + bench_weighted_rms_norm(M, N, "test") diff --git a/utils/mlir-air-hash.txt b/utils/mlir-air-hash.txt index 44bed0a..c2e28d4 100644 --- a/utils/mlir-air-hash.txt +++ b/utils/mlir-air-hash.txt @@ -1,3 +1,3 @@ -Commit: 4bc5734 -Timestamp: 2026031020 +Commit: 98f2fc3 +Timestamp: 2026031205 Version: 0.0.1 From 6d0496119b2dcc1e950dd5de0b7ecb391ca4c59c Mon Sep 17 00:00:00 2001 From: erweiw Date: Wed, 11 Mar 2026 23:08:36 -0700 Subject: [PATCH 2/2] Fix CI format check: update reviewdog/action-suggester to v1.24 reviewdog v0.20.3 (used by action-suggester v1.22) was removed from GitHub releases, causing both clang-format and black check steps to fail with "unable to find 'v0.20.3'". Co-Authored-By: Claude Opus 4.6 --- .github/workflows/format.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 2730208..7371ea3 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -54,7 +54,7 @@ jobs: name: clang_format_diffs - name: Check C/C++ format - uses: reviewdog/action-suggester@v1.22 + uses: reviewdog/action-suggester@v1.24 with: tool_name: clang-format level: error @@ -78,7 +78,7 @@ jobs: - name: Check Python format if: success() || failure() - uses: reviewdog/action-suggester@v1.22 + uses: reviewdog/action-suggester@v1.24 with: tool_name: black level: error