Skip to content

Commit fc10388

Browse files
committed
Add AIE2P (npu2) elementwise herd width: 8 threads for 8-column array
The shared @flatten_tile_forall sequence tiles into num_threads [4] for npu1's 4-column array. On npu2 (AIE2P / Strix) the array is 8 columns wide, so 4 threads leave half the array idle. Add @flatten_tile_forall_aie2p, an 8-thread variant, and point every AIE2P elementwise script (vec-add, relu, silu, gelu, sigmoid, swiglu, axpy, leaky_relu) at it. The npu1 sequence and the aie2 scripts are unchanged. NOTE: correct multi-program (grid > 1) execution on npu2 depends on mlir-air PR #1696 (Xilinx/mlir-air), which fixes air-split-l2-memref dropping the per-iteration air.launch base offset when it splits the L2 buffer across the 8 columns. The 8-way split added here is what exposes that bug. Without an mlir-air build containing the fix, grid > 1 elementwise kernels move only the first program's data on npu2; grid == 1 (one large block split across the herd) is correct regardless. See the dependency note on @flatten_tile_forall_aie2p in elementwise.mlir.
1 parent 4eba6c2 commit fc10388

9 files changed

Lines changed: 41 additions & 15 deletions

File tree

amd_triton_npu/backend/transform_library/elementwise.mlir

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ transform.named_sequence @fuse_elementwise_and_canonicalize(
2828
// wider than the target's column count for large blocks (placement fails). A
2929
// fixed thread count avoids both.
3030
//
31-
// The count is intentionally hardcoded to 4 for the npu1 4-column array. This
32-
// sequence is also included by AIE2P (npu2) elementwise scripts, where 4 caps
33-
// the herd at 4 of the 8 available columns -- correct, but it under-utilizes
34-
// the array for large blocks. Making the count target-aware (a per-target
35-
// sequence, or a driver-injected parameter) is left as a follow-up; 4 is kept
36-
// for now because it is the value validated on npu1 hardware.
31+
// The count is hardcoded to 4 for the npu1 4-column array. AIE2P (npu2)
32+
// elementwise scripts include @flatten_tile_forall_aie2p below instead, which
33+
// tiles into 8 threads to fill the 8-column Strix array.
3734
transform.named_sequence @flatten_tile_forall(
3835
%module: !transform.any_op {transform.readonly}) {
3936
%op = transform.structured.match ops{["linalg.generic"]} in %module
@@ -47,12 +44,41 @@ transform.named_sequence @flatten_tile_forall(
4744
%op_1 = transform.structured.match ops{["linalg.generic"]} in %module
4845
: (!transform.any_op) -> !transform.any_op
4946
%tiled_op_1, %forall_op_1 =
50-
// 4 = npu1 column count (hardcoded; see note above for AIE2P/npu2).
47+
// 4 = npu1 column count (hardcoded; AIE2P uses the _aie2p variant below).
5148
transform.structured.tile_using_forall %op_1 num_threads [4]
5249
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
5350
transform.yield
5451
}
5552

53+
// AIE2P (npu2) variant of @flatten_tile_forall: 8 threads for the 8-column
54+
// Strix array instead of npu1's 4. Identical otherwise.
55+
//
56+
// DEPENDS ON mlir-air PR #1696 (Xilinx/mlir-air): "Preserve launch base offset
57+
// when splitting L2 memref". The 8-way split this triggers exposed a bug in
58+
// air-split-l2-memref where the per-iteration air.launch base offset was
59+
// dropped, so a multi-program (grid > 1) elementwise kernel silently moved
60+
// only the first program's data on npu2. Without an mlir-air build that
61+
// contains that fix, grid > 1 produces wrong results here; grid == 1 (a single
62+
// large block split across the herd) is correct regardless.
63+
transform.named_sequence @flatten_tile_forall_aie2p(
64+
%module: !transform.any_op {transform.readonly}) {
65+
%op = transform.structured.match ops{["linalg.generic"]} in %module
66+
: (!transform.any_op) -> !transform.any_op
67+
%op_flattened = transform.structured.flatten_elementwise %op
68+
: (!transform.any_op) -> !transform.any_op
69+
%op_res_shared, %new_op = transform.structured.bufferize_to_allocation
70+
%op_flattened
71+
{memory_space = 1, bufferize_destination_only, emit_dealloc}
72+
: !transform.any_op
73+
%op_1 = transform.structured.match ops{["linalg.generic"]} in %module
74+
: (!transform.any_op) -> !transform.any_op
75+
%tiled_op_1, %forall_op_1 =
76+
// 8 = npu2/AIE2P column count (Strix). See dependency note above.
77+
transform.structured.tile_using_forall %op_1 num_threads [8]
78+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
79+
transform.yield
80+
}
81+
5682
// Unary variant: 1 input + 1 output = 2 operands (relu, sigmoid, silu, gelu).
5783
transform.named_sequence @pad_and_promote_unary_bf16(
5884
%module: !transform.any_op {transform.readonly}) {

examples/axpy/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module attributes {transform.with_named_sequence} {
1616
(%arg1) : (!transform.any_op) -> ()
1717
transform.include @fuse_elementwise_and_canonicalize failures(propagate)
1818
(%arg1) : (!transform.any_op) -> ()
19-
transform.include @flatten_tile_forall failures(propagate)
19+
transform.include @flatten_tile_forall_aie2p failures(propagate)
2020
(%arg1) : (!transform.any_op) -> ()
2121
transform.include @canonicalize_with_cse failures(propagate)
2222
(%arg1) : (!transform.any_op) -> ()

examples/gelu/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ module attributes {transform.with_named_sequence} {
1818
(%arg1) : (!transform.any_op) -> ()
1919
transform.include @fuse_elementwise_and_canonicalize failures(propagate)
2020
(%arg1) : (!transform.any_op) -> ()
21-
transform.include @flatten_tile_forall failures(propagate)
21+
transform.include @flatten_tile_forall_aie2p failures(propagate)
2222
(%arg1) : (!transform.any_op) -> ()
2323
transform.include @canonicalize_with_cse failures(propagate)
2424
(%arg1) : (!transform.any_op) -> ()

examples/leaky_relu/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module attributes {transform.with_named_sequence} {
1616
(%arg1) : (!transform.any_op) -> ()
1717
transform.include @fuse_elementwise_and_canonicalize failures(propagate)
1818
(%arg1) : (!transform.any_op) -> ()
19-
transform.include @flatten_tile_forall failures(propagate)
19+
transform.include @flatten_tile_forall_aie2p failures(propagate)
2020
(%arg1) : (!transform.any_op) -> ()
2121
transform.include @canonicalize_with_cse failures(propagate)
2222
(%arg1) : (!transform.any_op) -> ()

examples/relu/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module attributes {transform.with_named_sequence} {
1717
(%arg1) : (!transform.any_op) -> ()
1818
transform.include @fuse_elementwise_and_canonicalize failures(propagate)
1919
(%arg1) : (!transform.any_op) -> ()
20-
transform.include @flatten_tile_forall failures(propagate)
20+
transform.include @flatten_tile_forall_aie2p failures(propagate)
2121
(%arg1) : (!transform.any_op) -> ()
2222
transform.include @canonicalize_with_cse failures(propagate)
2323
(%arg1) : (!transform.any_op) -> ()

examples/sigmoid/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ module attributes {transform.with_named_sequence} {
2525
(%arg1) : (!transform.any_op) -> ()
2626

2727
// Phase 3: Flatten + tile forall [256]
28-
transform.include @flatten_tile_forall failures(propagate)
28+
transform.include @flatten_tile_forall_aie2p failures(propagate)
2929
(%arg1) : (!transform.any_op) -> ()
3030

3131
// Phase 4: Canonicalization

examples/silu/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module attributes {transform.with_named_sequence} {
1717
(%arg1) : (!transform.any_op) -> ()
1818
transform.include @fuse_elementwise_and_canonicalize failures(propagate)
1919
(%arg1) : (!transform.any_op) -> ()
20-
transform.include @flatten_tile_forall failures(propagate)
20+
transform.include @flatten_tile_forall_aie2p failures(propagate)
2121
(%arg1) : (!transform.any_op) -> ()
2222
transform.include @canonicalize_with_cse failures(propagate)
2323
(%arg1) : (!transform.any_op) -> ()

examples/swiglu/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module attributes {transform.with_named_sequence} {
1616
(%arg1) : (!transform.any_op) -> ()
1717
transform.include @fuse_elementwise_and_canonicalize failures(propagate)
1818
(%arg1) : (!transform.any_op) -> ()
19-
transform.include @flatten_tile_forall failures(propagate)
19+
transform.include @flatten_tile_forall_aie2p failures(propagate)
2020
(%arg1) : (!transform.any_op) -> ()
2121
transform.include @canonicalize_with_cse failures(propagate)
2222
(%arg1) : (!transform.any_op) -> ()

examples/vec-add/transform_aie2p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ module attributes {transform.with_named_sequence} {
1414
%arg1: !transform.any_op {transform.readonly}) {
1515

1616
// No Phase 1/2 for vec-add (no elementwise fusion needed)
17-
transform.include @flatten_tile_forall failures(propagate)
17+
transform.include @flatten_tile_forall_aie2p failures(propagate)
1818
(%arg1) : (!transform.any_op) -> ()
1919
transform.include @canonicalize_with_cse failures(propagate)
2020
(%arg1) : (!transform.any_op) -> ()

0 commit comments

Comments
 (0)