Skip to content

Commit 221ae39

Browse files
erwei-xilinxclaude
andauthored
Fix DMA offset for transposed memrefs and simplify test 55 (#1530)
* Fix DMA offset for transposed memrefs and simplify test 55 to 3 herds Fix a bug in extractOperandsFromReinterpretCast where a single flat offset from a reinterpret_cast was placed in the wrong DMA dimension. For transposed memrefs (e.g. strides [1, 504]), the old code padded zeros at the front, assigning the flat offset to the highest-stride dimension. This caused the offset to be multiplied by the wrong stride, producing out-of-bounds reads and NaN results for multi-launch-tile kernels. The fix finds the stride-1 dimension and places the flat offset there, so the offset is multiplied by 1 (correct) rather than by the column stride. Also simplify test 55 from a 4-herd pattern to a 3-herd pattern by merging the truncf herd into the compute herd. Testing on NPU1 hardware confirms the combined truncf+matmul pattern works correctly on aie2, making the separate truncf herd unnecessary. Add FileCheck unit tests for standalone reinterpret_cast with transposed and normal layouts, including a constant-offset variant. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix stride-1 search direction to avoid NPU2 regression Search backward for the stride-1 dimension so that ambiguous cases (e.g., strides=[1,1] in test_40 triton_vec_add) default to the last dimension, matching the original prepend-zeros behavior. Forward search picked dim 0 for strides=[1,1], which changed the intermediate DMA dimension structure and caused all-zeros output on NPU2. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f876db8 commit 221ae39

4 files changed

Lines changed: 116 additions & 57 deletions

File tree

mlir/lib/Conversion/ConvertToAIRPass.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,36 @@ static void extractOperandsFromReinterpretCast(
9999
sizes.push_back(getValueOrCreateConstantIndexOp(builder, loc, ofr));
100100
for (auto ofr : reinterpretCast.getMixedStrides())
101101
strides.push_back(getValueOrCreateConstantIndexOp(builder, loc, ofr));
102-
while (offsets.size() < sizes.size())
103-
offsets.insert(offsets.begin(),
104-
arith::ConstantIndexOp::create(builder, loc, 0));
102+
// When the reinterpret_cast has fewer offset dimensions than the memref
103+
// rank (e.g., a single flat offset for a 2D memref), we need to place
104+
// the flat offset in the correct dimension. For transposed memrefs
105+
// (stride-1 in the first dimension), the flat offset corresponds to the
106+
// stride-1 dimension, not the last dimension. Find the stride-1
107+
// dimension and place the offset there; pad others with zero.
108+
// Search backward so that ambiguous cases (e.g., strides=[1,1]) default
109+
// to the last dimension, matching the original prepend-zeros behavior.
110+
if (offsets.size() < sizes.size()) {
111+
int strideOneIdx = static_cast<int>(strides.size()) - 1;
112+
for (int i = static_cast<int>(strides.size()) - 1; i >= 0; --i) {
113+
if (auto cst =
114+
getConstantIntValue(reinterpretCast.getMixedStrides()[i])) {
115+
if (*cst == 1) {
116+
strideOneIdx = i;
117+
break;
118+
}
119+
}
120+
}
121+
// Save existing offsets (typically just one flat offset).
122+
SmallVector<Value, 4> existingOffsets(offsets);
123+
offsets.clear();
124+
for (size_t i = 0; i < sizes.size(); ++i) {
125+
if (static_cast<int>(i) == strideOneIdx && !existingOffsets.empty()) {
126+
offsets.push_back(existingOffsets[0]);
127+
} else {
128+
offsets.push_back(arith::ConstantIndexOp::create(builder, loc, 0));
129+
}
130+
}
131+
}
105132
}
106133

107134
// Detect self-copies that would produce invalid self-DMAs. After unwrapping

mlir/test/Conversion/ConvertToAIR/subview_reinterpret_cast_to_dma.mlir

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
// RUN: air-opt %s -air-copy-to-dma | FileCheck %s
99

10-
// Test that air-copy-to-dma correctly handles subview(reinterpret_cast) chains
11-
// by placing the reinterpret_cast flat offset in the stride-1 dimension.
10+
// Test that air-copy-to-dma correctly handles reinterpret_cast offsets
11+
// by placing flat offsets in the stride-1 dimension.
1212

1313
// CHECK-LABEL: func.func @transposed_a
1414
// The transposed A has strides [1, 512]. The reinterpret_cast offset %arg1
@@ -57,3 +57,61 @@ func.func @normal_layout(%arg0: memref<*xf32>, %arg1: index, %arg2: index) {
5757
to memref<16x256xf32, strided<[256, 1], offset: ?>, 1>
5858
return
5959
}
60+
61+
// -----
62+
63+
// Tests for standalone reinterpret_cast (no subview wrapper).
64+
// These exercise extractOperandsFromReinterpretCast directly.
65+
66+
// CHECK-LABEL: func.func @standalone_transposed
67+
// Transposed layout strides [1, 504]. The single flat offset %arg1 must go
68+
// in dim0 (stride=1), producing offsets [%arg1, 0].
69+
// CHECK: air.dma_memcpy_nd
70+
// CHECK-SAME: %arg0[%arg1, %c0]
71+
// CHECK-SAME: [%c256, %c16]
72+
// CHECK-SAME: [%c1, %c504]
73+
func.func @standalone_transposed(%arg0: memref<*xf32>, %arg1: index) {
74+
%alloc = memref.alloc() : memref<256x16xf32, 1>
75+
%rc = memref.reinterpret_cast %arg0 to
76+
offset: [%arg1], sizes: [256, 16], strides: [1, 504]
77+
: memref<*xf32> to memref<256x16xf32, strided<[1, 504], offset: ?>>
78+
memref.copy %rc, %alloc
79+
: memref<256x16xf32, strided<[1, 504], offset: ?>>
80+
to memref<256x16xf32, 1>
81+
return
82+
}
83+
84+
// CHECK-LABEL: func.func @standalone_normal
85+
// Normal layout strides [504, 1]. The single flat offset %arg1 must go
86+
// in dim1 (stride=1), producing offsets [0, %arg1].
87+
// CHECK: air.dma_memcpy_nd
88+
// CHECK-SAME: %arg0[%c0, %arg1]
89+
// CHECK-SAME: [%c16, %c256]
90+
// CHECK-SAME: [%c504, %c1]
91+
func.func @standalone_normal(%arg0: memref<*xf32>, %arg1: index) {
92+
%alloc = memref.alloc() : memref<16x256xf32, 1>
93+
%rc = memref.reinterpret_cast %arg0 to
94+
offset: [%arg1], sizes: [16, 256], strides: [504, 1]
95+
: memref<*xf32> to memref<16x256xf32, strided<[504, 1], offset: ?>>
96+
memref.copy %rc, %alloc
97+
: memref<16x256xf32, strided<[504, 1], offset: ?>>
98+
to memref<16x256xf32, 1>
99+
return
100+
}
101+
102+
// CHECK-LABEL: func.func @standalone_transposed_const_offset
103+
// Transposed layout with constant offset 256. Should produce offsets [c256, 0].
104+
// CHECK: air.dma_memcpy_nd
105+
// CHECK-SAME: %arg0[%c256, %c0]
106+
// CHECK-SAME: [%c64, %c16]
107+
// CHECK-SAME: [%c1, %c504]
108+
func.func @standalone_transposed_const_offset(%arg0: memref<*xf32>) {
109+
%alloc = memref.alloc() : memref<64x16xf32, 1>
110+
%rc = memref.reinterpret_cast %arg0 to
111+
offset: [256], sizes: [64, 16], strides: [1, 504]
112+
: memref<*xf32> to memref<64x16xf32, strided<[1, 504], offset: 256>>
113+
memref.copy %rc, %alloc
114+
: memref<64x16xf32, strided<[1, 504], offset: 256>>
115+
to memref<64x16xf32, 1>
116+
return
117+
}

test/xrt/55_matmul_padding_bf16_npu1/run.py

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
# Non-tile-aligned f32 matmul with bf16 computation on NPU1.
77
# Host data is f32. A is stored in K×M layout (same as test 54).
88
# L3→L2 DMA transposes A from K×M to M×K using f32 strides (4-byte aligned).
9-
# A dedicated truncf herd converts f32→bf16 in L1 before the compute herd.
10-
# This 4-herd pattern (prologue, truncf, compute, epilogue) avoids the
11-
# problematic combined truncf+matmul pattern that fails on NPU1.
9+
# The compute herd DMAs f32 from L2→L1, truncates f32→bf16 in-register,
10+
# and runs block_matmul with bf16 inputs and f32 accumulation.
1211
# Output is f32.
1312
#
13+
# Uses a 3-herd pattern (prologue, compute, epilogue) — the combined
14+
# truncf+matmul pattern works correctly on NPU1 (aie2).
15+
#
1416
# Target: NPU1/Phoenix, aie2 architecture with native 4x8x4 bf16 matmul.
1517

1618
import argparse
@@ -42,7 +44,7 @@
4244
range_ = for_
4345

4446

45-
# Element-wise truncation: f32 → bf16
47+
# Element-wise truncation: f32 → bf16, applied in-register inside compute herd
4648
@linalg_structured_op()
4749
def truncf_op(
4850
A=TensorDef(linalg_lang.TV.T1, S.a, S.b, S.c, S.d, S.e, S.f),
@@ -80,19 +82,19 @@ def build_module(
8082
herd_m,
8183
herd_n,
8284
):
83-
"""Build matmul module with 4-herd pattern: prologue, truncf, compute, epilogue.
85+
"""Build matmul module with 3-herd pattern: prologue, compute, epilogue.
8486
8587
L3 inputs are f32 in K×M / K×N layout. L3→L2 DMA transposes A to M×K.
86-
A dedicated truncf herd converts f32bf16 in L1.
87-
The compute herd reads bf16 from L1 and runs block_matmul.
88-
This avoids the problematic combined truncf+matmul herd pattern on NPU1."""
88+
The compute herd DMAs f32 from L2→L1, truncates f32→bf16 in-register,
89+
and runs block_matmul with bf16 inputs and f32 accumulation.
90+
The combined truncf+matmul pattern works on NPU1 (aie2)."""
8991
assert m % tile_m == 0
9092
assert k % tile_k_l2 == 0
9193
assert tile_k_l2 % tile_k_l1 == 0
9294
assert n % tile_n == 0
9395
assert (
9496
tile_k_l2 == tile_k_l1
95-
), "truncf herd approach requires tile_k_l2 == tile_k_l1"
97+
), "single-herd approach requires tile_k_l2 == tile_k_l1"
9698

9799
mmul_mkn = [4, 8, 4] # aie2 native bf16 matmul
98100

@@ -131,7 +133,7 @@ def build_module(
131133
mmul_mkn[2],
132134
]
133135

134-
# L1 buffers: f32 for DMA input, bf16 for matmul, f32 for output
136+
# L1 buffers: f32 for DMA input, bf16 for matmul input, f32 for accumulator
135137
l1MemrefTyA_f32 = MemRefType.get(
136138
shape=a_l1_size, element_type=xrt_dtype_f32, memory_space=l1_mem_space
137139
)
@@ -274,7 +276,7 @@ def prologue_herd(
274276
src_strides=[n_alloc * tile_k_l2, tile_n, n_alloc, 1],
275277
)
276278

277-
# Herd 2 (truncf): DMA f32 L2→L1, convert f32→bf16 in L1
279+
# Herd 2 (compute): DMA f32 L2→L1, truncf→matmul in one herd
278280
@herd(
279281
name="herd_0",
280282
sizes=[herd_m, herd_n],
@@ -288,7 +290,7 @@ def prologue_herd(
288290
l2_b,
289291
],
290292
)
291-
def truncf_herd(
293+
def compute_herd(
292294
_tx,
293295
_ty,
294296
_sx,
@@ -345,37 +347,9 @@ def truncf_herd(
345347
1,
346348
],
347349
)
348-
# Convert f32→bf16 in L1
350+
# Convert f32→bf16 in L1 and run matmul (combined)
349351
truncf_op(_l1_a_f32, outs=[_l1_a_bf16])
350352
truncf_op(_l1_b_f32, outs=[_l1_b_bf16])
351-
352-
# Herd 3 (compute): read bf16 from L1, block_matmul
353-
@herd(
354-
name="herd_0",
355-
sizes=[herd_m, herd_n],
356-
operands=[
357-
l1_a_f32,
358-
l1_b_f32,
359-
l1_a_bf16,
360-
l1_b_bf16,
361-
l1_c,
362-
l2_a,
363-
l2_b,
364-
],
365-
)
366-
def compute_herd(
367-
_tx,
368-
_ty,
369-
_sx,
370-
_sy,
371-
_af,
372-
_bf,
373-
_l1_a,
374-
_l1_b,
375-
_l1_c,
376-
_l2a,
377-
_l2b,
378-
):
379353
l1_c_sv = subview(
380354
_l1_c,
381355
offsets=[_tx, _ty, 0, 0, 0, 0],
@@ -389,11 +363,11 @@ def compute_herd(
389363
],
390364
strides=[1, 1, 1, 1, 1, 1],
391365
)
392-
block_matmul(_l1_a, _l1_b, outs=[l1_c_sv])
366+
block_matmul(_l1_a_bf16, _l1_b_bf16, outs=[l1_c_sv])
393367

394368
yield_([])
395369

396-
# Herd 4 (epilogue): write C from L1→L2
370+
# Herd 3 (epilogue): write C from L1→L2
397371
@herd(
398372
name="herd_0",
399373
sizes=[herd_m, herd_n],
@@ -541,9 +515,8 @@ def epilogue_herd(
541515
)
542516

543517
# Vectorization transform: tile truncf and block_matmul for vectorization.
544-
# 4 herds → split_handle produces 4 handles.
545-
# Truncf herd (herd2) has 2 truncf_op generics.
546-
# Compute herd (herd3) has 1 block_matmul generic.
518+
# 3 herds → split_handle produces 3 handles.
519+
# Compute herd has 2 truncf_op generics + 1 block_matmul generic.
547520
transform_ir_string = """
548521
module attributes {transform.with_named_sequence} {
549522
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -584,10 +557,10 @@ def epilogue_herd(
584557
transform.structured.tile_using_for %linalg_fills tile_sizes [0, 0, 1, 1]
585558
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
586559
587-
// Vectorize all herds (4 herds: prologue, truncf, compute, epilogue)
560+
// Vectorize all herds (3 herds: prologue, compute, epilogue)
588561
%herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
589562
%vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op
590-
%herd1, %herd2, %herd3, %herd4 = transform.split_handle %vectorized_herds : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
563+
%herd1, %herd2, %herd3 = transform.split_handle %vectorized_herds : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
591564
592565
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
593566
transform.apply_patterns to %func1 {
@@ -605,11 +578,12 @@ def epilogue_herd(
605578
// Re-vectorize after cleanup
606579
%herds_1 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
607580
%vectorized_herds_1 = transform.air.herd_vectorize %herds_1 : (!transform.any_op) -> !transform.any_op
608-
%h1, %h2, %h3, %h4 = transform.split_handle %vectorized_herds_1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
581+
%h1, %h2, %h3 = transform.split_handle %vectorized_herds_1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
609582
610583
// No vector_type_cast needed — accumulator is already f32.
611-
// The arith.extf on bf16 inputs before vector.contract will be
612-
// fused into aievec.matmul by convert-vector-to-aievec in aircc.
584+
// The arith.truncf on f32 inputs and arith.extf before
585+
// vector.contract will be fused into aievec.matmul by
586+
// convert-vector-to-aievec in aircc.
613587
614588
%func2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
615589
transform.apply_patterns to %func2 {

test/xrt/55_matmul_padding_bf16_npu1/run_npu1_peano.lit

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// REQUIRES: ryzen_ai_npu1, peano
55
//
66
// Non-tile-aligned f32 matmul with on-device bf16 truncation on NPU1.
7-
// Inputs are f32; a dedicated truncf herd converts f32→bf16 in L1.
7+
// Inputs are f32; truncf f32→bf16 and matmul run in the same compute herd.
88
// Uses native 4x8x4 bf16 matmul with f32 accumulation.
99
// Host-side padding pads inputs to tile-aligned sizes.
1010
//

0 commit comments

Comments
 (0)