Skip to content

Commit 4e466d8

Browse files
authored
Fix infinite loop in dotCanBeProperlyAsync (#9282)
The `checkOperand` traversal can run forever if either: 1. A block argument participates in a cycle containing only permitted instructions. 2. A block argument is defined outside of `forOp`, in which case we never advance transitiveOperand. To fix (1), track the set of visited block arguments. If we visit the same block argument again, that means that we are in a cycle originating in the init value of the iter arg, which is outside the loop. To fix (2), check for values defined outside the loop as we iterate. This way, we know that if we are evaluating a block argument, it must be an iter arg to the loop.
1 parent 2e42027 commit 4e466d8

2 files changed

Lines changed: 76 additions & 13 deletions

File tree

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -438,21 +438,30 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
438438
// come from an MemDescIndex op. Only ConvertLayout and view ops are
439439
// allowed in between.
440440
Value transitiveOperand = operand;
441-
while (isa_and_nonnull<ttg::ConvertLayoutOp, ttg::MemDescTransOp,
442-
ttg::MemDescReshapeOp, ttg::MemDescSubsliceOp>(
443-
transitiveOperand.getDefiningOp()) ||
444-
isa<BlockArgument>(transitiveOperand)) {
445-
auto blockArg = dyn_cast<BlockArgument>(transitiveOperand);
446-
if (blockArg && blockArg.getOwner() == forOp.getBody()) {
447-
transitiveOperand =
448-
cast<scf::YieldOp>(blockArg.getOwner()->getTerminator())
449-
.getOperand(blockArg.getArgNumber() - 1);
450-
} else if (Operation *def = transitiveOperand.getDefiningOp()) {
451-
transitiveOperand = def->getOperand(0);
441+
DenseSet<BlockArgument> visitedBlockArgs;
442+
while (!forOp.isDefinedOutsideOfLoop(transitiveOperand)) {
443+
if (auto *definingOp = transitiveOperand.getDefiningOp()) {
444+
if (isa<ttg::ConvertLayoutOp, ttg::MemDescTransOp,
445+
ttg::MemDescReshapeOp, ttg::MemDescSubsliceOp>(definingOp)) {
446+
transitiveOperand = definingOp->getOperand(0);
447+
continue;
448+
}
449+
return isa<ttg::MemDescIndexOp>(definingOp);
452450
}
451+
auto blockArg = cast<BlockArgument>(transitiveOperand);
452+
// We know that the dotOp is a top level operation in the loop body, and
453+
// we have already checked that transitiveOperand is not defined outside
454+
// the loop, therefore the block arg must be an iter arg of this loop.
455+
assert(dotOp->getParentOp() == forOp);
456+
assert(blockArg.getOwner() == forOp.getBody());
457+
// If we have already visited this block arg, that means that it
458+
// participates in a cycle containing only permitted operations. The
459+
// initial value therefore originates outside the loop, making this valid.
460+
if (!visitedBlockArgs.insert(blockArg).second)
461+
return true;
462+
transitiveOperand = forOp.getTiedLoopYieldedValue(blockArg)->get();
453463
}
454-
return forOp.isDefinedOutsideOfLoop(transitiveOperand) ||
455-
transitiveOperand.getDefiningOp<ttg::MemDescIndexOp>();
464+
return true;
456465
};
457466

458467
// Rule 1: All shmem operands are multi-buffered.

test/TritonGPU/loop-pipeline-hopper.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,3 +1080,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
10801080
tt.return %2 : tensor<64x32xf32, #mma>
10811081
}
10821082
}
1083+
1084+
// -----
1085+
1086+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>
1087+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}>
1088+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 16}>
1089+
#smem = #ttg.shared_memory
1090+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
1091+
// CHECK-LABEL: dot_outer_loop_arg
1092+
// CHECK: scf.for
1093+
// CHECK-NEXT: scf.for
1094+
// CHECK-NEXT: ttng.warp_group_dot
1095+
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32}
1096+
// CHECK-NEXT: scf.yield
1097+
// CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
1098+
tt.func public @dot_outer_loop_arg(%arg0: i32, %arg2: !ttg.memdesc<64x32xbf16, #shared, #smem, mutable>, %arg3: !ttg.memdesc<32x32xbf16, #shared1, #smem, mutable>) -> tensor<64x32xf32, #mma> {
1099+
%c0_i32 = arith.constant 0 : i32
1100+
%c32_i32 = arith.constant 32 : i32
1101+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma>
1102+
%outer:2 = scf.for %arg4 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg5 = %arg3, %arg8 = %cst_0) -> (!ttg.memdesc<32x32xbf16, #shared1, #smem, mutable>, tensor<64x32xf32, #mma>) : i32 {
1103+
%0 = scf.for %arg6 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg7 = %arg8) -> (tensor<64x32xf32, #mma>) : i32 {
1104+
%1 = ttng.warp_group_dot %arg2, %arg5, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<64x32xbf16, #shared, #smem, mutable> * !ttg.memdesc<32x32xbf16, #shared1, #smem, mutable> -> tensor<64x32xf32, #mma>
1105+
scf.yield %1 : tensor<64x32xf32, #mma>
1106+
}
1107+
scf.yield %arg5, %0 : !ttg.memdesc<32x32xbf16, #shared1, #smem, mutable>, tensor<64x32xf32, #mma>
1108+
}
1109+
tt.return %outer#1 : tensor<64x32xf32, #mma>
1110+
}
1111+
}
1112+
1113+
// -----
1114+
1115+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>
1116+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}>
1117+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 16}>
1118+
#smem = #ttg.shared_memory
1119+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
1120+
// CHECK-LABEL: loop_arg_cycle
1121+
// CHECK: scf.for
1122+
// CHECK-NEXT: ttng.warp_group_dot
1123+
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32}
1124+
// CHECK-NEXT: scf.yield
1125+
// CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
1126+
tt.func public @loop_arg_cycle(%arg0: i32, %arg2: !ttg.memdesc<64x32xbf16, #shared, #smem, mutable>, %arg3: !ttg.memdesc<32x32xbf16, #shared1, #smem, mutable>) -> tensor<64x32xf32, #mma> {
1127+
%c0_i32 = arith.constant 0 : i32
1128+
%c32_i32 = arith.constant 32 : i32
1129+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma>
1130+
%0:2 = scf.for %arg4 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg5 = %arg3, %arg7 = %cst_0) -> (!ttg.memdesc<32x32xbf16, #shared1, #smem, mutable>, tensor<64x32xf32, #mma>) : i32 {
1131+
%1 = ttng.warp_group_dot %arg2, %arg5, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<64x32xbf16, #shared, #smem, mutable> * !ttg.memdesc<32x32xbf16, #shared1, #smem, mutable> -> tensor<64x32xf32, #mma>
1132+
scf.yield %arg5, %1 : !ttg.memdesc<32x32xbf16, #shared1, #smem, mutable>, tensor<64x32xf32, #mma>
1133+
}
1134+
tt.return %0#1 : tensor<64x32xf32, #mma>
1135+
}
1136+
}

0 commit comments

Comments
 (0)