Skip to content

Commit 4eba6c2

Browse files
erwei-xilinxclaude
andcommitted
Document why num_threads is hardcoded to 4 (npu1 columns)
Address review feedback: explain in-code that the thread count is the npu1 4-column count, that AIE2P/npu2 scripts sharing this sequence are capped at 4 of 8 columns (correct but under-utilizing), and that a target-aware count is a follow-up. 4 is kept as the npu1-hardware-validated value. Co-Authored-By: Claude Sonnet 4 (1M context) <noreply@anthropic.com>
1 parent b8efaf1 commit 4eba6c2

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

amd_triton_npu/backend/transform_library/elementwise.mlir

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@ transform.named_sequence @fuse_elementwise_and_canonicalize(
2626
// With tile_sizes the width was ceildiv(block, tile): a single trip when the
2727
// block fits one tile (the forall is then folded away, leaving no herd) and
2828
// wider than the target's column count for large blocks (placement fails). A
29-
// fixed thread count avoids both. NOTE: the count below is sized for the npu1
30-
// 4-column array; targets with more columns (AIE2P) may want a larger value.
29+
// fixed thread count avoids both.
30+
//
31+
// The count is intentionally hardcoded to 4 for the npu1 4-column array. This
32+
// sequence is also included by AIE2P (npu2) elementwise scripts, where 4 caps
33+
// the herd at 4 of the 8 available columns -- correct, but it under-utilizes
34+
// the array for large blocks. Making the count target-aware (a per-target
35+
// sequence, or a driver-injected parameter) is left as a follow-up; 4 is kept
36+
// for now because it is the value validated on npu1 hardware.
3137
transform.named_sequence @flatten_tile_forall(
3238
%module: !transform.any_op {transform.readonly}) {
3339
%op = transform.structured.match ops{["linalg.generic"]} in %module
@@ -41,6 +47,7 @@ transform.named_sequence @flatten_tile_forall(
4147
%op_1 = transform.structured.match ops{["linalg.generic"]} in %module
4248
: (!transform.any_op) -> !transform.any_op
4349
%tiled_op_1, %forall_op_1 =
50+
// 4 = npu1 column count (hardcoded; see note above for AIE2P/npu2).
4451
transform.structured.tile_using_forall %op_1 num_threads [4]
4552
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
4653
transform.yield

0 commit comments

Comments
 (0)