|
| 1 | +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +# F32 matmul with BF16 emulation for NPU2 (AIE2P/Strix). |
| 5 | +# A is stored in K x M layout (transposed). Non-tile-aligned dimensions |
| 6 | +# are handled via air-split-launch-for-padding. |
| 7 | +# |
| 8 | +# Target: NPU2/Strix only (ELF output format, bf16 emulation). |
| 9 | +# Data types: F32 inputs/outputs, bf16 emulation on hardware |
| 10 | +# (hardware truncates f32 -> bf16 before multiply, f32 accumulation). |
| 11 | +# Tile sizes: TILE_M=64, TILE_N=32, HERD=4x4, LAUNCH_TILE=256x128. |
| 12 | + |
| 13 | +import math |
| 14 | +import os |
| 15 | +import sys |
| 16 | + |
| 17 | +import torch |
| 18 | +import triton |
| 19 | +import triton.language as tl |
| 20 | +import numpy as np |
| 21 | +from ml_dtypes import bfloat16 |
| 22 | + |
| 23 | +sys.path.append(os.path.abspath("..")) |
| 24 | +import benchmark |
| 25 | + |
| 26 | +# === Tile parameters (must match transform_aie2p.mlir) === |
| 27 | +TILE_M = 64 |
| 28 | +TILE_N = 32 |
| 29 | +K_L2_TILE = 16 |
| 30 | +HERD_M = 4 |
| 31 | +HERD_N = 4 |
| 32 | +LAUNCH_TILE_M = TILE_M * HERD_M # 256 |
| 33 | +LAUNCH_TILE_N = TILE_N * HERD_N # 128 |
| 34 | +INNER_BLOCK = 8 |
| 35 | + |
| 36 | +# === Problem dimensions === |
| 37 | +# M and N can be non-tile-aligned; padding is handled by air-split-launch-for-padding. |
| 38 | +# K must be a power of 2 (Triton requires tl.arange sizes to be powers of 2) |
| 39 | +# and a multiple of K_L2_TILE. |
| 40 | +M_actual = 500 |
| 41 | +N_actual = 500 |
| 42 | +K_val = 1024 |
| 43 | + |
| 44 | +assert K_val % K_L2_TILE == 0, f"K={K_val} must be divisible by K_L2_TILE={K_L2_TILE}" |
| 45 | + |
| 46 | +# === Padded/allocated dimensions === |
| 47 | +M_padded = math.ceil(M_actual / LAUNCH_TILE_M) * LAUNCH_TILE_M # 512 |
| 48 | +N_padded = math.ceil(N_actual / LAUNCH_TILE_N) * LAUNCH_TILE_N # 512 |
| 49 | +M_alloc = math.ceil(M_actual / INNER_BLOCK) * INNER_BLOCK # 504 |
| 50 | +N_alloc = math.ceil(N_actual / INNER_BLOCK) * INNER_BLOCK # 504 |
| 51 | + |
| 52 | + |
| 53 | +@triton.jit |
| 54 | +def padded_matmul_kernel( |
| 55 | + A, |
| 56 | + B, |
| 57 | + C, |
| 58 | + M: tl.constexpr, |
| 59 | + N: tl.constexpr, |
| 60 | + K: tl.constexpr, |
| 61 | + stride_am: tl.constexpr, |
| 62 | + stride_ak: tl.constexpr, |
| 63 | + stride_bk: tl.constexpr, |
| 64 | + stride_bn: tl.constexpr, |
| 65 | + stride_cm: tl.constexpr, |
| 66 | + stride_cn: tl.constexpr, |
| 67 | + BLOCK_SIZE_M: tl.constexpr, |
| 68 | + BLOCK_SIZE_N: tl.constexpr, |
| 69 | + BLOCK_SIZE_K: tl.constexpr, |
| 70 | +): |
| 71 | + pid_m = tl.program_id(0) |
| 72 | + pid_n = tl.program_id(1) |
| 73 | + |
| 74 | + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 75 | + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 76 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 77 | + |
| 78 | + a_block = tl.load(A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 79 | + b_block = tl.load(B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) |
| 80 | + |
| 81 | + c_block = tl.dot(a_block, b_block) |
| 82 | + |
| 83 | + tl.store(C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, c_block) |
| 84 | + |
| 85 | + |
| 86 | +def run_padded_matmul(): |
| 87 | + np.random.seed(42) |
| 88 | + |
| 89 | + # Host data: A is K x M_alloc (transposed, block-aligned). |
| 90 | + # B is K x N_alloc. Zero-padded beyond M_actual/N_actual. |
| 91 | + A_np = np.zeros((K_val, M_alloc), dtype=np.float32) |
| 92 | + A_np[:, :M_actual] = (np.random.rand(K_val, M_actual) * 4).astype(np.float32) |
| 93 | + B_np = np.zeros((K_val, N_alloc), dtype=np.float32) |
| 94 | + B_np[:, :N_actual] = (np.random.rand(K_val, N_actual) * 4).astype(np.float32) |
| 95 | + |
| 96 | + A = torch.from_numpy(A_np) |
| 97 | + B = torch.from_numpy(B_np) |
| 98 | + C = torch.zeros((M_padded, N_padded), dtype=torch.float32) |
| 99 | + |
| 100 | + # Enable BF16 emulation for aircc |
| 101 | + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" |
| 102 | + |
| 103 | + grid = ( |
| 104 | + triton.cdiv(M_actual, LAUNCH_TILE_M), |
| 105 | + triton.cdiv(N_actual, LAUNCH_TILE_N), |
| 106 | + ) |
| 107 | + |
| 108 | + compiled_kernel = padded_matmul_kernel[grid]( |
| 109 | + A, |
| 110 | + B, |
| 111 | + C, |
| 112 | + M_actual, |
| 113 | + N_actual, |
| 114 | + K_val, |
| 115 | + 1, # stride_am = 1 (A transposed: stored K x M) |
| 116 | + M_alloc, # stride_ak = M_alloc |
| 117 | + N_alloc, # stride_bk = N_alloc |
| 118 | + 1, # stride_bn = 1 |
| 119 | + N_padded, # stride_cm = N_padded |
| 120 | + 1, # stride_cn = 1 |
| 121 | + BLOCK_SIZE_M=LAUNCH_TILE_M, # 256 |
| 122 | + BLOCK_SIZE_N=LAUNCH_TILE_N, # 128 |
| 123 | + BLOCK_SIZE_K=K_val, # full K |
| 124 | + ) |
| 125 | + |
| 126 | + # Dump intermediate IR for debugging |
| 127 | + with open("tt.shared.mlir", "w") as f: |
| 128 | + f.write(str(compiled_kernel.asm["ttsharedir"])) |
| 129 | + |
| 130 | + # Validate with stochastic sampling. |
| 131 | + # Golden: truncate f32 inputs to bf16 (matching hardware bf16_emulation |
| 132 | + # truncf_op), then compute dot product with f32 accumulation. |
| 133 | + A_bf16 = A_np.astype(bfloat16) |
| 134 | + B_bf16 = B_np.astype(bfloat16) |
| 135 | + |
| 136 | + num_samples = 100 |
| 137 | + sample_m = np.random.randint(0, M_actual, num_samples) |
| 138 | + sample_n = np.random.randint(0, N_actual, num_samples) |
| 139 | + |
| 140 | + # Add deterministic boundary-tile samples to catch padding errors. |
| 141 | + boundary_m = list( |
| 142 | + set( |
| 143 | + [ |
| 144 | + min(M_actual - 1, m) |
| 145 | + for m in [M_actual - 1, M_actual - TILE_M + 1, 0] |
| 146 | + if m >= 0 |
| 147 | + ] |
| 148 | + ) |
| 149 | + ) |
| 150 | + boundary_n = list( |
| 151 | + set( |
| 152 | + [ |
| 153 | + min(N_actual - 1, n) |
| 154 | + for n in [N_actual - 1, N_actual - TILE_N + 1, 0] |
| 155 | + if n >= 0 |
| 156 | + ] |
| 157 | + ) |
| 158 | + ) |
| 159 | + for bm in boundary_m: |
| 160 | + for bn in boundary_n: |
| 161 | + sample_m = np.append(sample_m, bm) |
| 162 | + sample_n = np.append(sample_n, bn) |
| 163 | + |
| 164 | + C_np = C.numpy() |
| 165 | + errors = 0 |
| 166 | + for i in range(len(sample_m)): |
| 167 | + m, n = int(sample_m[i]), int(sample_n[i]) |
| 168 | + expected = np.sum( |
| 169 | + A_bf16[:, m].astype(np.float32) * B_bf16[:, n].astype(np.float32), |
| 170 | + dtype=np.float32, |
| 171 | + ) |
| 172 | + actual = C_np[m, n] |
| 173 | + if not np.isclose(actual, expected, rtol=0.1, atol=10.0): |
| 174 | + errors += 1 |
| 175 | + if errors <= 5: |
| 176 | + print(f"Mismatch at ({m}, {n}): actual={actual}, expected={expected}") |
| 177 | + |
| 178 | + total = len(sample_m) |
| 179 | + if errors == 0: |
| 180 | + print( |
| 181 | + f"PASS: All {total} sampled elements match " |
| 182 | + f"(M={M_actual}, N={N_actual}, K={K_val})" |
| 183 | + ) |
| 184 | + else: |
| 185 | + print(f"FAIL: {errors}/{total} samples mismatched") |
| 186 | + sys.exit(1) |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == "__main__": |
| 190 | + benchmark.select_npu_backend() |
| 191 | + run_padded_matmul() |
0 commit comments