Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
name: clang_format_diffs

- name: Check C/C++ format
uses: reviewdog/action-suggester@v1.22
uses: reviewdog/action-suggester@v1.24
with:
tool_name: clang-format
level: error
Expand All @@ -78,7 +78,7 @@ jobs:

- name: Check Python format
if: success() || failure()
uses: reviewdog/action-suggester@v1.22
uses: reviewdog/action-suggester@v1.24
with:
tool_name: black
level: error
Expand Down
6 changes: 6 additions & 0 deletions examples/generate_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@
"path": "rms_norm",
"datatypes": "bf16",
},
{
"category": "Normalization",
"name": "Weighted RMS Normalization",
"path": "weighted_rms_norm",
"datatypes": "bf16",
},
{
"category": "Normalization",
"name": "Softmax",
Expand Down
94 changes: 94 additions & 0 deletions examples/weighted_rms_norm/transform_aie2p.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT

// Weighted RMS Norm transform for AIE2P.
// y = x * rsqrt(mean(x^2) + eps) * w
//
// No fuse_multi_op_linalg (creates L2 refs). CSE merges duplicate X
// extract_slices so linalg_promote shares one L1 X buffer across ops.

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op
transform.apply_cse to %func0 : !transform.any_op
%reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op
%func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %func1a : !transform.any_op
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
%fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %fa : !transform.any_op

%ag = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%sq, %out = transform.split_handle %ag : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op

%ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
%t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f2, %fl2 = transform.structured.fuse_into_containing_op %sq into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f3, %fl3 = transform.structured.fuse_into_containing_op %fill into %fl2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

// Fill dest → L1
%fills3 = transform.structured.match ops{["linalg.fill"]} in %fl3 : (!transform.any_op) -> !transform.any_op
%fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op

// CSE to merge duplicate X extract_slices from fuse_into_containing_op.
// sq and out both slice from the same X tensor -- CSE merges them so
// linalg_promote's promotedValueMap shares one L1 buffer.
%func_cse = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_cse {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.apply_cse to %func_cse : !transform.any_op

// Bufferize
%fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op
%f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %f6 : !transform.any_op
%lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op
%fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op)
%fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op)

// linalg_promote for ALL operands (X, weight, output) with consistent
// memory space attributes. No tensor-domain promotion needed.
%forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op
%reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op
%all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op
%promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op

// No post-promote canonicalization (it hoists allocs out of forall region)

// Herd + DMA
%fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op
%h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op

// Override herd-internal allocs to L1
%herd_ms = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%herd_ms_done = transform.air.override_memref_memory_space %herd_ms {memory_space = 2 : i32} : (!transform.any_op) -> !transform.any_op

%h_fresh = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op

// Skip copy_to_dma here -- driver's step 3 runs air-copy-to-dma pass
// which handles self-copy elimination (PR #1390).

// No post-herd cleanup (canonicalization hoists allocs across regions)

// STOP HERE to check domination
transform.yield
}
}
95 changes: 95 additions & 0 deletions examples/weighted_rms_norm/weighted_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

# Weighted RMS Normalization kernel for AMD XDNA NPU
# Computes: y = x * rsqrt(mean(x^2) + eps) * w per row
#
# Extends rms_norm by multiplying each element by a learned weight vector w.
# Uses BLOCK_M=2 (2D tiling) to avoid the scalar chain issue.

import torch
import triton
import triton.language as tl
import sys, os

sys.path.append(os.path.abspath(".."))
import benchmark

EPS = 1e-5


@triton.jit
def weighted_rms_norm_kernel(
X,
W,
Y,
N: tl.constexpr,
eps: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
row_start = pid * BLOCK_M
rows = row_start + tl.arange(0, BLOCK_M)
cols = tl.arange(0, BLOCK_N)

# Load BLOCK_M rows at once (2D block)
offsets = rows[:, None] * N + cols[None, :]
x = tl.load(X + offsets)

# Load weight vector (same for all rows, 1D [BLOCK_N])
w = tl.load(W + cols)

# Compute mean of squares per row in bf16
x_f32 = x.to(tl.float32)
x_sq = x_f32 * x_f32
x_sq_bf16 = x_sq.to(x.dtype)
sum_sq_bf16 = tl.sum(x_sq_bf16, axis=1)
sum_sq = sum_sq_bf16.to(tl.float32)

# Compute rsqrt per row
mean_sq = sum_sq / N
rstd = tl.math.rsqrt(mean_sq + eps)

# Normalize and multiply by weight: y = x * rstd * w
w_f32 = w.to(tl.float32)
y = x_f32 * rstd[:, None] * w_f32[None, :]
y = y.to(x.dtype)
tl.store(Y + offsets, y)


def bench_weighted_rms_norm(M, N, provider):
device = "cpu"
dtype = torch.bfloat16
BLOCK_M = 2
x = torch.randn(M, N, device=device, dtype=dtype)
w = torch.randn(N, device=device, dtype=dtype)
y = torch.empty(M, N, device=device, dtype=dtype)
if provider == "torch" or provider == "test":
x_f32 = x.float()
w_f32 = w.float()
mean_sq = (x_f32 * x_f32).mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(mean_sq + EPS)
y_ref = (x_f32 * rstd * w_f32.unsqueeze(0)).to(dtype)
if provider == "triton" or provider == "test":
grid = (M // BLOCK_M,)
compiled_kernel = weighted_rms_norm_kernel[grid](
x,
w,
y,
N,
EPS,
BLOCK_M=BLOCK_M,
BLOCK_N=N,
)
with open("tt.shared.mlir", "w") as f:
f.write(str(compiled_kernel.asm["ttsharedir"]))
if provider == "test":
torch.testing.assert_close(y, y_ref, atol=5e-1, rtol=1e-1)


if __name__ == "__main__":
benchmark.select_npu_backend()
for M in [32, 64]:
for N in [256]:
bench_weighted_rms_norm(M, N, "test")
4 changes: 2 additions & 2 deletions utils/mlir-air-hash.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Commit: 4bc5734
Timestamp: 2026031020
Commit: 98f2fc3
Timestamp: 2026031205
Version: 0.0.1
Loading