Skip to content

[mlir][shard] Hardening sharding propagation and partitioning#183028

Merged
fschlimb merged 2 commits intollvm:mainfrom
fschlimb:harden-shard
Feb 24, 2026
Merged

[mlir][shard] Hardening sharding propagation and partitioning#183028
fschlimb merged 2 commits intollvm:mainfrom
fschlimb:harden-shard

Conversation

@fschlimb
Copy link
Contributor

Simple fixes to avoid failures.

@llvmbot
Copy link
Member

llvmbot commented Feb 24, 2026

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

Simple fixes to avoid failures.


Full diff: https://github.com/llvm/llvm-project/pull/183028.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+5-2)
  • (modified) mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp (+4)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 8b73bdd7ea60b..74fc59c4f6d14 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -532,7 +532,8 @@ shardedBlockArgumentTypes(Block &block,
       block.getArguments(), std::back_inserter(res),
       [&symbolTableCollection](BlockArgument arg) {
         auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
-        if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
+        if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
+            rankedTensorArg.use_empty()) {
           return arg.getType();
         }
 
@@ -666,7 +667,8 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
 static LogicalResult checkFullyAnnotated(Block &block) {
   for (const BlockArgument &arg : block.getArguments()) {
     auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
-    if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
+    if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
+        rankedTensorArg.getNumUses() < 1)
       continue;
 
     if (rankedTensorArg.getNumUses() > 1)
@@ -674,6 +676,7 @@ static LogicalResult checkFullyAnnotated(Block &block) {
              << "Cannot partition: expected a single use for block argument "
              << arg.getArgNumber() << " in block "
              << block.computeBlockNumber();
+
     Operation *useOp = *rankedTensorArg.getUsers().begin();
     auto shardOp = dyn_cast<ShardOp>(useOp);
     if (!shardOp)
diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index f954131ed7910..cff02d4f03143 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -379,6 +379,10 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
+    // Nothing to propagate if there is no sharding annotation in the block.
+    if (block.getOps<shard::ShardOp>().empty())
+      return;
+
     auto traverse = [&](auto &&range, OpBuilder &builder,
                         const char *order) -> bool {
       for (Operation &op : range) {

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds defensive checks to prevent failures in MLIR's Shard dialect transformation passes. The changes add early returns and guards to handle edge cases where sharding operations or uses are absent, preventing potential assertion failures or crashes during sharding propagation and partitioning.

Changes:

  • Added early return in ShardingPropagation when no ShardOp annotations exist in the block
  • Added checks to skip unused block arguments in both shardedBlockArgumentTypes and checkFullyAnnotated functions
  • Minor formatting improvement with blank line addition

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp Adds early return optimization when block has no ShardOp operations, avoiding unnecessary traversal
mlir/lib/Dialect/Shard/Transforms/Partition.cpp Adds defensive checks to handle unused block arguments in type deduction and validation logic, plus formatting cleanup

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fschlimb fschlimb merged commit f1bf37e into llvm:main Feb 24, 2026
10 checks passed
@fschlimb
Copy link
Contributor Author

Will enable llvm/lighthouse#57

tudinhh pushed a commit to tudinhh/llvm-project that referenced this pull request Feb 26, 2026
HendrikHuebner pushed a commit to HendrikHuebner/llvm-project that referenced this pull request Mar 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants