Skip to content

Commit 0973248

Browse files
Reverts c139cc7
PiperOrigin-RevId: 804823355
1 parent 29eec94 commit 0973248

2 files changed

Lines changed: 31 additions & 14 deletions

File tree

shardy/dialect/sdy/transforms/export/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ cc_library(
4949
":explicit_reshards_util",
5050
":passes_inc",
5151
"//shardy/common:file_utils",
52+
"//shardy/common:logging",
5253
"//shardy/dialect/sdy/ir:axis_list_ref",
5354
"//shardy/dialect/sdy/ir:dialect",
5455
"//shardy/dialect/sdy/transforms/common:op_properties",

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "mlir/IR/Value.h"
3232
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
3333
#include "mlir/Support/LLVM.h"
34+
#include "shardy/common/logging.h"
3435
#include "shardy/dialect/sdy/ir/dialect.h"
3536
#include "shardy/dialect/sdy/ir/enums.h"
3637
#include "shardy/dialect/sdy/ir/utils.h"
@@ -319,6 +320,34 @@ std::optional<Mesh> getMesh(ArrayRef<TensorShardingAttr> inShardings,
319320
/*defaultMesh=*/Mesh(meshAttr, *meshName));
320321
}
321322

323+
void insertAllReduceOnOpIfUnreducedToReplicated(
324+
Operation* op, IRRewriter& rewriter, const SymbolTable& symbolTable) {
325+
if (op->getResults().empty()) {
326+
auto operandHasUnreducedAxes = [&](OpOperand& operand) {
327+
TensorShardingAttr sharding = getSharding(operand.get());
328+
return sharding && !sharding.getUnreducedAxes().empty();
329+
};
330+
SDY_CHECK(!llvm::any_of(op->getOpOperands(), operandHasUnreducedAxes))
331+
<< "Some operands has unreduced axes but the operation has no "
332+
"results. ";
333+
return;
334+
}
335+
336+
// For each operand that has unreduced axes, insert an all-reduce if
337+
// any of the unreduced axes isn't unreduced in the target sharding.
338+
//
339+
// We assume all results of an op should have the same unreduced axes,
340+
// so we look at the first result.
341+
rewriter.setInsertionPoint(op);
342+
for (OpOperand& operand : op->getOpOperands()) {
343+
if (TensorShardingAttr inSharding = getSharding(operand.get())) {
344+
insertAllReduceIfUnreducedToReplicated(operand, inSharding,
345+
getSharding(op->getResult(0)),
346+
symbolTable, rewriter);
347+
}
348+
}
349+
}
350+
322351
struct InsertExplicitReshardsPass
323352
: public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
324353
using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;
@@ -361,20 +390,7 @@ struct InsertExplicitReshardsPass
361390
return;
362391
}
363392

364-
// For each operand that has unreduced axes, insert an all-reduce if
365-
// any of the unreduced axes isn't unreduced in the target sharding.
366-
//
367-
// We assume all results of an op should have the same unreduced axes,
368-
// so we look at the first result.
369-
TensorShardingAttr outSharding =
370-
op->getResults().empty() ? nullptr : getSharding(op->getResult(0));
371-
rewriter.setInsertionPoint(op);
372-
for (OpOperand& operand : op->getOpOperands()) {
373-
if (TensorShardingAttr inSharding = getSharding(operand.get())) {
374-
insertAllReduceIfUnreducedToReplicated(
375-
operand, inSharding, outSharding, symbolTable, rewriter);
376-
}
377-
}
393+
insertAllReduceOnOpIfUnreducedToReplicated(op, rewriter, symbolTable);
378394

379395
// NOTE: Creating a sharding rule requires data flow edges are present.
380396
OpShardingRuleAttr shardingRule =

0 commit comments

Comments
 (0)