Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/Shard/Transforms/Partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ shardedBlockArgumentTypes(Block &block,
block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
rankedTensorArg.use_empty()) {
return arg.getType();
}

Expand Down Expand Up @@ -666,14 +667,16 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
static LogicalResult checkFullyAnnotated(Block &block) {
for (const BlockArgument &arg : block.getArguments()) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
rankedTensorArg.use_empty())
continue;

if (rankedTensorArg.getNumUses() > 1)
return emitError(block.getParent()->getLoc())
<< "Cannot partition: expected a single use for block argument "
<< arg.getArgNumber() << " in block "
<< block.computeBlockNumber();

Operation *useOp = *rankedTensorArg.getUsers().begin();
auto shardOp = dyn_cast<ShardOp>(useOp);
if (!shardOp)
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ struct ShardingPropagation
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});

// Nothing to propagate if there is no sharding annotation in the block.
if (block.getOps<shard::ShardOp>().empty())
return;

auto traverse = [&](auto &&range, OpBuilder &builder,
const char *order) -> bool {
for (Operation &op : range) {
Expand Down
Loading