|
4 | 4 | #include <numeric> |
5 | 5 | #include <utility> |
6 | 6 |
|
7 | | -#include "mlir/Dialect/UB/IR/UBOps.h" |
8 | 7 | #include "mlir/IR/DialectImplementation.h" |
9 | 8 | #include "mlir/IR/OpImplementation.h" |
10 | 9 | #include "mlir/IR/OperationSupport.h" |
@@ -4051,135 +4050,6 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, |
4051 | 4050 | << " which is expected only on `module` or `tt.func` ops"; |
4052 | 4051 | } |
4053 | 4052 |
|
4054 | | - // Verify that all ops in a tt.warp_specialize op have partition ids |
4055 | | - if (attr.getName() == "tt.warp_specialize") { |
4056 | | - if (!isa<scf::ForOp>(op)) { |
4057 | | - return op->emitOpError("has unexpected attribute ") |
4058 | | - << attr.getName() << " which is expected only on `scf.for` ops"; |
4059 | | - } |
4060 | | - Operation *failedOp = nullptr; |
4061 | | - op->walk([&](Operation *childOp) { |
4062 | | - if (!childOp->hasAttr(kPartitionAttrName)) { |
4063 | | - failedOp = childOp; |
4064 | | - WalkResult::interrupt(); |
4065 | | - } |
4066 | | - }); |
4067 | | - if (failedOp) { |
4068 | | - return failedOp->emitOpError("does not have expected attribute ") |
4069 | | - << kPartitionAttrName |
4070 | | - << " which is expected on all child ops of an op with " |
4071 | | - "attribute `tt.warp_specialize`"; |
4072 | | - } |
4073 | | - } |
4074 | | - |
4075 | | - // Verify that partition id lists are non-empty, sorted and have no duplicates |
4076 | | - auto verifyPartitionIds = |
4077 | | - [&](const ArrayRef<int> &partitionIds) -> LogicalResult { |
4078 | | - SetVector<int> idSet; |
4079 | | - for (auto id : partitionIds) { |
4080 | | - if (idSet.contains(id)) |
4081 | | - return op->emitOpError("has duplicated partition ids in attribute ") |
4082 | | - << attr.getName(); |
4083 | | - idSet.insert(id); |
4084 | | - } |
4085 | | - if (idSet.empty()) |
4086 | | - return op->emitOpError("has no partition ids in attribute ") |
4087 | | - << attr.getName(); |
4088 | | - auto ids = idSet.takeVector(); |
4089 | | - SmallVector<int> sortedIds(ids.begin(), ids.end()); |
4090 | | - std::sort(sortedIds.begin(), sortedIds.end()); |
4091 | | - if (ids != sortedIds) |
4092 | | - return op->emitOpError("partition ids not in sorted order in attribute ") |
4093 | | - << attr.getName(); |
4094 | | - return success(); |
4095 | | - }; |
4096 | | - |
4097 | | - if (attr.getName() == kPartitionAttrName) { |
4098 | | - auto result = verifyPartitionIds( |
4099 | | - cast<DenseI32ArrayAttr>(attr.getValue()).asArrayRef()); |
4100 | | - if (failed(result)) |
4101 | | - return result; |
4102 | | - } |
4103 | | - if (attr.getName() == kPartitionOutputsAttrName) { |
4104 | | - auto arrayAttr = cast<ArrayAttr>(attr.getValue()); |
4105 | | - for (auto idx = 0; idx < arrayAttr.size(); idx++) { |
4106 | | - auto result = verifyPartitionIds( |
4107 | | - cast<DenseI32ArrayAttr>(arrayAttr[idx]).asArrayRef()); |
4108 | | - if (failed(result)) |
4109 | | - return result; |
4110 | | - } |
4111 | | - } |
4112 | | - |
4113 | | - // Verify that op partitions include partitions of all child ops |
4114 | | - if (attr.getName() == kPartitionAttrName && op->getNumRegions() != 0) { |
4115 | | - SetVector<int> expectedIds; |
4116 | | - for (auto ®ion : op->getRegions()) { |
4117 | | - for (auto &block : region.getBlocks()) { |
4118 | | - for (auto &childOp : block.getOperations()) { |
4119 | | - if (isa<scf::YieldOp, ub::PoisonOp>(childOp)) { |
4120 | | - // yield ops and ub.poison do not need partition ids |
4121 | | - continue; |
4122 | | - } |
4123 | | - if (!childOp.hasAttr(kPartitionAttrName)) |
4124 | | - return childOp.emitOpError("does not have expected attribute ") |
4125 | | - << kPartitionAttrName |
4126 | | - << " which is expected for ops whose parent has partitions"; |
4127 | | - auto ids = getPartitionIds(&childOp); |
4128 | | - expectedIds.insert(ids.begin(), ids.end()); |
4129 | | - } |
4130 | | - } |
4131 | | - } |
4132 | | - auto partitionIds = getPartitionIds(op); |
4133 | | - for (auto id : expectedIds) { |
4134 | | - if (!partitionIds.contains(id)) { |
4135 | | - return op->emitOpError("partition ids in attr ") |
4136 | | - << attr.getName() |
4137 | | - << " does not contain partition ids of all child ops"; |
4138 | | - } |
4139 | | - } |
4140 | | - } |
4141 | | - |
4142 | | - if (attr.getName() == kPartitionOutputsAttrName) { |
4143 | | - if (!isa<scf::ForOp, scf::IfOp, triton::ReduceOp>(op)) |
4144 | | - return op->emitOpError("has unexpected attribute ") << attr.getName(); |
4145 | | - |
4146 | | - // Verify that number of output partitions matches number of For/If results |
4147 | | - size_t numResults = 0; |
4148 | | - if (isa<scf::ForOp>(op)) { |
4149 | | - numResults = cast<scf::ForOp>(op).getResults().size(); |
4150 | | - } else if (isa<scf::IfOp>(op)) { |
4151 | | - numResults = cast<scf::IfOp>(op).getResults().size(); |
4152 | | - } else { |
4153 | | - numResults = cast<triton::ReduceOp>(op).getResults().size(); |
4154 | | - } |
4155 | | - |
4156 | | - if (cast<ArrayAttr>(attr.getValue()).size() != numResults) { |
4157 | | - return op->emitOpError("does not have expected number of output " |
4158 | | - "partition sets in attr ") |
4159 | | - << attr.getName() << "; should match number of results"; |
4160 | | - } |
4161 | | - |
4162 | | - // Verify that union of op output partitions is a subset of op partitions |
4163 | | - if (!op->hasAttr(kPartitionAttrName)) |
4164 | | - return op->emitOpError("does not have expected attribute ") |
4165 | | - << kPartitionAttrName << " which is expected for ops with attr " |
4166 | | - << kPartitionOutputsAttrName; |
4167 | | - auto partitionIds = getPartitionIds(op); |
4168 | | - |
4169 | | - SetVector<int> outputPartitionIdsUnion; |
4170 | | - for (auto outputPartitionIds : getPartitionOutputs(op)) { |
4171 | | - outputPartitionIdsUnion.insert(outputPartitionIds.begin(), |
4172 | | - outputPartitionIds.end()); |
4173 | | - } |
4174 | | - if (!std::all_of(outputPartitionIdsUnion.begin(), |
4175 | | - outputPartitionIdsUnion.end(), |
4176 | | - [&](int id) { return partitionIds.contains(id); })) { |
4177 | | - return op->emitOpError("partition ids in attr ") |
4178 | | - << kPartitionAttrName |
4179 | | - << " must be the union of all partition ids in " << attr.getName(); |
4180 | | - } |
4181 | | - } |
4182 | | - |
4183 | 4053 | return success(); |
4184 | 4054 | } |
4185 | 4055 |
|
@@ -4414,57 +4284,6 @@ SmallVector<int64_t> triton::gpu::getTMABlockShape( |
4414 | 4284 | mode); |
4415 | 4285 | } |
4416 | 4286 |
|
4417 | | -SetVector<int> triton::gpu::getPartitionIds(Operation *op) { |
4418 | | - auto attrs = op->getAttr(kPartitionAttrName); |
4419 | | - SmallVector<int> partitionIds; |
4420 | | - for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) { |
4421 | | - partitionIds.push_back(id); |
4422 | | - } |
4423 | | - std::sort(partitionIds.begin(), partitionIds.end()); |
4424 | | - return SetVector<int>(partitionIds.begin(), partitionIds.end()); |
4425 | | -} |
4426 | | - |
4427 | | -SmallVector<SetVector<int>, 4> triton::gpu::getPartitionOutputs(Operation *op) { |
4428 | | - SmallVector<SetVector<int>, 4> partitionOutputsIds; |
4429 | | - if (op->getNumResults() == 0) { |
4430 | | - return partitionOutputsIds; |
4431 | | - } |
4432 | | - assert(op->hasAttr(kPartitionOutputsAttrName)); |
4433 | | - auto arrayAttr = cast<ArrayAttr>(op->getAttr(kPartitionOutputsAttrName)); |
4434 | | - for (auto attr : arrayAttr) { |
4435 | | - auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef(); |
4436 | | - partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end())); |
4437 | | - } |
4438 | | - return partitionOutputsIds; |
4439 | | -} |
4440 | | - |
4441 | | -SetVector<int> triton::gpu::getPartitionIds(OpOperand *use) { |
4442 | | - auto owner = use->getOwner(); |
4443 | | - if (isa<scf::YieldOp>(owner)) { |
4444 | | - return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()]; |
4445 | | - } else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) { |
4446 | | - int idx = use->getOperandNumber() - forOp.getNumControlOperands(); |
4447 | | - return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp); |
4448 | | - } else { |
4449 | | - return getPartitionIds(owner); |
4450 | | - } |
4451 | | -} |
4452 | | - |
4453 | | -bool triton::gpu::hasPartition(Operation *op) { |
4454 | | - return op && op->hasAttr(kPartitionAttrName); |
4455 | | -} |
4456 | | - |
4457 | | -bool triton::gpu::hasWarpSpecializeTag(Operation *op) { |
4458 | | - return op && op->hasAttr(kWarpSpecializeTagAttrName); |
4459 | | -} |
4460 | | - |
4461 | | -std::optional<int> triton::gpu::getWarpSpecializeTag(Operation *op) { |
4462 | | - if (hasWarpSpecializeTag(op)) { |
4463 | | - return cast<IntegerAttr>(op->getAttr(kWarpSpecializeTagAttrName)).getInt(); |
4464 | | - } |
4465 | | - return std::nullopt; |
4466 | | -} |
4467 | | - |
4468 | 4287 | PaddedSharedEncodingAttr triton::gpu::getPaddedEncoding(Attribute encoding) { |
4469 | 4288 | if (!encoding) |
4470 | 4289 | return nullptr; |
|
0 commit comments