Skip to content

Commit f0530d2

Browse files
authored
[TritonNVIDIAGPU] Add dependency tokens to TMEM ops (#6520)
The Triton middle-end has perfect dependency+modref information about TMEM (and shared memory) because it is introduced by the middle-end by expanding chains of SSA ops. E.g. `HoistTMEMAlloc` is essentially a form of reg-2-mem for MMA accumulators. Despite this, the dependency and alias analysis needed by `HoistTMEMAlloc`, warp specialization, and the pipeliner rely on ad-hoc checks that are not always correct and which are becoming increasingly complex. Instead of building stronger memory analysis, we can just not discard the information the compiler already has. This PR adds tokens to all the ops that touch TMEM (except `TMEMCopyOp`, since it is not used in the middle-end), and acts as a form of MemorySSA (memory variable lattice encoded in the IR), and leverages them throughout the middle-end to check aliasing, modref, etc. information instead of scanning the IR. Consequently, the transformations are more robust and easier to maintain, at the cost of extra book-keeping that is necessary. This will greatly simplify the dependence analysis needed by more complex warp specialization, and help with composing warp specialization with the pipeliner(cc @htyu @manman-ren). There would be a pretty big performance cliff if this PR was wrong (failed to pipeline/warp specialize), so I sanity checked that it did not break pipelining. ### Performance numbers after ``` ├─ 703.378 976.992 matmul_kernel [M=8192, N=8192, K=512] ├─ 936.461 733.821 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 938.393 732.310 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512] ├─ 856.351 802.468 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 785.072 875.327 matmul_kernel_tma [M=8192, N=8192, K=512] ├─ 1024.165 670.981 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] ├─ 1125.056 610.810 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512] ├─ 800.940 857.986 matmul_kernel_tma_ws [M=8192, N=8192, K=512] ``` ``` fused-attention-batch4-head32-d64-fwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 183.032906 176.540661 1 2048.0 384.363999 417.633483 2 4096.0 471.816004 511.814693 3 8192.0 519.752669 566.761880 4 16384.0 545.707761 595.042579 fused-attention-batch4-head32-d64-fwd-causal=False: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 364.631059 364.685641 1 2048.0 492.108137 536.102664 2 4096.0 532.795804 580.166599 3 8192.0 550.670842 599.591255 4 16384.0 559.480705 608.551411 fused-attention-batch4-head32-d64-bwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 144.731066 152.721176 1 2048.0 234.101200 234.195236 2 4096.0 293.602665 293.519568 3 8192.0 331.644550 331.388321 4 16384.0 355.252999 354.861517 ``` ``` Problem Shape = 8192x8192x512 └─ 974.209 705.458 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512] ``` ### Performance numbers before ``` ├─ 708.163 970.391 matmul_kernel [M=8192, N=8192, K=512] ├─ 935.792 734.346 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 922.666 744.793 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512] ├─ 856.643 802.195 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 792.424 867.206 matmul_kernel_tma [M=8192, N=8192, K=512] ├─ 1020.997 673.063 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] ├─ 1134.083 605.948 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512] ├─ 799.650 859.369 matmul_kernel_tma_ws [M=8192, N=8192, K=512] ``` ``` fused-attention-batch4-head32-d64-fwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 181.507652 183.077756 1 2048.0 384.836411 416.908797 2 4096.0 471.260742 512.709282 3 8192.0 519.896730 566.172554 4 16384.0 545.181917 595.246382 fused-attention-batch4-head32-d64-fwd-causal=False: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 368.266771 373.516950 1 2048.0 492.137719 535.968650 2 4096.0 533.092876 580.134559 3 8192.0 550.571575 599.455669 4 16384.0 559.555689 608.442981 fused-attention-batch4-head32-d64-bwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 151.081525 155.745186 1 2048.0 234.359406 234.108984 2 4096.0 293.584945 293.689437 3 8192.0 331.633380 331.669234 4 16384.0 355.077635 354.963313 ``` ``` Problem Shape = 8192x8192x512 └─ 972.794 706.484 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512] ```
1 parent 3932686 commit f0530d2

31 files changed

Lines changed: 954 additions & 745 deletions

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ namespace triton::nvidia_gpu {
1717
// MMA Pipeline Analysis
1818
//===----------------------------------------------------------------------===//
1919

20-
// Returns the TMEMAllocOp and TMEMLoadOp that are used to allocate and load the
21-
// accumulator for the given MMA operation. The TMEMAllocOp and TMEMLoadOp must
22-
// be in the same region as the MMA operation.
23-
std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>>
24-
getTMemAllocAndLoad(MMAv5OpInterface mmaOp);
2520
// Given an MMAv5 operation in a loop, determine if its accumulator can be
2621
// multibuffered.
2722
bool isAccMultibufferingPossible(MMAv5OpInterface mma, scf::ForOp forOp);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ void hoistOpsBefore(Operation *refOp,
4848
void hoistOpsBefore(Block *block, Block::iterator it,
4949
const llvm::SetVector<Operation *> &toHoist);
5050

51+
//===----------------------------------------------------------------------===//
52+
// Sinking Utilities
53+
//===----------------------------------------------------------------------===//
54+
55+
// Sink a value redefinition into a block, provided that the block is dominated
56+
// by `in` and postdominated by `out`.
57+
Value sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out,
58+
Block *block);
59+
5160
//===----------------------------------------------------------------------===//
5261
// Loop Pipelining Utilities
5362
//===----------------------------------------------------------------------===//

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,6 @@ SetVector<Value> getNestedOperands(Operation *op);
243243
// Erase the given loop carried values from the loop, where `loop` is replaced
244244
// with a new loop.
245245
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
246-
247-
// Return true if two value sets may refer to the same allocation.
248-
bool mayAliasAllocations(const DenseSet<Value> &lhs,
249-
const DenseSet<Value> &rhs);
250246
} // namespace mlir
251247

252248
namespace mlir::triton {

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
4040
"void",
4141
"setPredicate",
4242
(ins "::mlir::Value":$pred)>,
43+
InterfaceMethod<"Get the memory dependencies of the accumulator.",
44+
"::mlir::Value",
45+
"getAccDep">,
46+
InterfaceMethod<"Get the mutable memory dependencies of the accumulator.",
47+
"::mlir::MutableOperandRange",
48+
"getAccDepMutable">,
49+
InterfaceMethod<"Get the produced write dependency of the accumulator.",
50+
"::mlir::Value",
51+
"getToken">,
4352
];
4453
}
4554
#endif // TRITON_NVIDIAGPU_OP_INTERFACES

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
417417
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
418418
DeclareOpInterfaceMethods<DotOpInterface>,
419419
DeclareOpInterfaceMethods<MMAv5OpInterface>,
420-
SameVariadicOperandSize
420+
AttrSizedOperandSegments
421421
]> {
422422
let summary = "block level op mapping to tensorcore gen5 mma";
423423

@@ -427,29 +427,36 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
427427
If there is a barrier the result will be safe to read after a barrier wait.
428428
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
429429
and syncronize both CTAs if the op is synchronous.
430+
431+
This operation takes and produces an optional token to indicate TMEM read
432+
and write on its accumulator operand. When the tokens are present, they can
433+
be used to check aliasing and modref on the accumulator memory.
430434
}];
431435

432436
let arguments = (ins
433437
TTG_MemDescType:$a,
434438
TTG_MemDescType:$b,
435439
TTG_MemDescType:$d,
440+
Optional<TTG_AsyncToken>:$acc_dep,
436441
I1:$useD,
437442
I1:$pred,
438443
Variadic<TTG_MemDescType>:$barriers,
439444
Variadic<I1>:$barrier_preds,
440445
OptionalAttr<UnitAttr>:$two_ctas
441446
);
447+
let results = (outs Optional<TTG_AsyncToken>:$token);
442448

443449
let builders = [
444-
OpBuilder<(ins
445-
"Value":$a, "Value":$b, "Value":$d, "Value":$useD, "Value":$pred,
446-
CArg<"bool", "false">:$two_ctas, CArg<"ValueRange", "{}">:$barriers,
450+
OpBuilder<(ins "Type":$token,
451+
"Value":$a, "Value":$b, "Value":$d, "Value":$acc_dep, "Value":$useD,
452+
"Value":$pred, CArg<"bool", "false">:$two_ctas,
453+
CArg<"ValueRange", "{}">:$barriers,
447454
CArg<"ValueRange", "{}">:$barrier_preds)>
448455
];
449456

450457
let assemblyFormat = [{
451-
$a`,` $b`,` $d`,` $useD`,` $pred
452-
`` custom<BarriersAndPreds>($barriers, $barrier_preds)
458+
$a `,` $b `,` $d `` custom<Token>($acc_dep, type($token)) `,` $useD`,`
459+
$pred `` custom<BarriersAndPreds>($barriers, $barrier_preds)
453460
attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
454461
qualified(type($d)) (`,` qualified(type($barriers))^)?
455462
}];
@@ -459,20 +466,25 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
459466
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
460467
DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
461468
DeclareOpInterfaceMethods<MMAv5OpInterface>,
462-
SameVariadicOperandSize
469+
AttrSizedOperandSegments
463470
]> {
464471
let summary = "block level op mapping to tensorcore gen5 mma";
465472

466473
let description = [{
467474
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
468475
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
469476
If there is a barrier the result will be safe to read after a barrier wait.
477+
478+
This operation takes and produces an optional token to indicate TMEM read
479+
and write on its accumulator operand. When the tokens are present, they can
480+
be used to check aliasing and modref on the accumulator memory.
470481
}];
471482

472483
let arguments = (ins
473484
TTG_MemDescType:$a,
474485
TTG_MemDescType:$b,
475486
TTG_MemDescType:$d,
487+
Optional<TTG_AsyncToken>:$acc_dep,
476488
TTG_MemDescType:$a_scale,
477489
TTG_MemDescType:$b_scale,
478490
TT_ScaleDotElemTypeAttr:$a_type,
@@ -482,6 +494,8 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
482494
Variadic<TTG_MemDescType>:$barriers,
483495
Variadic<I1>:$barrier_preds
484496
);
497+
let results = (outs Optional<TTG_AsyncToken>:$token);
498+
485499
let extraClassDeclaration = [{
486500
int64_t getBlockM();
487501
int64_t getBlockN();
@@ -491,19 +505,19 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
491505
let builders = [
492506
// Namespaces need to be prefixed so ODS prefers our
493507
// custom builder signature over the default-generated one.
494-
OpBuilder<(ins
508+
OpBuilder<(ins "::mlir::Type":$token,
495509
"::mlir::Value":$a, "::mlir::Value":$b, "::mlir::Value":$d,
496-
"::mlir::Value":$a_scale, "::mlir::Value":$b_scale,
497-
"::mlir::triton::ScaleDotElemType":$a_type,
510+
"::mlir::Value":$acc_dep, "::mlir::Value":$a_scale,
511+
"::mlir::Value":$b_scale, "::mlir::triton::ScaleDotElemType":$a_type,
498512
"::mlir::triton::ScaleDotElemType":$b_type,
499513
"::mlir::Value":$useD, "::mlir::Value":$pred,
500514
CArg<"::mlir::ValueRange", "{}">:$barriers,
501515
CArg<"::mlir::ValueRange", "{}">:$barrier_preds)>
502516
];
503517

504518
let assemblyFormat = [{
505-
$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred
506-
`lhs` `=` $a_type `rhs` `=` $b_type
519+
$a `,` $b `,` $d `` custom<Token>($acc_dep, type($token)) `,` $a_scale `,`
520+
$b_scale `,` $useD `,` $pred `lhs` `=` $a_type `rhs` `=` $b_type
507521
`` custom<BarriersAndPreds>($barriers, $barrier_preds)
508522
attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
509523
qualified(type($d)) `,` qualified(type($a_scale)) `,`
@@ -517,27 +531,55 @@ def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {
517531
let description = [{
518532
This is similar to ttg.local_load except the result layout is restricted to only few possibility.
519533
Therefore we cannot combine this op with any convert layout like local_load.
534+
535+
This operation takes and produces an optional token to indicate TMEM read
536+
on its source operand. When the tokens are present, they can
537+
be used to check aliasing and modref on the TMEM buffer.
538+
}];
539+
let arguments = (ins
540+
Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src,
541+
Optional<TTG_AsyncToken>:$dep
542+
);
543+
let results = (outs
544+
TT_Tensor:$result,
545+
Optional<TTG_AsyncToken>:$token
546+
);
547+
548+
let assemblyFormat = [{
549+
$src `` custom<Token>($dep, type($token))
550+
attr-dict `:` qualified(type($src)) `->` type($result)
520551
}];
521-
let arguments = (ins Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src);
522552

523-
let assemblyFormat = [{$src attr-dict `:` qualified(type($src)) `->` type($result)}];
524-
let results = (outs TT_Tensor:$result);
525553
let hasVerifier = 1;
554+
555+
let extraClassDeclaration = [{
556+
RankedTensorType getType() { return getResult().getType(); }
557+
operator TypedValue<RankedTensorType>() { return getResult(); }
558+
}];
526559
}
527560

528561
def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store"> {
529562
let summary = "Store a distributed tensor into a buffer in tensor memory";
530563

531564
let description = [{
532-
This is similar to ttg.local_local except the source layout is restricted to only few possibility.
565+
This is similar to ttg.local_store except the source layout is restricted to only few possibility.
566+
567+
This operation takes and produces an optional token to indicate TMEM write
568+
on its source operand. When the tokens are present, they can
569+
be used to check aliasing and modref on the TMEM buffer.
533570
}];
534571
let arguments = (ins
535572
Arg<TTG_MemDescType, "", [MemWrite<TensorMemory>]>:$dst,
573+
Optional<TTG_AsyncToken>:$dep,
536574
TT_Tensor:$src,
537575
I1:$pred
538576
);
577+
let results = (outs Optional<TTG_AsyncToken>:$token);
539578

540-
let assemblyFormat = [{$src `,` $dst `,` $pred attr-dict `:` type($src) `->` qualified(type($dst))}];
579+
let assemblyFormat = [{
580+
$src `,` $dst `` custom<Token>($dep, type($token)) `,` $pred
581+
attr-dict `:` type($src) `->` qualified(type($dst))
582+
}];
541583
let hasVerifier = 1;
542584
}
543585

@@ -551,13 +593,21 @@ def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods<MemoryEf
551593
Explicitly deallocating a buffer is optional; see local_dealloc.
552594
}];
553595
let arguments = (ins Optional<TT_Tensor>:$src);
596+
let results = (outs
597+
TTG_MemDescType:$result,
598+
Optional<TTG_AsyncToken>:$token
599+
);
554600

555601
let assemblyFormat = [{
556602
($src^)? attr-dict `:` functional-type(operands, results)
557603
}];
558604

559-
let results = (outs TTG_MemDescType:$result);
560605
let hasVerifier = 1;
606+
607+
let extraClassDeclaration = [{
608+
triton::gpu::MemDescType getType() { return getResult().getType(); }
609+
operator TypedValue<triton::gpu::MemDescType>() { return getResult(); }
610+
}];
561611
}
562612

563613
def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> {

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ std::unique_ptr<Pass> createTritonNvidiaGPUMMALoweringPass();
5858

5959
std::unique_ptr<Pass> createTritonNvidiaGPUPromoteLHSToTMemPass();
6060

61+
std::unique_ptr<Pass> createTritonNvidiaGPURemoveTMEMTokensPass();
62+
6163
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
6264

6365
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemSubtilingPass();

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,13 @@ def TritonNvidiaGPUOptimizeTMemSubtilingPass : Pass<"triton-nvidia-optimize-tmem
142142
"mlir::triton::TritonDialect"];
143143
}
144144

145+
def TritonNvidiaGPURemoveTMEMTokensPass : Pass<"triton-nvidia-gpu-remove-tmem-tokens", "mlir::ModuleOp"> {
146+
let summary = "remove TMEM tokens";
147+
148+
let description = [{
149+
The `triton-nvidia-gpu-remove-tmem-tokens` pass removes TMEM memory
150+
dependency tokens from the IR, after they are no longer needed.
151+
}];
152+
}
153+
145154
#endif

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -544,15 +544,17 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
544544
newDistributedEncoding);
545545
Value cvtAcc =
546546
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
547+
auto tokType = rewriter.getType<AsyncTokenType>();
547548
auto acc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
548-
loc, accMemDescType, cvtAcc);
549+
loc, accMemDescType, tokType, cvtAcc);
549550
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
550551
auto mma = rewriter.create<triton::nvidia_gpu::TCGen5MMAOp>(
551-
loc, a, b, acc, /*useD=*/vTrue, /*pred=*/vTrue);
552+
loc, tokType, a, b, acc, acc.getToken(), /*useD=*/vTrue,
553+
/*pred=*/vTrue);
552554
mma.setTwoCtas(useTwoCTAs);
553555

554-
auto ld =
555-
rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(loc, newAccType, acc);
556+
auto ld = rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(
557+
loc, newAccType, tokType, acc, /*dep=*/mma.getToken());
556558
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(dotOp, oldRetType, ld);
557559
return success();
558560
}
@@ -697,8 +699,9 @@ class ScaledBlockedToMMAv5
697699
newDistributedEncoding);
698700
Value cvtAcc =
699701
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
702+
auto tokType = rewriter.getType<AsyncTokenType>();
700703
auto acc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
701-
loc, accMemDescType, cvtAcc);
704+
loc, accMemDescType, tokType, cvtAcc);
702705

703706
RankedTensorType oldScaleAType = dotOp.getAScale().getType();
704707
RankedTensorType oldScaleBType = dotOp.getBScale().getType();
@@ -728,17 +731,22 @@ class ScaledBlockedToMMAv5
728731
rewriter.create<ConvertLayoutOp>(loc, newScaleAType, lhsScale);
729732
Value newScaleB =
730733
rewriter.create<ConvertLayoutOp>(loc, newScaleBType, rhsScale);
731-
Value scaleA = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
732-
loc, scaleAType, newScaleA);
733-
Value scaleB = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
734-
loc, scaleBType, newScaleB);
734+
735+
// We don't need to track memory dependencies for the scale operands since
736+
// they are not pipelined.
737+
auto scaleA = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
738+
loc, scaleAType, /*token=*/Type(), newScaleA);
739+
auto scaleB = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
740+
loc, scaleBType, /*token=*/Type(), newScaleB);
741+
735742
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
736-
rewriter.create<triton::nvidia_gpu::TCGen5MMAScaledOp>(
737-
loc, a, b, acc, scaleA, scaleB, dotOp.getAElemType(),
738-
dotOp.getBElemType(), /*useD=*/vTrue, /*pred=*/vTrue);
743+
auto mmaOp = rewriter.create<triton::nvidia_gpu::TCGen5MMAScaledOp>(
744+
loc, tokType, a, b, acc.getResult(), acc.getToken(), scaleA.getResult(),
745+
scaleB.getResult(), dotOp.getAElemType(), dotOp.getBElemType(),
746+
/*useD=*/vTrue, /*pred=*/vTrue);
739747

740-
auto ld =
741-
rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(loc, newAccType, acc);
748+
auto ld = rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(
749+
loc, newAccType, tokType, acc, mmaOp.getToken());
742750
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(dotOp, oldRetType, ld);
743751
return success();
744752
}

0 commit comments

Comments
 (0)