Skip to content

Commit 29eec94

Browse files
ZixuanJiangcopybara-github
authored andcommitted
Clean up includes in explicit reshards.
PiperOrigin-RevId: 804674568
1 parent c139cc7 commit 29eec94

3 files changed

Lines changed: 17 additions & 25 deletions

File tree

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@ limitations under the License.
3939
#include "shardy/dialect/sdy/ir/dialect.h"
4040
#include "shardy/dialect/sdy/ir/enums.h"
4141
#include "shardy/dialect/sdy/ir/utils.h"
42-
#include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h"
4342
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
4443
#include "shardy/dialect/sdy/transforms/propagation/utils.h"
45-
#include "stablehlo/dialect/StablehloOps.h"
4644

4745
namespace mlir {
4846
namespace sdy {
@@ -127,7 +125,7 @@ void insertExplicitReshardsOnOperand(
127125
mesh.getContext(), shardingRule.getOperandMapping(operandIndex),
128126
shardingRule.getFactorSizes(), mesh.name(), mesh.attr());
129127
auto reshardOp =
130-
rewriter.create<ReshardOp>(operand.getLoc(), operand, newTensorSharding);
128+
ReshardOp::create(rewriter, operand.getLoc(), operand, newTensorSharding);
131129
op->setOperand(operandIndex, reshardOp);
132130
}
133131

@@ -141,8 +139,8 @@ void insertExplicitReshardsOnResult(
141139
.createTensorShardingAttr(
142140
mesh.getContext(), shardingRule.getResultMapping(resultIndex),
143141
shardingRule.getFactorSizes(), mesh.name(), mesh.attr());
144-
auto reshardOp = rewriter.create<ReshardOp>(
145-
result.getLoc(), result,
142+
auto reshardOp = ReshardOp::create(
143+
rewriter, result.getLoc(), result,
146144
getOrCreateSharding(result, mesh.name(), /*closedIfMissing=*/true));
147145
rewriter.replaceAllUsesExcept(result, reshardOp, reshardOp);
148146
setSharding(result, newTensorSharding);
@@ -248,8 +246,8 @@ void insertAllReduces(Operation* op, const AxesPerFactor& commonAxesPerFactor,
248246
if (allReduceAxes.empty()) {
249247
continue;
250248
}
251-
auto allReduceOp = rewriter.create<AllReduceOp>(
252-
result.getLoc(), result, allReduceAxes, resultSharding);
249+
auto allReduceOp = AllReduceOp::create(rewriter, result.getLoc(), result,
250+
allReduceAxes, resultSharding);
253251
rewriter.replaceAllUsesExcept(result, allReduceOp, allReduceOp);
254252
}
255253
}
@@ -439,11 +437,10 @@ class FactorAxesCandidateBag {
439437
}
440438
for (const auto& [factorIndex, _] :
441439
tensorFactorSharding.factorIndexToSharding) {
442-
int64_t candidateIndex = 0;
443-
while (candidateIndex < size()) {
440+
for (int64_t candidateIndex = 0; candidateIndex < size();
441+
++candidateIndex) {
444442
updateSourceTensorSizeAt(factorIndex, candidateIndex,
445443
localTensorSize);
446-
candidateIndex++;
447444
}
448445
}
449446
}
@@ -738,13 +735,6 @@ std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
738735
: lhs;
739736
}
740737

741-
TensorShardingAttr getShardingOfTensorIndex(
742-
const int64_t tensorIndex, ArrayRef<TensorShardingAttr> inShardings,
743-
ArrayRef<TensorShardingAttr> outShardings, const int64_t numOperands) {
744-
return tensorIndex < numOperands ? inShardings[tensorIndex]
745-
: outShardings[tensorIndex - numOperands];
746-
}
747-
748738
// Assumes that:
749739
// 1. Either tensor does not have factors that need replication.
750740
// 2. Both tensors have the same mesh but may have different device orders.
@@ -938,8 +928,9 @@ TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
938928
}
939929
TensorShardingAttr allReduceSharding =
940930
sourceSharding.replaceUnreducedAxes(targetUnreducedAxes);
941-
auto allReduceOp = rewriter.create<AllReduceOp>(
942-
use.get().getLoc(), use.get(), allReduceAxes, allReduceSharding);
931+
auto allReduceOp =
932+
AllReduceOp::create(rewriter, use.get().getLoc(), use.get(),
933+
allReduceAxes, allReduceSharding);
943934
use.set(allReduceOp);
944935
return allReduceSharding;
945936
}

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ limitations under the License.
1717
#define SHARDY_DIALECT_SDY_TRANSFORMS_EXPORT_EXPLICIT_RESHARD_UTIL_H_
1818

1919
#include <cassert>
20-
#include <utility>
20+
#include <cstdint>
21+
#include <optional>
2122

2223
#include "mlir/IR/MLIRContext.h"
2324
#include "mlir/IR/Operation.h"

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222
#include "llvm/ADT/DenseMap.h"
2323
#include "llvm/ADT/STLExtras.h"
2424
#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep
25-
#include "mlir/Dialect/Func/IR/FuncOps.h"
25+
#include "mlir/IR/BuiltinAttributes.h"
2626
#include "mlir/IR/BuiltinOps.h"
2727
#include "mlir/IR/OpDefinition.h"
2828
#include "mlir/IR/Operation.h"
@@ -71,8 +71,8 @@ void insertExplicitReshardsToTargetSharding(OpOperand& opOperand,
7171

7272
if (onFullVersion && shouldReshard(operandSharding, targetSharding)) {
7373
operand = opOperand.get();
74-
auto reshardOp = rewriter.create<ReshardOp>(
75-
operand.getLoc(), operand,
74+
auto reshardOp = ReshardOp::create(
75+
rewriter, operand.getLoc(), operand,
7676
targetSharding
7777
? targetSharding
7878
// Since it should reshard and `targetSharding` is empty,
@@ -255,8 +255,8 @@ void processDot(OpTy op, ArrayRef<TensorShardingAttr> inShardings,
255255
op.getContext(), shardingRule.getResultMapping(0),
256256
shardingRule.getFactorSizes(), mesh.name(), mesh.attr()));
257257
rewriter.setInsertionPointAfter(op);
258-
auto reshardOp = rewriter.create<ReshardOp>(op.getLoc(), op.getResult(),
259-
outShardings.front());
258+
auto reshardOp = ReshardOp::create(rewriter, op.getLoc(), op.getResult(),
259+
outShardings.front());
260260
rewriter.replaceAllUsesExcept(op.getResult(), reshardOp, reshardOp);
261261
}
262262

0 commit comments

Comments
 (0)