@@ -39,10 +39,8 @@ limitations under the License.
3939#include " shardy/dialect/sdy/ir/dialect.h"
4040#include " shardy/dialect/sdy/ir/enums.h"
4141#include " shardy/dialect/sdy/ir/utils.h"
42- #include " shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h"
4342#include " shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
4443#include " shardy/dialect/sdy/transforms/propagation/utils.h"
45- #include " stablehlo/dialect/StablehloOps.h"
4644
4745namespace mlir {
4846namespace sdy {
@@ -127,7 +125,7 @@ void insertExplicitReshardsOnOperand(
127125 mesh.getContext (), shardingRule.getOperandMapping (operandIndex),
128126 shardingRule.getFactorSizes (), mesh.name (), mesh.attr ());
129127 auto reshardOp =
130- rewriter. create < ReshardOp>( operand.getLoc (), operand, newTensorSharding);
128+ ReshardOp::create (rewriter, operand.getLoc (), operand, newTensorSharding);
131129 op->setOperand (operandIndex, reshardOp);
132130}
133131
@@ -141,8 +139,8 @@ void insertExplicitReshardsOnResult(
141139 .createTensorShardingAttr (
142140 mesh.getContext (), shardingRule.getResultMapping (resultIndex),
143141 shardingRule.getFactorSizes (), mesh.name (), mesh.attr ());
144- auto reshardOp = rewriter. create < ReshardOp> (
145- result.getLoc (), result,
142+ auto reshardOp = ReshardOp::create (
143+ rewriter, result.getLoc (), result,
146144 getOrCreateSharding (result, mesh.name (), /* closedIfMissing=*/ true ));
147145 rewriter.replaceAllUsesExcept (result, reshardOp, reshardOp);
148146 setSharding (result, newTensorSharding);
@@ -248,8 +246,8 @@ void insertAllReduces(Operation* op, const AxesPerFactor& commonAxesPerFactor,
248246 if (allReduceAxes.empty ()) {
249247 continue ;
250248 }
251- auto allReduceOp = rewriter. create <AllReduceOp>(
252- result. getLoc (), result, allReduceAxes, resultSharding);
249+ auto allReduceOp = AllReduceOp::create ( rewriter, result. getLoc (), result,
250+ allReduceAxes, resultSharding);
253251 rewriter.replaceAllUsesExcept (result, allReduceOp, allReduceOp);
254252 }
255253}
@@ -439,11 +437,10 @@ class FactorAxesCandidateBag {
439437 }
440438 for (const auto & [factorIndex, _] :
441439 tensorFactorSharding.factorIndexToSharding ) {
442- int64_t candidateIndex = 0 ;
443- while (candidateIndex < size () ) {
440+ for ( int64_t candidateIndex = 0 ; candidateIndex < size () ;
441+ ++candidateIndex ) {
444442 updateSourceTensorSizeAt (factorIndex, candidateIndex,
445443 localTensorSize);
446- candidateIndex++;
447444 }
448445 }
449446 }
@@ -738,13 +735,6 @@ std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
738735 : lhs;
739736}
740737
741- TensorShardingAttr getShardingOfTensorIndex (
742- const int64_t tensorIndex, ArrayRef<TensorShardingAttr> inShardings,
743- ArrayRef<TensorShardingAttr> outShardings, const int64_t numOperands) {
744- return tensorIndex < numOperands ? inShardings[tensorIndex]
745- : outShardings[tensorIndex - numOperands];
746- }
747-
748738// Assumes that:
749739// 1. Either tensor does not have factors that need replication.
750740// 2. Both tensors have the same mesh but may have different device orders.
@@ -938,8 +928,9 @@ TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
938928 }
939929 TensorShardingAttr allReduceSharding =
940930 sourceSharding.replaceUnreducedAxes (targetUnreducedAxes);
941- auto allReduceOp = rewriter.create <AllReduceOp>(
942- use.get ().getLoc (), use.get (), allReduceAxes, allReduceSharding);
931+ auto allReduceOp =
932+ AllReduceOp::create (rewriter, use.get ().getLoc (), use.get (),
933+ allReduceAxes, allReduceSharding);
943934 use.set (allReduceOp);
944935 return allReduceSharding;
945936}
0 commit comments