Skip to content

Commit f7cad78

Browse files
authored
Merge pull request #28 from amd/padded-matmul-bf16-emulation
Add padded matmul with BF16 emulation and non-aligned dimensions
2 parents d491a48 + e0ad01f commit f7cad78

5 files changed

Lines changed: 534 additions & 15 deletions

File tree

amd_triton_npu/backend/driver.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def _get_transform_ir_string():
460460
"""
461461

462462

463-
def _ttshared_to_air(mod, gridX, gridY, gridZ):
463+
def _ttshared_to_air(mod, gridX, gridY, gridZ, actual_sizes=None):
464464
# Get Triton-Shared-MLIR as string
465465
with tempfile.TemporaryDirectory() as tmpdir:
466466
dst_path = os.path.join(tmpdir, "airinput.mlir")
@@ -485,17 +485,14 @@ def _ttshared_to_air(mod, gridX, gridY, gridZ):
485485
transform_ir = Module.parse(transform_ir_string, context=air_context)
486486
run_transform(transform_ir, air_module)
487487
# MLIR-AIR compilation step 3: converting to AIR
488+
wrap_params = f"loop-bounds={gridX},{gridY},{gridZ}"
489+
if actual_sizes:
490+
wrap_params += f" actual-sizes={actual_sizes}"
488491
pipeline = (
489492
"builtin.module("
490493
+ ",".join(
491494
[
492-
"func.func(air-wrap-func-with-parallel{loop-bounds="
493-
+ str(gridX)
494-
+ ","
495-
+ str(gridY)
496-
+ ","
497-
+ str(gridZ)
498-
+ "})",
495+
f"func.func(air-wrap-func-with-parallel{{{wrap_params}}})",
499496
"air-par-to-launch{depth=0 has-air-segment=true}",
500497
"canonicalize",
501498
"cse",
@@ -1144,7 +1141,9 @@ def _generate_elf_launcher(constants, signature, kernel_name):
11441141
"""
11451142

11461143

1147-
def compile_module(launcher_src, kernel_placeholder_name, output_format="xclbin"):
1144+
def compile_module(
1145+
launcher_src, kernel_placeholder_name, output_format="xclbin", actual_sizes=None
1146+
):
11481147
py_version = sys.version_info
11491148
if platform.system() == "Windows":
11501149
py_include_dir = os.path.join(sys.base_prefix, "include")
@@ -1188,7 +1187,9 @@ def launch(
11881187
air_proj_path = _get_air_project_path()
11891188
os.makedirs(air_proj_path, exist_ok=True)
11901189
Path(os.path.join(air_proj_path, "asm_src.mlir")).write_bytes(asm_src)
1191-
air_output = _ttshared_to_air(asm_src, gridX, gridY, gridZ)
1190+
air_output = _ttshared_to_air(
1191+
asm_src, gridX, gridY, gridZ, actual_sizes=actual_sizes
1192+
)
11921193
with open(Path(os.path.join(air_proj_path, "asm_air_output.mlir")), "w") as f:
11931194
f.write(str(air_output))
11941195

@@ -1198,6 +1199,7 @@ def launch(
11981199
+ f"_timing_{autotune_time}"
11991200
+ f"_format_{output_format}"
12001201
+ f"_npu_{npu_version}"
1202+
+ f"_bf16emu_{os.getenv('AMD_TRITON_NPU_BF16_EMULATION', '0')}"
12011203
)
12021204
key = hashlib.md5(key_data.encode("utf-8")).hexdigest()
12031205

@@ -1287,6 +1289,10 @@ def launch(
12871289
"--peano=",
12881290
air_mlir_path,
12891291
]
1292+
# Enable bf16 emulation: hardware truncates f32 -> bf16 before
1293+
# multiply, with f32 accumulation.
1294+
if os.getenv("AMD_TRITON_NPU_BF16_EMULATION", "0") == "1":
1295+
aircc_cmd.insert(-1, "--bf16-emulation")
12901296
subprocess.check_call(aircc_cmd)
12911297

12921298
# Cache format-specific artifacts first, then the .so last.
@@ -1384,10 +1390,50 @@ def __init__(self, src, metadata):
13841390
launcher_src = _generate_launcher(
13851391
constants, signature, kernel_placeholder_name
13861392
)
1393+
1394+
# Extract actual problem sizes from constexpr args for padding support.
1395+
# When the kernel has constexpr args named "M" and "N", their values
1396+
# are the actual (non-padded) problem dimensions. These are passed to
1397+
# air-wrap-func-with-parallel as actual-sizes to enable DMA padding
1398+
# via air-split-launch-for-padding on boundary tiles.
1399+
# Only set actual-sizes when dimensions are NOT tile-aligned (i.e.,
1400+
# M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0), to avoid triggering
1401+
# the padding split path when it's not needed.
1402+
actual_sizes = None
1403+
if hasattr(src, "fn") and hasattr(src.fn, "arg_names"):
1404+
arg_names = src.fn.arg_names
1405+
raw_constants = src.constants if hasattr(src, "constants") else {}
1406+
1407+
def _get_constexpr(name):
1408+
"""Look up a constexpr value by arg name, trying multiple key forms."""
1409+
if name not in arg_names:
1410+
return None
1411+
idx = arg_names.index(name)
1412+
# src.constants uses tuple keys (idx,) per ASTSource.__init__,
1413+
# but check multiple forms for robustness across versions.
1414+
for key in [(idx,), idx, name]:
1415+
if key in raw_constants:
1416+
return raw_constants[key]
1417+
return None
1418+
1419+
m_val = _get_constexpr("M")
1420+
n_val = _get_constexpr("N")
1421+
if m_val is not None and n_val is not None:
1422+
bsm = _get_constexpr("BLOCK_SIZE_M")
1423+
bsn = _get_constexpr("BLOCK_SIZE_N")
1424+
needs_padding = True
1425+
if bsm is not None and bsn is not None:
1426+
needs_padding = (m_val % bsm != 0) or (n_val % bsn != 0)
1427+
if needs_padding:
1428+
actual_sizes = f"{m_val},{n_val},1"
1429+
13871430
# Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name
13881431
# in the following launch function.
13891432
self.launch = compile_module(
1390-
launcher_src, kernel_placeholder_name, self.output_format
1433+
launcher_src,
1434+
kernel_placeholder_name,
1435+
self.output_format,
1436+
actual_sizes=actual_sizes,
13911437
)
13921438

13931439
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
# F32 matmul with BF16 emulation for NPU2 (AIE2P/Strix).
5+
# A is stored in K x M layout (transposed). Non-tile-aligned dimensions
6+
# are handled via air-split-launch-for-padding.
7+
#
8+
# Target: NPU2/Strix only (ELF output format, bf16 emulation).
9+
# Data types: F32 inputs/outputs, bf16 emulation on hardware
10+
# (hardware truncates f32 -> bf16 before multiply, f32 accumulation).
11+
# Tile sizes: TILE_M=64, TILE_N=32, HERD=4x4, LAUNCH_TILE=256x128.
12+
13+
import math
14+
import os
15+
import sys
16+
17+
import torch
18+
import triton
19+
import triton.language as tl
20+
import numpy as np
21+
from ml_dtypes import bfloat16
22+
23+
sys.path.append(os.path.abspath(".."))
24+
import benchmark
25+
26+
# === Tile parameters (must match transform_aie2p.mlir) ===
27+
TILE_M = 64
28+
TILE_N = 32
29+
K_L2_TILE = 16
30+
HERD_M = 4
31+
HERD_N = 4
32+
LAUNCH_TILE_M = TILE_M * HERD_M # 256
33+
LAUNCH_TILE_N = TILE_N * HERD_N # 128
34+
INNER_BLOCK = 8
35+
36+
# === Problem dimensions ===
37+
# M and N can be non-tile-aligned; padding is handled by air-split-launch-for-padding.
38+
# K must be a power of 2 (Triton requires tl.arange sizes to be powers of 2)
39+
# and a multiple of K_L2_TILE.
40+
M_actual = 500
41+
N_actual = 500
42+
K_val = 1024
43+
44+
assert K_val % K_L2_TILE == 0, f"K={K_val} must be divisible by K_L2_TILE={K_L2_TILE}"
45+
46+
# === Padded/allocated dimensions ===
47+
M_padded = math.ceil(M_actual / LAUNCH_TILE_M) * LAUNCH_TILE_M # 512
48+
N_padded = math.ceil(N_actual / LAUNCH_TILE_N) * LAUNCH_TILE_N # 512
49+
M_alloc = math.ceil(M_actual / INNER_BLOCK) * INNER_BLOCK # 504
50+
N_alloc = math.ceil(N_actual / INNER_BLOCK) * INNER_BLOCK # 504
51+
52+
53+
@triton.jit
54+
def padded_matmul_kernel(
55+
A,
56+
B,
57+
C,
58+
M: tl.constexpr,
59+
N: tl.constexpr,
60+
K: tl.constexpr,
61+
stride_am: tl.constexpr,
62+
stride_ak: tl.constexpr,
63+
stride_bk: tl.constexpr,
64+
stride_bn: tl.constexpr,
65+
stride_cm: tl.constexpr,
66+
stride_cn: tl.constexpr,
67+
BLOCK_SIZE_M: tl.constexpr,
68+
BLOCK_SIZE_N: tl.constexpr,
69+
BLOCK_SIZE_K: tl.constexpr,
70+
):
71+
pid_m = tl.program_id(0)
72+
pid_n = tl.program_id(1)
73+
74+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
75+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
76+
offs_k = tl.arange(0, BLOCK_SIZE_K)
77+
78+
a_block = tl.load(A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
79+
b_block = tl.load(B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
80+
81+
c_block = tl.dot(a_block, b_block)
82+
83+
tl.store(C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, c_block)
84+
85+
86+
def run_padded_matmul():
87+
np.random.seed(42)
88+
89+
# Host data: A is K x M_alloc (transposed, block-aligned).
90+
# B is K x N_alloc. Zero-padded beyond M_actual/N_actual.
91+
A_np = np.zeros((K_val, M_alloc), dtype=np.float32)
92+
A_np[:, :M_actual] = (np.random.rand(K_val, M_actual) * 4).astype(np.float32)
93+
B_np = np.zeros((K_val, N_alloc), dtype=np.float32)
94+
B_np[:, :N_actual] = (np.random.rand(K_val, N_actual) * 4).astype(np.float32)
95+
96+
A = torch.from_numpy(A_np)
97+
B = torch.from_numpy(B_np)
98+
C = torch.zeros((M_padded, N_padded), dtype=torch.float32)
99+
100+
# Enable BF16 emulation for aircc
101+
os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1"
102+
103+
grid = (
104+
triton.cdiv(M_actual, LAUNCH_TILE_M),
105+
triton.cdiv(N_actual, LAUNCH_TILE_N),
106+
)
107+
108+
compiled_kernel = padded_matmul_kernel[grid](
109+
A,
110+
B,
111+
C,
112+
M_actual,
113+
N_actual,
114+
K_val,
115+
1, # stride_am = 1 (A transposed: stored K x M)
116+
M_alloc, # stride_ak = M_alloc
117+
N_alloc, # stride_bk = N_alloc
118+
1, # stride_bn = 1
119+
N_padded, # stride_cm = N_padded
120+
1, # stride_cn = 1
121+
BLOCK_SIZE_M=LAUNCH_TILE_M, # 256
122+
BLOCK_SIZE_N=LAUNCH_TILE_N, # 128
123+
BLOCK_SIZE_K=K_val, # full K
124+
)
125+
126+
# Dump intermediate IR for debugging
127+
with open("tt.shared.mlir", "w") as f:
128+
f.write(str(compiled_kernel.asm["ttsharedir"]))
129+
130+
# Validate with stochastic sampling.
131+
# Golden: truncate f32 inputs to bf16 (matching hardware bf16_emulation
132+
# truncf_op), then compute dot product with f32 accumulation.
133+
A_bf16 = A_np.astype(bfloat16)
134+
B_bf16 = B_np.astype(bfloat16)
135+
136+
num_samples = 100
137+
sample_m = np.random.randint(0, M_actual, num_samples)
138+
sample_n = np.random.randint(0, N_actual, num_samples)
139+
140+
# Add deterministic boundary-tile samples to catch padding errors.
141+
boundary_m = list(
142+
set(
143+
[
144+
min(M_actual - 1, m)
145+
for m in [M_actual - 1, M_actual - TILE_M + 1, 0]
146+
if m >= 0
147+
]
148+
)
149+
)
150+
boundary_n = list(
151+
set(
152+
[
153+
min(N_actual - 1, n)
154+
for n in [N_actual - 1, N_actual - TILE_N + 1, 0]
155+
if n >= 0
156+
]
157+
)
158+
)
159+
for bm in boundary_m:
160+
for bn in boundary_n:
161+
sample_m = np.append(sample_m, bm)
162+
sample_n = np.append(sample_n, bn)
163+
164+
C_np = C.numpy()
165+
errors = 0
166+
for i in range(len(sample_m)):
167+
m, n = int(sample_m[i]), int(sample_n[i])
168+
expected = np.sum(
169+
A_bf16[:, m].astype(np.float32) * B_bf16[:, n].astype(np.float32),
170+
dtype=np.float32,
171+
)
172+
actual = C_np[m, n]
173+
if not np.isclose(actual, expected, rtol=0.1, atol=10.0):
174+
errors += 1
175+
if errors <= 5:
176+
print(f"Mismatch at ({m}, {n}): actual={actual}, expected={expected}")
177+
178+
total = len(sample_m)
179+
if errors == 0:
180+
print(
181+
f"PASS: All {total} sampled elements match "
182+
f"(M={M_actual}, N={N_actual}, K={K_val})"
183+
)
184+
else:
185+
print(f"FAIL: {errors}/{total} samples mismatched")
186+
sys.exit(1)
187+
188+
189+
if __name__ == "__main__":
190+
benchmark.select_npu_backend()
191+
run_padded_matmul()

0 commit comments

Comments
 (0)