Skip to content

Commit c4173cc

Browse files
erwei-xilinxclaude
andcommitted
Add weighted RMS normalization example
Weighted RMS norm: y = x * rsqrt(mean(x^2) + eps) * w Extends rms_norm by multiplying each normalized element by a learned weight vector. Uses BLOCK_M=2 (2D tiling) with 3 memref arguments (X, W, Y) where W has broadcast indexing. The transform script relies on mlir-air PR #1412 (cross-op buffer sharing in linalg_promote) to share the X subview buffer across the squaring and output generics, keeping DMA count within AIE tile limits. Update mlir-air to 98f2fc3 which includes all necessary fixes: - PR #1407: broadcast operand promotion - PR #1408: dead memref.global cleanup - PR #1411: memory space comparison fix - PR #1412: cross-op promotedValueMap with DominanceInfo Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d8186c4 commit c4173cc

4 files changed

Lines changed: 197 additions & 2 deletions

File tree

examples/generate_readme.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@
9292
"path": "rms_norm",
9393
"datatypes": "bf16",
9494
},
95+
{
96+
"category": "Normalization",
97+
"name": "Weighted RMS Normalization",
98+
"path": "weighted_rms_norm",
99+
"datatypes": "bf16",
100+
},
95101
{
96102
"category": "Normalization",
97103
"name": "Softmax",
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
// Weighted RMS Norm transform for AIE2P.
5+
// y = x * rsqrt(mean(x^2) + eps) * w
6+
//
7+
// No fuse_multi_op_linalg (creates L2 refs). CSE merges duplicate X
8+
// extract_slices so linalg_promote shares one L1 X buffer across ops.
9+
10+
module attributes {transform.with_named_sequence} {
11+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
12+
%func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
13+
transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization
14+
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op
15+
transform.apply_cse to %func0 : !transform.any_op
16+
%reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
17+
%tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op
18+
%func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
19+
transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op
20+
transform.apply_cse to %func1a : !transform.any_op
21+
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
22+
%f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
23+
%fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
24+
transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op
25+
transform.apply_cse to %fa : !transform.any_op
26+
27+
%ag = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
28+
%sq, %out = transform.split_handle %ag : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
29+
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
30+
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
31+
32+
%ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
33+
%t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
34+
%f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
35+
%f2, %fl2 = transform.structured.fuse_into_containing_op %sq into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
36+
%f3, %fl3 = transform.structured.fuse_into_containing_op %fill into %fl2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
37+
38+
// Fill dest → L1
39+
%fills3 = transform.structured.match ops{["linalg.fill"]} in %fl3 : (!transform.any_op) -> !transform.any_op
40+
%fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3
41+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
42+
43+
// CSE to merge duplicate X extract_slices from fuse_into_containing_op.
44+
// sq and out both slice from the same X tensor -- CSE merges them so
45+
// linalg_promote's promotedValueMap shares one L1 buffer.
46+
%func_cse = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
47+
transform.apply_patterns to %func_cse {
48+
transform.apply_patterns.linalg.tiling_canonicalization
49+
transform.apply_patterns.scf.for_loop_canonicalization
50+
transform.apply_patterns.canonicalization
51+
} : !transform.any_op
52+
transform.apply_cse to %func_cse : !transform.any_op
53+
54+
// Bufferize
55+
%fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56+
%fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op
57+
%f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
58+
transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op
59+
transform.apply_cse to %f6 : !transform.any_op
60+
%lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
61+
%mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op
62+
%fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op)
63+
%fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op)
64+
65+
// linalg_promote for ALL operands (X, weight, output) with consistent
66+
// memory space attributes. No tensor-domain promotion needed.
67+
%forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
68+
%gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op
69+
%reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op
70+
%all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op
71+
%promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op
72+
73+
// No post-promote canonicalization (it hoists allocs out of forall region)
74+
75+
// Herd + DMA
76+
%fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
77+
%pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op
78+
%h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op
79+
80+
// Override herd-internal allocs to L1
81+
%herd_ms = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
82+
%herd_ms_done = transform.air.override_memref_memory_space %herd_ms {memory_space = 2 : i32} : (!transform.any_op) -> !transform.any_op
83+
84+
%h_fresh = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
85+
86+
// Skip copy_to_dma here -- driver's step 3 runs air-copy-to-dma pass
87+
// which handles self-copy elimination (PR #1390).
88+
89+
// No post-herd cleanup (canonicalization hoists allocs across regions)
90+
91+
// STOP HERE to check domination
92+
transform.yield
93+
}
94+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
# Weighted RMS Normalization kernel for AMD XDNA NPU
5+
# Computes: y = x * rsqrt(mean(x^2) + eps) * w per row
6+
#
7+
# Extends rms_norm by multiplying each element by a learned weight vector w.
8+
# Uses BLOCK_M=2 (2D tiling) to avoid the scalar chain issue.
9+
10+
import torch
11+
import triton
12+
import triton.language as tl
13+
import sys, os
14+
15+
sys.path.append(os.path.abspath(".."))
16+
import benchmark
17+
18+
EPS = 1e-5
19+
20+
21+
@triton.jit
22+
def weighted_rms_norm_kernel(
23+
X,
24+
W,
25+
Y,
26+
N: tl.constexpr,
27+
eps: tl.constexpr,
28+
BLOCK_M: tl.constexpr,
29+
BLOCK_N: tl.constexpr,
30+
):
31+
pid = tl.program_id(0)
32+
row_start = pid * BLOCK_M
33+
rows = row_start + tl.arange(0, BLOCK_M)
34+
cols = tl.arange(0, BLOCK_N)
35+
36+
# Load BLOCK_M rows at once (2D block)
37+
offsets = rows[:, None] * N + cols[None, :]
38+
x = tl.load(X + offsets)
39+
40+
# Load weight vector (same for all rows, 1D [BLOCK_N])
41+
w = tl.load(W + cols)
42+
43+
# Compute mean of squares per row in bf16
44+
x_f32 = x.to(tl.float32)
45+
x_sq = x_f32 * x_f32
46+
x_sq_bf16 = x_sq.to(x.dtype)
47+
sum_sq_bf16 = tl.sum(x_sq_bf16, axis=1)
48+
sum_sq = sum_sq_bf16.to(tl.float32)
49+
50+
# Compute rsqrt per row
51+
mean_sq = sum_sq / N
52+
rstd = tl.math.rsqrt(mean_sq + eps)
53+
54+
# Normalize and multiply by weight: y = x * rstd * w
55+
w_f32 = w.to(tl.float32)
56+
y = x_f32 * rstd[:, None] * w_f32[None, :]
57+
y = y.to(x.dtype)
58+
tl.store(Y + offsets, y)
59+
60+
61+
def bench_weighted_rms_norm(M, N, provider):
62+
device = "cpu"
63+
dtype = torch.bfloat16
64+
BLOCK_M = 2
65+
x = torch.randn(M, N, device=device, dtype=dtype)
66+
w = torch.randn(N, device=device, dtype=dtype)
67+
y = torch.empty(M, N, device=device, dtype=dtype)
68+
if provider == "torch" or provider == "test":
69+
x_f32 = x.float()
70+
w_f32 = w.float()
71+
mean_sq = (x_f32 * x_f32).mean(dim=-1, keepdim=True)
72+
rstd = torch.rsqrt(mean_sq + EPS)
73+
y_ref = (x_f32 * rstd * w_f32.unsqueeze(0)).to(dtype)
74+
if provider == "triton" or provider == "test":
75+
grid = (M // BLOCK_M,)
76+
compiled_kernel = weighted_rms_norm_kernel[grid](
77+
x,
78+
w,
79+
y,
80+
N,
81+
EPS,
82+
BLOCK_M=BLOCK_M,
83+
BLOCK_N=N,
84+
)
85+
with open("tt.shared.mlir", "w") as f:
86+
f.write(str(compiled_kernel.asm["ttsharedir"]))
87+
if provider == "test":
88+
torch.testing.assert_close(y, y_ref, atol=5e-1, rtol=1e-1)
89+
90+
91+
if __name__ == "__main__":
92+
benchmark.select_npu_backend()
93+
for M in [32, 64]:
94+
for N in [256]:
95+
bench_weighted_rms_norm(M, N, "test")

utils/mlir-air-hash.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Commit: 4bc5734
2-
Timestamp: 2026031020
1+
Commit: 98f2fc3
2+
Timestamp: 2026031205
33
Version: 0.0.1

0 commit comments

Comments
 (0)