Skip to content

Commit 7a83ed7

Browse files
authored
[NFC] Correct async_copy_global_to_local operand type inference conditions (#6448)
The `async_copy_global_to_local` operation currently exhibits incorrect type inference behavior when the optional `mask` or `other` operands are provided. Adjust the operand count checks to: - `<= 2` for `mask` since it is the 3rd operand - `<= 3` for `other` since it is the 4th operand --------- Co-authored-by: junjian.zhan <junjian.zhan@iluvatar.com>
1 parent f2e9247 commit 7a83ed7

2 files changed

Lines changed: 38 additions & 6 deletions

File tree

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,10 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
7777

7878
def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
7979
AttrSizedOperandSegments,
80-
TypesMatchWith<"infer mask type from src type",
81-
"src", "mask", "getI1SameShape($_self)",
82-
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
83-
TypesMatchWith<"infer other type from src type",
84-
"src", "other", "getPointeeType($_self)",
85-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">
80+
OptionalTypesMatchWith<"infer mask type from src type",
81+
"src", "mask", "getI1SameShape($_self)">,
82+
OptionalTypesMatchWith<"infer other type from src type",
83+
"src", "other", "getPointeeType($_self)">,
8684
]> {
8785
let summary = "copy data from global memory to local memory asynchronously";
8886

test/TritonGPU/invalid.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,37 @@ tt.func @partition_no_terminator() {
345345
} : () -> ()
346346
tt.return
347347
}
348+
349+
// -----
350+
351+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
352+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
353+
#smem = #ttg.shared_memory
354+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
355+
tt.func @async_copy_invalid_mask_type(%input: tensor<64x64x!tt.ptr<f16>, #blocked>,
356+
%view: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>,
357+
%invalid_mask: tensor<64x64xi32, #blocked> // expected-note {{prior use here}}
358+
) {
359+
// expected-error @+1 {{expects different type than prior uses}}
360+
%token = ttg.async_copy_global_to_local %input, %view mask %invalid_mask
361+
: tensor<64x64x!tt.ptr<f16>, #blocked> -> <64x64xf16, #shared, #smem, mutable>
362+
tt.return
363+
}
364+
}
365+
366+
// -----
367+
368+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
369+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
370+
#smem = #ttg.shared_memory
371+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
372+
tt.func @async_copy_invalid_other_type(%input: tensor<64x64x!tt.ptr<f16>, #blocked>,
373+
%view: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>,
374+
%mask: tensor<64x64xi1, #blocked>,
375+
%invalid_other: tensor<64x64xf32, #blocked> // expected-note {{prior use here}}
376+
) {
377+
// expected-error @+1 {{expects different type than prior uses}}
378+
%token = ttg.async_copy_global_to_local %input, %view mask %mask other %invalid_other : tensor<64x64x!tt.ptr<f16>, #blocked> -> <64x64xf16, #shared, #smem, mutable>
379+
tt.return
380+
}
381+
}

0 commit comments

Comments
 (0)