Skip to content

Commit 8324fad

Browse files
authored
Fix constraints in isExpensiveCat (#9995)
isExpensiveCat does not reflect the constraint we have in lowering, which is that the number of unique result elements per thread must be equal to the total number of unique operand elements per thread. This means that we can sometimes fold `CatOp` into layout conversions that have destination layouts that violate this requirement. Rename it to `isLegalCatEncoding` to reflect that it is actually a correctness requirement, and update it to reflect the actual constraint.
1 parent f43dff6 commit 8324fad

File tree

5 files changed

+70
-30
lines changed

5 files changed

+70
-30
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
260260
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
261261
bool kContig);
262262

263-
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
263+
// Return true if \p cat would be valid with result encoding \p targetEncoding.
264+
bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding);
264265

265266
// Return true if a view between the two types cannot be implemented as a no-op.
266267
bool isExpensiveView(Type srcType, Type dstType);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -409,16 +409,27 @@ SmallVector<unsigned> orderPerDimImpl(const LinearLayout &ll,
409409
return order.takeVector();
410410
}
411411

412-
bool isExpensiveCat(CatOp cat, Attribute targetEncoding) {
413-
// If the new elements per thread is less than the old one, we will need to
414-
// do convert encoding that goes through shared memory anyway. So we
415-
// consider it as expensive.
416-
RankedTensorType tensorTy = cat.getType();
417-
auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy);
418-
auto shape = tensorTy.getShape();
419-
auto newTotalElemsPerThread =
420-
gpu::getTotalElemsPerThread(targetEncoding, shape);
421-
return newTotalElemsPerThread < totalElemsPerThread;
412+
static int64_t getNumNonBroadcastRegisters(ArrayRef<int64_t> shape,
413+
Attribute encoding) {
414+
auto kReg = StringAttr::get(encoding.getContext(), "register");
415+
auto strippedLayout =
416+
toLinearLayout(shape, encoding).removeZeroBasesAlongDim(kReg);
417+
return strippedLayout.getInDimSize(kReg);
418+
}
419+
420+
static int64_t getNumNonBroadcastRegisters(RankedTensorType tensorType) {
421+
return getNumNonBroadcastRegisters(tensorType.getShape(),
422+
tensorType.getEncoding());
423+
}
424+
425+
bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding) {
426+
// Cat lowering concatenates the operands' unique register values. So the
427+
// number of unique register values in the result must be equal to those in
428+
// the operands.
429+
int64_t operandRegs = getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2;
430+
int64_t resultRegs =
431+
getNumNonBroadcastRegisters(cat.getType().getShape(), targetEncoding);
432+
return resultRegs == operandRegs;
422433
}
423434

424435
static LogicalResult

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ struct CanonicalizeConvertFromConvert
355355

356356
// cvt(cat) -> cat
357357
if (auto cat = dyn_cast<CatOp>(arg)) {
358-
if (isExpensiveCat(cat, op.getType().getEncoding()))
358+
if (!isLegalCatEncoding(cat, op.getType().getEncoding()))
359359
return failure();
360360

361361
rewriter.replaceOpWithNewOp<CatOp>(op, op->getResult(0).getType(),

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -604,26 +604,10 @@ bool isExpensiveLoadOrStore(Operation *op) {
604604
return true;
605605
}
606606

607-
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
608-
if (!op)
609-
return true;
610-
if (isa<triton::LoadOp, triton::StoreOp>(op))
611-
return isExpensiveLoadOrStore(op);
612-
if (isa<triton::CatOp>(op))
613-
return triton::gpu::isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
614-
if (isa<triton::gpu::AsyncCopyGlobalToLocalOp, triton::AtomicRMWOp,
615-
triton::AtomicCASOp, triton::DotOp>(op))
616-
return true;
617-
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
618-
op))
619-
return true;
620-
return false;
621-
}
622-
623607
bool canUseResultEncoding(Operation *op, Attribute targetEncoding) {
624608
if (isa<triton::CatOp>(op))
625-
return !triton::gpu::isExpensiveCat(cast<triton::CatOp>(op),
626-
targetEncoding);
609+
return triton::gpu::isLegalCatEncoding(cast<triton::CatOp>(op),
610+
targetEncoding);
627611
if (auto convert = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
628612
if (mlir::isa<triton::gpu::NvidiaMmaEncodingAttr>(targetEncoding)) {
629613
auto srcEncoding = convert.getSrc().getType().getEncoding();

test/TritonGPU/combine.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4176,3 +4176,47 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 1 : i32}
41764176
tt.return %o : tensor<1x2x2xi32, #dst>
41774177
}
41784178
}
4179+
4180+
// -----
4181+
4182+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4183+
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4184+
#linear = #ttg.linear<{register = [[1], [16]], lane = [[0], [0], [2], [4], [8]], warp = [[0], [0]], block = []}>
4185+
4186+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
4187+
// CHECK-LABEL: @cat_incompatible_target_keeps_convert
4188+
tt.func public @cat_incompatible_target_keeps_convert(%out: !tt.ptr<i32>) {
4189+
%lhs = arith.constant dense<0> : tensor<16xi32, #blocked>
4190+
%rhs = arith.constant dense<1> : tensor<16xi32, #blocked>
4191+
// CHECK: %[[CAT:[^ ]+]] = tt.cat
4192+
%cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2>
4193+
// CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]]
4194+
%cvt = ttg.convert_layout %cat {allocation.offset = 0 : i32} : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear>
4195+
%ptr = tt.splat %out : !tt.ptr<i32> -> tensor<32x!tt.ptr<i32>, #linear>
4196+
// CHECK: tt.store {{.*}}, %[[CVT]]
4197+
tt.store %ptr, %cvt : tensor<32x!tt.ptr<i32>, #linear>
4198+
tt.return
4199+
}
4200+
}
4201+
4202+
// -----
4203+
4204+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4205+
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4206+
#linear_bcast = #ttg.linear<{register = [[0]], lane = [[1], [2], [4], [8], [16]], warp = [[0], [0]], block = []}>
4207+
4208+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
4209+
// CHECK-LABEL: @cat_target_adds_broadcasting_keeps_convert
4210+
tt.func public @cat_target_adds_broadcasting_keeps_convert(%out: !tt.ptr<i32>) {
4211+
%lhs = arith.constant dense<0> : tensor<16xi32, #blocked>
4212+
%rhs = arith.constant dense<1> : tensor<16xi32, #blocked>
4213+
// CHECK: %[[CAT:[^ ]+]] = tt.cat
4214+
%cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2>
4215+
// CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]]
4216+
%cvt = ttg.convert_layout %cat : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear_bcast>
4217+
%ptr = tt.splat %out : !tt.ptr<i32> -> tensor<32x!tt.ptr<i32>, #linear_bcast>
4218+
// CHECK: tt.store {{.*}}, %[[CVT]]
4219+
tt.store %ptr, %cvt : tensor<32x!tt.ptr<i32>, #linear_bcast>
4220+
tt.return
4221+
}
4222+
}

0 commit comments

Comments
 (0)