Skip to content

Commit ad97da3

Browse files
Hard fail for mesh mismatches also when target unreduced axes is empty as it will add an all-reduce.
PiperOrigin-RevId: 805827132
1 parent b258a07 commit ad97da3

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,9 @@ TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
916916
if (userSharding) {
917917
targetUnreducedAxes = userSharding.getUnreducedAxes();
918918
// TODO(enver): Support the case the meshes differ only on device orders.
919-
SDY_CHECK(targetUnreducedAxes.empty() ||
920-
mesh.equals(userSharding.getMesh(symbolTable)))
919+
// NOTE: At this point, it is guaranteed that source unreduced axes is
920+
// non-empty.
921+
SDY_CHECK(mesh.equals(userSharding.getMesh(symbolTable)))
921922
<< "source and user shardings have different meshes for unreduced "
922923
"axes.";
923924
}

0 commit comments

Comments
 (0)