@@ -1314,6 +1314,19 @@ class CollectiveInserter {
13141314 AxisToDimAndIndex outAxisToDimAndIndex;
13151315};
13161316
1317+ bool isEquivalentOnMesh (TensorShardingAttr inSharding,
1318+ TensorShardingAttr outSharding, ReshardOp reshardOp) {
1319+ if (isFullyReplicated (inSharding) || isFullyReplicated (outSharding)) {
1320+ return true ;
1321+ }
1322+ if (inSharding.getMeshName () == outSharding.getMeshName ()) {
1323+ return true ;
1324+ }
1325+ MeshAttr inMesh = inSharding.getMesh (reshardOp);
1326+ MeshAttr outMesh = outSharding.getMesh (reshardOp);
1327+ return inMesh.equals (outMesh, /* ignoreDeviceOrder=*/ true );
1328+ }
1329+
13171330class ReshardPattern : public OpConversionPattern <ReshardOp> {
13181331 public:
13191332 using OpConversionPattern::OpConversionPattern;
@@ -1334,20 +1347,19 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
13341347 rewriter.replaceOp (op, adaptor.getInput ());
13351348 return success ();
13361349 }
1337- MeshAttr inMesh = inSharding.getMesh (op);
13381350 if (inSharding.getMeshName () != outSharding.getMeshName ()) {
1339- MeshAttr outMesh = outSharding.getMesh (op);
1340- // TODO(enver): Use MeshAttr::equals method instead .
1341- if (outMesh. getAxes () != inMesh. getAxes () ||
1342- inMesh. getDeviceIds () == outMesh. getDeviceIds () ||
1343- (inSharding. isFullyReplicated () &&
1344- outSharding. isFullyReplicated () )) {
1345- // We currently only support a reshard between different meshes if
1346- // they have the same axes and different device ids, and at least one
1347- // of the sharding isn't fully replicated.
1348- return rewriter.notifyMatchFailure (
1349- op, [](Diagnostic& diag) { diag << " Incompatible meshes" ; });
1350- }
1351+ if ( outSharding.isFullyReplicated ()) {
1352+ // TODO(enver): Hard fail if out sharding has unreduced axes .
1353+ outSharding = TensorShardingAttr::getFullyClosedLike (inSharding);
1354+ }
1355+ } else {
1356+ if (! isEquivalentOnMesh (inSharding, outSharding, op )) {
1357+ // We currently only support a reshard between different meshes if
1358+ // they have the same axes and different device ids, and at least one
1359+ // of the sharding isn't fully replicated.
1360+ return rewriter.notifyMatchFailure (
1361+ op, [](Diagnostic& diag) { diag << " Incompatible meshes" ; });
1362+ }
13511363 }
13521364
13531365 // TODO(tomnatan): we should verify that the operand of ReshardOp has a
@@ -1370,13 +1382,12 @@ struct ReshardToCollectivesPass
13701382 target = std::make_shared<ConversionTarget>(*context);
13711383 target->addLegalOp <AllGatherOp, AllSliceOp, AllToAllOp,
13721384 CollectivePermuteOp>();
1373- if (keepRedundantReshards) {
1374- target->addDynamicallyLegalOp <ReshardOp>([](ReshardOp op) {
1375- return isEquivalent (getSharding (op.getInput ()), op.getSharding ());
1376- });
1377- } else {
1378- target->addIllegalOp <ReshardOp>();
1379- }
1385+ target->addDynamicallyLegalOp <ReshardOp>([&](ReshardOp op) {
1386+ TensorShardingAttr inSharding = getSharding (op.getInput ());
1387+ TensorShardingAttr outSharding = op.getSharding ();
1388+ return (keepRedundantReshards && isEquivalent (inSharding, outSharding)) ||
1389+ !isEquivalentOnMesh (inSharding, outSharding, op);
1390+ });
13801391
13811392 RewritePatternSet patternsInternal (context);
13821393 patternsInternal.add <ReshardPattern>(context);
0 commit comments