Skip to content

Commit 86114ae

Browse files
Keep reshards that has non-equivalent input and output meshes.
InsertExplicitReshards pass do not insert reshards with non-equivalent input and output meshes. Still it is possible for reshard to collectives pass to have them from user sharding constraints. PiperOrigin-RevId: 813793497
1 parent b75252b commit 86114ae

2 files changed

Lines changed: 64 additions & 20 deletions

File tree

shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,19 @@ class CollectiveInserter {
13141314
AxisToDimAndIndex outAxisToDimAndIndex;
13151315
};
13161316

1317+
bool isEquivalentOnMesh(TensorShardingAttr inSharding,
1318+
TensorShardingAttr outSharding, ReshardOp reshardOp) {
1319+
if (isFullyReplicated(inSharding) || isFullyReplicated(outSharding)) {
1320+
return true;
1321+
}
1322+
if (inSharding.getMeshName() == outSharding.getMeshName()) {
1323+
return true;
1324+
}
1325+
MeshAttr inMesh = inSharding.getMesh(reshardOp);
1326+
MeshAttr outMesh = outSharding.getMesh(reshardOp);
1327+
return inMesh.equals(outMesh, /*ignoreDeviceOrder=*/true);
1328+
}
1329+
13171330
class ReshardPattern : public OpConversionPattern<ReshardOp> {
13181331
public:
13191332
using OpConversionPattern::OpConversionPattern;
@@ -1334,20 +1347,19 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
13341347
rewriter.replaceOp(op, adaptor.getInput());
13351348
return success();
13361349
}
1337-
MeshAttr inMesh = inSharding.getMesh(op);
13381350
if (inSharding.getMeshName() != outSharding.getMeshName()) {
1339-
MeshAttr outMesh = outSharding.getMesh(op);
1340-
// TODO(enver): Use MeshAttr::equals method instead.
1341-
if (outMesh.getAxes() != inMesh.getAxes() ||
1342-
inMesh.getDeviceIds() == outMesh.getDeviceIds() ||
1343-
(inSharding.isFullyReplicated() &&
1344-
outSharding.isFullyReplicated())) {
1345-
// We currently only support a reshard between different meshes if
1346-
// they have the same axes and different device ids, and at least one
1347-
// of the sharding isn't fully replicated.
1348-
return rewriter.notifyMatchFailure(
1349-
op, [](Diagnostic& diag) { diag << "Incompatible meshes"; });
1350-
}
1351+
if (outSharding.isFullyReplicated()) {
1352+
// TODO(enver): Hard fail if out sharding has unreduced axes.
1353+
outSharding = TensorShardingAttr::getFullyClosedLike(inSharding);
1354+
}
1355+
} else {
1356+
if (!isEquivalentOnMesh(inSharding, outSharding, op)) {
1357+
// We currently only support a reshard between different meshes if
1358+
// they have the same axes and different device ids, and at least one
1359+
// of the sharding isn't fully replicated.
1360+
return rewriter.notifyMatchFailure(
1361+
op, [](Diagnostic& diag) { diag << "Incompatible meshes"; });
1362+
}
13511363
}
13521364

13531365
// TODO(tomnatan): we should verify that the operand of ReshardOp has a
@@ -1370,13 +1382,12 @@ struct ReshardToCollectivesPass
13701382
target = std::make_shared<ConversionTarget>(*context);
13711383
target->addLegalOp<AllGatherOp, AllSliceOp, AllToAllOp,
13721384
CollectivePermuteOp>();
1373-
if (keepRedundantReshards) {
1374-
target->addDynamicallyLegalOp<ReshardOp>([](ReshardOp op) {
1375-
return isEquivalent(getSharding(op.getInput()), op.getSharding());
1376-
});
1377-
} else {
1378-
target->addIllegalOp<ReshardOp>();
1379-
}
1385+
target->addDynamicallyLegalOp<ReshardOp>([&](ReshardOp op) {
1386+
TensorShardingAttr inSharding = getSharding(op.getInput());
1387+
TensorShardingAttr outSharding = op.getSharding();
1388+
return (keepRedundantReshards && isEquivalent(inSharding, outSharding)) ||
1389+
!isEquivalentOnMesh(inSharding, outSharding, op);
1390+
});
13801391

13811392
RewritePatternSet patternsInternal(context);
13821393
patternsInternal.add<ReshardPattern>(context);

shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ sdy.mesh @empty_mesh = <[]>
2121
sdy.mesh @empty_mesh_another = <[]>
2222

2323

24+
2425
// CHECK-LABEL: func @redundant_reshard_fully_replicated
2526
func.func @redundant_reshard_fully_replicated(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{}, {}]>}) -> tensor<16x8xf32> {
2627
// CHECK-NEXT: return %arg0
@@ -70,6 +71,38 @@ func.func @redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.shardin
7071
return %0 : tensor<16x8xf32>
7172
}
7273

74+
// CHECK-LABEL: func @reshard_from_sharded_to_fully_replicated_same_meshes
75+
func.func @reshard_from_sharded_to_fully_replicated_same_meshes(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> tensor<24x8xf32> {
76+
// CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x"}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]>
77+
// CHECK-NEXT: return %[[ALL_GATHER]]
78+
%0 = sdy.reshard %arg0 <@mesh1d_6, [{}, {}]> : tensor<24x8xf32>
79+
return %0 : tensor<24x8xf32>
80+
}
81+
82+
// CHECK-LABEL: func @reshard_from_sharded_to_fully_replicated_different_meshes
83+
func.func @reshard_from_sharded_to_fully_replicated_different_meshes(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> tensor<24x8xf32> {
84+
// CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x"}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]>
85+
// CHECK-NEXT: return %[[ALL_GATHER]]
86+
%0 = sdy.reshard %arg0 <@mesh2d_2x3, [{}, {}]> : tensor<24x8xf32>
87+
return %0 : tensor<24x8xf32>
88+
}
89+
90+
// CHECK-LABEL: func @reshard_from_sharded_to_fully_replicated_different_meshes_with_different_device_counts
91+
func.func @reshard_from_sharded_to_fully_replicated_different_meshes_with_different_device_counts(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> tensor<24x8xf32> {
92+
// CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x"}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]>
93+
// CHECK-NEXT: return %[[ALL_GATHER]]
94+
%0 = sdy.reshard %arg0 <@mesh2d, [{}, {}]> : tensor<24x8xf32>
95+
return %0 : tensor<24x8xf32>
96+
}
97+
98+
// CHECK-LABEL: func @reshard_from_sharded_to_sharded_different_meshes
99+
func.func @reshard_from_sharded_to_sharded_different_meshes(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> (tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d_2x3, [{"x"}, {}]>}) {
100+
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]>
101+
// CHECK-NEXT: return %[[RESHARD]]
102+
%0 = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]> : tensor<24x8xf32>
103+
return %0 : tensor<24x8xf32>
104+
}
105+
73106
// CHECK-LABEL: func @all_gather_single_axis
74107
func.func @all_gather_single_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> {
75108
// CHECK-NEXT: sdy.all_gather [{}, {"x"}] %arg0 out_sharding=<@mesh2d, [{"y"}, {}]>

0 commit comments

Comments
 (0)