diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 8b73bdd7ea60b..9c5880e0c3b64 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -532,7 +532,8 @@ shardedBlockArgumentTypes(Block &block, block.getArguments(), std::back_inserter(res), [&symbolTableCollection](BlockArgument arg) { auto rankedTensorArg = dyn_cast>(arg); - if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) { + if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 || + rankedTensorArg.use_empty()) { return arg.getType(); } @@ -660,20 +661,22 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap, } // Check if the block args are correctly annotated with sharding information: -// - non-tensor and 0d-tensor args are ignored +// - non-tensor, 0d-tensor and unused args are ignored // - each tensor arg must have exactly one use, which must be a shard.shard -// operation +// operation static LogicalResult checkFullyAnnotated(Block &block) { for (const BlockArgument &arg : block.getArguments()) { auto rankedTensorArg = dyn_cast>(arg); - if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) + if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 || + rankedTensorArg.use_empty()) continue; - if (rankedTensorArg.getNumUses() > 1) + if (!rankedTensorArg.hasOneUse()) return emitError(block.getParent()->getLoc()) << "Cannot partition: expected a single use for block argument " << arg.getArgNumber() << " in block " << block.computeBlockNumber(); + Operation *useOp = *rankedTensorArg.getUsers().begin(); auto shardOp = dyn_cast(useOp); if (!shardOp) diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp index f954131ed7910..cff02d4f03143 100644 --- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp @@ -379,6 +379,10 @@ struct ShardingPropagation shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); }); + // Nothing to propagate if there is no sharding annotation in the block. + if (block.getOps().empty()) + return; + auto traverse = [&](auto &&range, OpBuilder &builder, const char *order) -> bool { for (Operation &op : range) {