Skip to content

Commit 7b2682d

Browse files
authored
2CTA Block Scale MMA with tcgen05.cp (#9460)
# 2CTA Block Scale MMA with tcgen05.cp * **New Features** * Functional 2-CTA Block Scale MMA with tcgen05.cp: full path (TMA → cp → MMA → commit) for two CTAs * **Bug Fixes** * TMA barrier (2CTA hang): arrive on lead CTA barrier and use .dst shared::cluster * TCGen05.cp (2CTA hang): use Lead CTA predicate for tcgen05.cp instruction * Scaled MMA 2CTA: per-CTA M/N; double M in the scale descriptor * TCGen05.cp: only copy per-CTA scales (block=0) * **Tests** * test_mma_scaled_tcgen05_copy: 48 cases, parameters: num_ctas, multicast, block size, dtype * tritongpu_to_llvm_blackwell.mlir: tmem_copy cta_group::2, lead-CTA predicate # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent cad7253 commit 7b2682d

13 files changed

Lines changed: 343 additions & 84 deletions

File tree

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
6262
"void",
6363
"setIsAsync",
6464
(ins "bool":$isAsync)>,
65+
InterfaceMethod<"Return true if this MMA op uses two CTAs.",
66+
"bool",
67+
"getTwoCtas">,
6568
InterfaceMethod<"Return true if this MMA op executes asynchronously.",
6669
"bool",
6770
"isAsync">

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,9 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
633633

634634
let description = [{
635635
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
636-
if is_async is false, the op executes synchronously. The barrier operands must not be present in that case.
636+
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs
637+
and synchronize both CTAs if the op is synchronous.
638+
If is_async is false, the op executes synchronously. The barrier operands must not be present in that case.
637639
Otherwise, if a barrier is given, the op will trigger a commit/arrive on it.
638640
The result will be safe to read after a barrier wait.
639641

@@ -655,6 +657,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
655657
I1:$pred,
656658
Variadic<TTG_MemDescType>:$barriers,
657659
Variadic<I1>:$barrier_preds,
660+
UnitAttr:$two_ctas,
658661
UnitAttr:$is_async
659662
);
660663
let results = (outs Optional<TTG_AsyncToken>:$token);
@@ -676,6 +679,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
676679
"::mlir::Value":$useD, "::mlir::Value":$pred,
677680
CArg<"::mlir::ValueRange", "{}">:$barriers,
678681
CArg<"::mlir::ValueRange", "{}">:$barrier_preds,
682+
CArg<"bool", "false">:$two_ctas,
679683
CArg<"bool", "false">:$is_async)>
680684
];
681685

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -894,15 +894,17 @@ void TCGen5MMAScaledOp::build(OpBuilder &builder, OperationState &state,
894894
Value accDep, Value aScale, Value bScale,
895895
ScaleDotElemType aType, ScaleDotElemType bType,
896896
Value useD, Value pred, ValueRange barriers,
897-
ValueRange barrierPreds, bool isAsync) {
897+
ValueRange barrierPreds, bool twoCTAs,
898+
bool isAsync) {
898899
MLIRContext *ctx = builder.getContext();
899900
if (!barriers.empty()) {
900901
isAsync = true;
901902
}
902903
build(builder, state, token, a, b, d, accDep, aScale, bScale,
903904
ScaleDotElemTypeAttr::get(ctx, aType),
904905
ScaleDotElemTypeAttr::get(ctx, bType), useD, pred, barriers,
905-
barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr());
906+
barrierPreds, twoCTAs ? builder.getUnitAttr() : UnitAttr(),
907+
isAsync ? builder.getUnitAttr() : UnitAttr());
906908
}
907909

908910
bool TCGen5MMAScaledOp::isAsync() { return getIsAsync(); }
@@ -1065,11 +1067,6 @@ LogicalResult TMEMCopyOp::verify() {
10651067
"representable in a matrix descriptor.");
10661068
}
10671069

1068-
auto mod = getOperation()->getParentOfType<ModuleOp>();
1069-
unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
1070-
if (numCTAs != 1)
1071-
return emitOpError("NYI: Only one CTA is supported for now.");
1072-
10731070
// Fp4 we could lift if we needed
10741071
auto nvmmaEnc =
10751072
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(srcTy.getEncoding());

lib/Dialect/TritonNvidiaGPU/Transforms/CheckMatmulTwoCTAs.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class TritonNvidiaGPUCheckMatmulTwoCTAPass
2828
Operation *firstMatmul = nullptr;
2929
bool firstTwoCTA = false;
3030

31-
WalkResult result = mod.walk([&](ttng::TCGen5MMAOp op) {
31+
// Walk all MMAv5 ops using the interface
32+
WalkResult result = mod.walk([&](ttng::MMAv5OpInterface op) -> WalkResult {
3233
bool currentTwoCTA = op.getTwoCtas();
3334
if (!firstMatmul) {
3435
firstMatmul = op;

lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ static bool isDistributedMultiCTAOp(Operation *op, bool isRead) {
4040
if (auto mma = dyn_cast<ttng::TCGen5MMAOp>(op)) {
4141
return mma.getTwoCtas();
4242
} else if (auto mmaScaled = dyn_cast<ttng::TCGen5MMAScaledOp>(op)) {
43-
// TODO: Change when we support scaled MMA with 2CTAs
44-
assert(!ttng::getModuleTwoCTAs(op->getParentOfType<ModuleOp>()) &&
45-
"Scaled MMA with 2CTAs not supported");
46-
return false;
43+
return mmaScaled.getTwoCtas();
4744
} else if (auto tma = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
4845
return tma.getMulticast();
4946
}

python/src/gluon_ir.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,13 +904,13 @@ void init_gluon_ir(py::module &&m) {
904904
[](GluonOpBuilder &self, Value a, Value b, Value acc, Value aScale,
905905
Value bScale, tt::ScaleDotElemType aType,
906906
tt::ScaleDotElemType bType, Value useAcc, Value pred,
907-
std::vector<Value> &mbarriers,
908-
std::vector<Value> &mbarrier_preds) {
907+
std::vector<Value> &mbarriers, std::vector<Value> &mbarrier_preds,
908+
bool two_ctas) {
909909
Value accDep;
910910
auto tokType = self.getBuilder().getType<ttg::AsyncTokenType>();
911911
self.create<ttng::TCGen5MMAScaledOp>(
912912
tokType, a, b, acc, accDep, aScale, bScale, aType, bType,
913-
useAcc, pred, mbarriers, mbarrier_preds);
913+
useAcc, pred, mbarriers, mbarrier_preds, two_ctas);
914914
})
915915
.def("create_tcgen05_commit",
916916
[](GluonOpBuilder &self, Value &barrier, Value &pred,

0 commit comments

Comments
 (0)