Skip to content

Commit 3be1a23

Browse files
authored
Limit warp specialization partition attrs to WS pass (#10058)
1 parent e9352e2 commit 3be1a23

File tree

14 files changed

+330
-203
lines changed

14 files changed

+330
-203
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
5151
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
5252
constexpr static char AttrTargetName[] = "ttg.target";
5353
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
54-
// FIXME: rename to match above
55-
constexpr static char kPartitionAttrName[] = "ttg.partition";
56-
constexpr static char kPartitionOutputsAttrName[] = "ttg.partition.outputs";
57-
constexpr static char kPartitionStagesAttrName[] = "ttg.partition.stages";
58-
constexpr static char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag";
5954

6055
// Find the contextual number of warps on which this operation is executed.
6156
int lookupNumWarps(Operation *op);
@@ -335,13 +330,6 @@ LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
335330
ShapedType dstTy);
336331
// Verify a memory allocation operation.
337332
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);
338-
339-
SetVector<int> getPartitionIds(Operation *op);
340-
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
341-
SetVector<int> getPartitionIds(OpOperand *use);
342-
bool hasPartition(Operation *op);
343-
bool hasWarpSpecializeTag(Operation *op);
344-
std::optional<int> getWarpSpecializeTag(Operation *op);
345333
/// Returns the size in bytes of a scalar type when stored in shared memory.
346334
size_t getSharedMemorySize(Type type);
347335

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <numeric>
55
#include <utility>
66

7-
#include "mlir/Dialect/UB/IR/UBOps.h"
87
#include "mlir/IR/DialectImplementation.h"
98
#include "mlir/IR/OpImplementation.h"
109
#include "mlir/IR/OperationSupport.h"
@@ -4051,135 +4050,6 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
40514050
<< " which is expected only on `module` or `tt.func` ops";
40524051
}
40534052

4054-
// Verify that all ops in a tt.warp_specialize op have partition ids
4055-
if (attr.getName() == "tt.warp_specialize") {
4056-
if (!isa<scf::ForOp>(op)) {
4057-
return op->emitOpError("has unexpected attribute ")
4058-
<< attr.getName() << " which is expected only on `scf.for` ops";
4059-
}
4060-
Operation *failedOp = nullptr;
4061-
op->walk([&](Operation *childOp) {
4062-
if (!childOp->hasAttr(kPartitionAttrName)) {
4063-
failedOp = childOp;
4064-
WalkResult::interrupt();
4065-
}
4066-
});
4067-
if (failedOp) {
4068-
return failedOp->emitOpError("does not have expected attribute ")
4069-
<< kPartitionAttrName
4070-
<< " which is expected on all child ops of an op with "
4071-
"attribute `tt.warp_specialize`";
4072-
}
4073-
}
4074-
4075-
// Verify that partition id lists are non-empty, sorted and have no duplicates
4076-
auto verifyPartitionIds =
4077-
[&](const ArrayRef<int> &partitionIds) -> LogicalResult {
4078-
SetVector<int> idSet;
4079-
for (auto id : partitionIds) {
4080-
if (idSet.contains(id))
4081-
return op->emitOpError("has duplicated partition ids in attribute ")
4082-
<< attr.getName();
4083-
idSet.insert(id);
4084-
}
4085-
if (idSet.empty())
4086-
return op->emitOpError("has no partition ids in attribute ")
4087-
<< attr.getName();
4088-
auto ids = idSet.takeVector();
4089-
SmallVector<int> sortedIds(ids.begin(), ids.end());
4090-
std::sort(sortedIds.begin(), sortedIds.end());
4091-
if (ids != sortedIds)
4092-
return op->emitOpError("partition ids not in sorted order in attribute ")
4093-
<< attr.getName();
4094-
return success();
4095-
};
4096-
4097-
if (attr.getName() == kPartitionAttrName) {
4098-
auto result = verifyPartitionIds(
4099-
cast<DenseI32ArrayAttr>(attr.getValue()).asArrayRef());
4100-
if (failed(result))
4101-
return result;
4102-
}
4103-
if (attr.getName() == kPartitionOutputsAttrName) {
4104-
auto arrayAttr = cast<ArrayAttr>(attr.getValue());
4105-
for (auto idx = 0; idx < arrayAttr.size(); idx++) {
4106-
auto result = verifyPartitionIds(
4107-
cast<DenseI32ArrayAttr>(arrayAttr[idx]).asArrayRef());
4108-
if (failed(result))
4109-
return result;
4110-
}
4111-
}
4112-
4113-
// Verify that op partitions include partitions of all child ops
4114-
if (attr.getName() == kPartitionAttrName && op->getNumRegions() != 0) {
4115-
SetVector<int> expectedIds;
4116-
for (auto &region : op->getRegions()) {
4117-
for (auto &block : region.getBlocks()) {
4118-
for (auto &childOp : block.getOperations()) {
4119-
if (isa<scf::YieldOp, ub::PoisonOp>(childOp)) {
4120-
// yield ops and ub.poison do not need partition ids
4121-
continue;
4122-
}
4123-
if (!childOp.hasAttr(kPartitionAttrName))
4124-
return childOp.emitOpError("does not have expected attribute ")
4125-
<< kPartitionAttrName
4126-
<< " which is expected for ops whose parent has partitions";
4127-
auto ids = getPartitionIds(&childOp);
4128-
expectedIds.insert(ids.begin(), ids.end());
4129-
}
4130-
}
4131-
}
4132-
auto partitionIds = getPartitionIds(op);
4133-
for (auto id : expectedIds) {
4134-
if (!partitionIds.contains(id)) {
4135-
return op->emitOpError("partition ids in attr ")
4136-
<< attr.getName()
4137-
<< " does not contain partition ids of all child ops";
4138-
}
4139-
}
4140-
}
4141-
4142-
if (attr.getName() == kPartitionOutputsAttrName) {
4143-
if (!isa<scf::ForOp, scf::IfOp, triton::ReduceOp>(op))
4144-
return op->emitOpError("has unexpected attribute ") << attr.getName();
4145-
4146-
// Verify that number of output partitions matches number of For/If results
4147-
size_t numResults = 0;
4148-
if (isa<scf::ForOp>(op)) {
4149-
numResults = cast<scf::ForOp>(op).getResults().size();
4150-
} else if (isa<scf::IfOp>(op)) {
4151-
numResults = cast<scf::IfOp>(op).getResults().size();
4152-
} else {
4153-
numResults = cast<triton::ReduceOp>(op).getResults().size();
4154-
}
4155-
4156-
if (cast<ArrayAttr>(attr.getValue()).size() != numResults) {
4157-
return op->emitOpError("does not have expected number of output "
4158-
"partition sets in attr ")
4159-
<< attr.getName() << "; should match number of results";
4160-
}
4161-
4162-
// Verify that union of op output partitions is a subset of op partitions
4163-
if (!op->hasAttr(kPartitionAttrName))
4164-
return op->emitOpError("does not have expected attribute ")
4165-
<< kPartitionAttrName << " which is expected for ops with attr "
4166-
<< kPartitionOutputsAttrName;
4167-
auto partitionIds = getPartitionIds(op);
4168-
4169-
SetVector<int> outputPartitionIdsUnion;
4170-
for (auto outputPartitionIds : getPartitionOutputs(op)) {
4171-
outputPartitionIdsUnion.insert(outputPartitionIds.begin(),
4172-
outputPartitionIds.end());
4173-
}
4174-
if (!std::all_of(outputPartitionIdsUnion.begin(),
4175-
outputPartitionIdsUnion.end(),
4176-
[&](int id) { return partitionIds.contains(id); })) {
4177-
return op->emitOpError("partition ids in attr ")
4178-
<< kPartitionAttrName
4179-
<< " must be the union of all partition ids in " << attr.getName();
4180-
}
4181-
}
4182-
41834053
return success();
41844054
}
41854055

@@ -4414,57 +4284,6 @@ SmallVector<int64_t> triton::gpu::getTMABlockShape(
44144284
mode);
44154285
}
44164286

4417-
SetVector<int> triton::gpu::getPartitionIds(Operation *op) {
4418-
auto attrs = op->getAttr(kPartitionAttrName);
4419-
SmallVector<int> partitionIds;
4420-
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
4421-
partitionIds.push_back(id);
4422-
}
4423-
std::sort(partitionIds.begin(), partitionIds.end());
4424-
return SetVector<int>(partitionIds.begin(), partitionIds.end());
4425-
}
4426-
4427-
SmallVector<SetVector<int>, 4> triton::gpu::getPartitionOutputs(Operation *op) {
4428-
SmallVector<SetVector<int>, 4> partitionOutputsIds;
4429-
if (op->getNumResults() == 0) {
4430-
return partitionOutputsIds;
4431-
}
4432-
assert(op->hasAttr(kPartitionOutputsAttrName));
4433-
auto arrayAttr = cast<ArrayAttr>(op->getAttr(kPartitionOutputsAttrName));
4434-
for (auto attr : arrayAttr) {
4435-
auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef();
4436-
partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end()));
4437-
}
4438-
return partitionOutputsIds;
4439-
}
4440-
4441-
SetVector<int> triton::gpu::getPartitionIds(OpOperand *use) {
4442-
auto owner = use->getOwner();
4443-
if (isa<scf::YieldOp>(owner)) {
4444-
return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()];
4445-
} else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
4446-
int idx = use->getOperandNumber() - forOp.getNumControlOperands();
4447-
return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp);
4448-
} else {
4449-
return getPartitionIds(owner);
4450-
}
4451-
}
4452-
4453-
bool triton::gpu::hasPartition(Operation *op) {
4454-
return op && op->hasAttr(kPartitionAttrName);
4455-
}
4456-
4457-
bool triton::gpu::hasWarpSpecializeTag(Operation *op) {
4458-
return op && op->hasAttr(kWarpSpecializeTagAttrName);
4459-
}
4460-
4461-
std::optional<int> triton::gpu::getWarpSpecializeTag(Operation *op) {
4462-
if (hasWarpSpecializeTag(op)) {
4463-
return cast<IntegerAttr>(op->getAttr(kWarpSpecializeTagAttrName)).getInt();
4464-
}
4465-
return std::nullopt;
4466-
}
4467-
44684287
PaddedSharedEncodingAttr triton::gpu::getPaddedEncoding(Attribute encoding) {
44694288
if (!encoding)
44704289
return nullptr;

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "PartitionAttrs.h"
12
#include "mlir/Dialect/Arith/Transforms/Passes.h"
23
#include "mlir/IR/BuiltinOps.h"
34
#include "mlir/Pass/Pass.h"
@@ -23,6 +24,25 @@ namespace mlir::triton::gpu {
2324
} // namespace mlir::triton::gpu
2425

2526
namespace {
27+
struct VerifyWarpSpecializationPartitions
28+
: PassWrapper<VerifyWarpSpecializationPartitions, OperationPass<ModuleOp>> {
29+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
30+
VerifyWarpSpecializationPartitions)
31+
32+
void runOnOperation() override {
33+
WalkResult result = getOperation().walk([&](scf::ForOp loop) {
34+
if (!loop->hasAttr(kPartitionStagesAttrName))
35+
return WalkResult::advance();
36+
if (failed(verifyPartitionedLoop(loop))) {
37+
signalPassFailure();
38+
return WalkResult::interrupt();
39+
}
40+
return WalkResult::advance();
41+
});
42+
(void)result;
43+
}
44+
};
45+
2646
struct AutomaticWarpSpecialization
2747
: triton::gpu::impl::TritonGPUAutomaticWarpSpecializationBase<
2848
AutomaticWarpSpecialization> {
@@ -57,20 +77,38 @@ void multiBufferTMADescriptors(ModuleOp mod, int numStages) {
5777
}
5878
}
5979

80+
void clearInternalWarpSpecializationAttrs(ModuleOp mod) {
81+
mod.walk([](Operation *op) {
82+
op->removeAttr(kPartitionAttrName);
83+
op->removeAttr(kPartitionOutputsAttrName);
84+
op->removeAttr(kPartitionStagesAttrName);
85+
op->removeAttr(kWarpSpecializeTagAttrName);
86+
});
87+
}
88+
89+
std::unique_ptr<Pass> createVerifyWarpSpecializationPartitionsPass() {
90+
return std::make_unique<VerifyWarpSpecializationPartitions>();
91+
}
92+
6093
} // namespace
6194

6295
void AutomaticWarpSpecialization::runOnOperation() {
6396
OpPassManager pm;
64-
pm.addPass(createTritonGPUPartitionScheduling());
65-
pm.addPass(createNVWSHoistTmemStore());
66-
pm.addPass(createNVWSInsertAref());
67-
pm.addPass(createNVWSInsertTmemAref());
97+
auto addPassWithPartitionVerifier = [&](std::unique_ptr<Pass> pass) {
98+
pm.addPass(std::move(pass));
99+
pm.addPass(createVerifyWarpSpecializationPartitionsPass());
100+
};
101+
102+
addPassWithPartitionVerifier(createTritonGPUPartitionScheduling());
103+
addPassWithPartitionVerifier(createNVWSHoistTmemStore());
104+
addPassWithPartitionVerifier(createNVWSInsertAref());
105+
addPassWithPartitionVerifier(createNVWSInsertTmemAref());
68106
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
69107
// FIXME: Re-enable integer range analysis once it is fixed.
70108
// pm.addPass(arith::createIntRangeOptimizationsPass());
71-
pm.addPass(createSCCPPass());
72-
pm.addPass(createCSEPass());
73-
pm.addPass(createNVWSLowerAref({numStages}));
109+
addPassWithPartitionVerifier(createSCCPPass());
110+
addPassWithPartitionVerifier(createCSEPass());
111+
addPassWithPartitionVerifier(createNVWSLowerAref({numStages}));
74112
pm.addPass(createTritonGPUPartitionLoops());
75113
pm.addPass(createNVWSLowerWarpGroup());
76114
pm.addPass(createTritonGPUScheduleLoops());
@@ -80,4 +118,5 @@ void AutomaticWarpSpecialization::runOnOperation() {
80118
// Multi-buffer TMA descriptors. We cannot rely on SWP to do it, to support
81119
// desc updates in nested loops.
82120
multiBufferTMADescriptors(getOperation(), numStages);
121+
clearInternalWarpSpecializationAttrs(getOperation());
83122
}

0 commit comments

Comments
 (0)