@@ -31,6 +31,7 @@ limitations under the License.
3131#include " mlir/IR/Value.h"
3232#include " mlir/Pass/Pass.h" // IWYU pragma: keep
3333#include " mlir/Support/LLVM.h"
34+ #include " shardy/common/logging.h"
3435#include " shardy/dialect/sdy/ir/dialect.h"
3536#include " shardy/dialect/sdy/ir/enums.h"
3637#include " shardy/dialect/sdy/ir/utils.h"
@@ -319,6 +320,34 @@ std::optional<Mesh> getMesh(ArrayRef<TensorShardingAttr> inShardings,
319320 /* defaultMesh=*/ Mesh (meshAttr, *meshName));
320321}
321322
323+ void insertAllReduceOnOpIfUnreducedToReplicated (
324+ Operation* op, IRRewriter& rewriter, const SymbolTable& symbolTable) {
325+ if (op->getResults ().empty ()) {
326+ auto operandHasUnreducedAxes = [&](OpOperand& operand) {
327+ TensorShardingAttr sharding = getSharding (operand.get ());
328+ return sharding && !sharding.getUnreducedAxes ().empty ();
329+ };
330+ SDY_CHECK (!llvm::any_of (op->getOpOperands (), operandHasUnreducedAxes))
331+ << " Some operands has unreduced axes but the operation has no "
332+ " results. " ;
333+ return ;
334+ }
335+
336+ // For each operand that has unreduced axes, insert an all-reduce if
337+ // any of the unreduced axes isn't unreduced in the target sharding.
338+ //
339+ // We assume all results of an op should have the same unreduced axes,
340+ // so we look at the first result.
341+ rewriter.setInsertionPoint (op);
342+ for (OpOperand& operand : op->getOpOperands ()) {
343+ if (TensorShardingAttr inSharding = getSharding (operand.get ())) {
344+ insertAllReduceIfUnreducedToReplicated (operand, inSharding,
345+ getSharding (op->getResult (0 )),
346+ symbolTable, rewriter);
347+ }
348+ }
349+ }
350+
322351struct InsertExplicitReshardsPass
323352 : public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
324353 using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;
@@ -361,20 +390,7 @@ struct InsertExplicitReshardsPass
361390 return ;
362391 }
363392
364- // For each operand that has unreduced axes, insert an all-reduce if
365- // any of the unreduced axes isn't unreduced in the target sharding.
366- //
367- // We assume all results of an op should have the same unreduced axes,
368- // so we look at the first result.
369- TensorShardingAttr outSharding =
370- op->getResults ().empty () ? nullptr : getSharding (op->getResult (0 ));
371- rewriter.setInsertionPoint (op);
372- for (OpOperand& operand : op->getOpOperands ()) {
373- if (TensorShardingAttr inSharding = getSharding (operand.get ())) {
374- insertAllReduceIfUnreducedToReplicated (
375- operand, inSharding, outSharding, symbolTable, rewriter);
376- }
377- }
393+ insertAllReduceOnOpIfUnreducedToReplicated (op, rewriter, symbolTable);
378394
379395 // NOTE: Creating a sharding rule requires data flow edges are present.
380396 OpShardingRuleAttr shardingRule =
0 commit comments