Skip to content

Commit 2403602

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy outliner up past ConstantOrScalarMerger pass.
ConstantOrScalarMerger combines constants and constant-likes that are sharded the same way. It is not critical for final performance but it is important to keep the graph smaller. Hence, unlike ConstantOrScalarSplitter, it does not take NamedComputations/Funcs/Calls as parts of constant expressions. PiperOrigin-RevId: 900633218
1 parent 356dda9 commit 2403602

27 files changed

+701
-252
lines changed

shardy/dialect/sdy/ir/constants.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ inline const std::string kEmptyMeshSymbol = "empty_mesh";
8686
// Attribute name for the original name of the func before flattening.
8787
inline constexpr llvm::StringRef kOriginalFuncName = "sdy.original_func_name";
8888

89+
// Attribute name of the main func.
90+
inline constexpr llvm::StringRef kMainFuncName = "main";
91+
8992
} // namespace sdy
9093
} // namespace mlir
9194

shardy/dialect/sdy/ir/utils.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,5 +1115,32 @@ FuncOp getFuncOpOrDie(StringRef funcSymName, const SymbolTable& symbolTable) {
11151115
return funcOp;
11161116
}
11171117

1118+
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
1119+
Attribute meshOrRef) {
1120+
SmallVector<TensorShardingAttr> resultShardings;
1121+
resultShardings.reserve(values.size());
1122+
for (mlir::Value value : values) {
1123+
resultShardings.push_back(TensorShardingAttr::getFullyReplicated(
1124+
meshOrRef.getContext(), mlir::sdy::getTensorRank(value), meshOrRef,
1125+
/*isClosed=*/true));
1126+
}
1127+
return TensorShardingPerValueAttr::get(meshOrRef.getContext(),
1128+
resultShardings);
1129+
}
1130+
1131+
// Returns the main func. Dies if there is no main func.
1132+
FuncOp getMainFuncOrDie(ModuleOp moduleOp, SymbolTable& symbolTable,
1133+
bool useTheOneIfSingleFunc) {
1134+
if (useTheOneIfSingleFunc) {
1135+
auto funcOps = moduleOp.getOps<FuncOp>();
1136+
if (std::distance(funcOps.begin(), funcOps.end()) == 1) {
1137+
return *funcOps.begin();
1138+
}
1139+
}
1140+
FuncOp funcOp = symbolTable.lookup<FuncOp>(kMainFuncName);
1141+
SDY_CHECK(funcOp) << "Failed to lookup function: " << kMainFuncName.str();
1142+
return funcOp;
1143+
}
1144+
11181145
} // namespace sdy
11191146
} // namespace mlir

shardy/dialect/sdy/ir/utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,18 @@ mlir::func::FuncOp cloneFuncRecursively(func::FuncOp funcOp,
690690
func::FuncOp getFuncOpOrDie(StringRef funcSymName,
691691
const SymbolTable& symbolTable);
692692

693+
// Returns a `TensorShardingPerValueAttr` on the shardings of the `values`. If
694+
// the sharding of a value is null, it creates a fully closed sharding for it on
695+
// the given `meshOrRef` and the rank of the tensor corresponding to the value.
696+
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
697+
Attribute meshOrRef);
698+
699+
// Returns the main func. Dies if there is no main func. If
700+
// `useTheOneIfSingleFunc` is true, then first check if the module has only one
701+
// func, and assume it as the main func. Useful for tests.
702+
mlir::func::FuncOp getMainFuncOrDie(ModuleOp moduleOp, SymbolTable& symbolTable,
703+
bool useTheOneIfSingleFunc = false);
704+
693705
} // namespace sdy
694706
} // namespace mlir
695707

shardy/dialect/sdy/ir/utils_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ TEST_F(UtilsTest,
173173
std::nullopt);
174174
}
175175

176+
176177
TEST_F(UtilsTest, GetCommonMeshName_AllEmptyShardings) {
177178
EXPECT_EQ(getCommonMeshName(TensorShardingAttr(), TensorShardingAttr(),
178179
getSymbolTable(), /*ignoreDeviceIds=*/false),

shardy/dialect/sdy/transforms/common/propagation_options.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ struct PropagationOptions {
4848
// auto-partitioner will be invoked after propagation of user-specified
4949
// shardings.
5050
bool enableAutoPartitioning = false;
51-
// Whether to avoid explicit reshards/collectives on named computations.
51+
// Whether to avoid explicit reshards/collectives on named computations/calls.
52+
// TODO(enver): Rename to avoidReshardsOnCalls.
5253
bool avoidReshardsOnNamedComputations = false;
5354
// Whether to update axes with non-divisible input/output shardings.
5455
bool updateNonDivisibleInputOutputShardings = true;

shardy/dialect/sdy/transforms/export/export_named_computations.cc

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@ limitations under the License.
2424
#include "mlir/IR/OperationSupport.h"
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/IR/SymbolTable.h"
27+
#include "mlir/IR/Value.h"
28+
#include "mlir/IR/ValueRange.h"
2729
#include "mlir/Support/LLVM.h"
2830
#include "mlir/Transforms/DialectConversion.h"
31+
#include "shardy/common/logging.h"
2932
#include "shardy/dialect/sdy/ir/constants.h"
3033
#include "shardy/dialect/sdy/ir/dialect.h"
3134
#include "shardy/dialect/sdy/ir/utils.h"
3235
#include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep
36+
#include "shardy/dialect/sdy/transforms/propagation/utils.h"
3337

3438
namespace mlir {
3539
namespace sdy {
@@ -42,6 +46,18 @@ namespace {
4246
using func::CallOp;
4347
using func::FuncOp;
4448

49+
void removeDataFlowEdges(ValueRange values, IRRewriter& rewriter) {
50+
for (Value value : values) {
51+
if (value.use_empty()) {
52+
continue;
53+
}
54+
if (auto dataFlowEdgeOp = dyn_cast<DataFlowEdgeOp>(*value.user_begin())) {
55+
SDY_CHECK(value.hasOneUse());
56+
rewriter.replaceOp(dataFlowEdgeOp, dataFlowEdgeOp.getInput());
57+
}
58+
}
59+
}
60+
4561
struct NamedComputationWithCount {
4662
NamedComputationOp namedComputationOp;
4763
int64_t callSiteCount;
@@ -66,6 +82,8 @@ StringAttr createFuncOp(
6682
inlineRegionAndConvertTerminatorOp<func::ReturnOp>(
6783
namedComputationOp.getBody(), funcOp.getBody());
6884

85+
removeDataFlowEdges(funcOp.getArguments(), rewriter);
86+
6987
// Copy the input shardings to the func.
7088
if (inShardings.has_value()) {
7189
for (auto [i, sharding] : llvm::enumerate(inShardings->getShardings())) {
@@ -95,15 +113,35 @@ TensorShardingPerValueAttr getFullyClosedLike(
95113

96114
void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
97115
Block& moduleBlock = moduleOp.getRegion().front();
116+
MLIRContext* ctx = moduleOp.getContext();
117+
IRRewriter rewriter(moduleOp);
98118

99119
// NOTE: The walk needs to be in post order, which is the default order, to
100120
// account for nested named computations.
121+
SmallVector<Value> callResults;
101122
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
102-
IRRewriter rewriter(namedComputationOp);
103123
rewriter.setInsertionPointToEnd(&moduleBlock);
104124

125+
// Propagate the shardings from the data flow edges to argument shardings.
126+
ArrayRef<BlockArgument> blockArgOwners =
127+
namedComputationOp.getBody().getArguments();
128+
if (SmallVector<TensorShardingAttr> blockArgShardings =
129+
getShardingsFromDataFlowEdges(blockArgOwners);
130+
!blockArgShardings.empty()) {
131+
namedComputationOp.setInShardingsAttr(
132+
TensorShardingPerValueAttr::get(ctx, blockArgShardings));
133+
}
105134
std::optional<TensorShardingPerValueAttr> inShardings =
106135
namedComputationOp.getInShardings();
136+
137+
// Propagate the shardings from the data flow edges to result shardings.
138+
ResultRange resultOwners = namedComputationOp.getResults();
139+
if (SmallVector<TensorShardingAttr> resultShardings =
140+
getShardingsFromDataFlowEdges(resultOwners);
141+
!resultShardings.empty()) {
142+
namedComputationOp.setOutShardingsAttr(
143+
TensorShardingPerValueAttr::get(ctx, resultShardings));
144+
}
107145
std::optional<TensorShardingPerValueAttr> outShardings =
108146
namedComputationOp.getOutShardings();
109147

@@ -117,6 +155,7 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
117155
auto callOp = rewriter.replaceOpWithNewOp<CallOp>(
118156
namedComputationOp, namedComputationOp.getResultTypes(), funcSymName,
119157
namedComputationOp.getOperands());
158+
llvm::append_range(callResults, callOp.getResults());
120159
callOp->setAttrs(callOpAttrs);
121160
FuncOp funcOp = symbolTable.lookup<FuncOp>(funcSymName);
122161
// Copy the func output shardings to the call op.
@@ -128,6 +167,7 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
128167
: getFullyClosedLike(funcResultShardings));
129168
}
130169
});
170+
removeDataFlowEdges(callResults, rewriter);
131171
}
132172

133173
struct ExportNamedComputationsPass
@@ -138,8 +178,8 @@ struct ExportNamedComputationsPass
138178
void runOnOperation() final {
139179
ModuleOp moduleOp = getOperation();
140180
SymbolTableCollection symbolTableCollection;
141-
142181
SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp);
182+
143183
exportNamedComputations(moduleOp, symbolTable);
144184
}
145185
};

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
4040
const ExportOptions& options) {
4141
InsertExplicitReshardsPassOptions passOptions;
4242
passOptions.enableFullVersion = options.enableInsertExplicitCollectives;
43-
passOptions.avoidReshardsOnNamedComputations =
44-
options.avoidReshardsOnNamedComputations;
43+
passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls;
4544
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass(passOptions));
46-
pm.addPass(createExportNamedComputationsPass());
4745
if (options.enableInsertExplicitCollectives) {
4846
pm.addPass(mlir::sdy::createSaveModuleOpPass(
4947
options.dumpDirectory, "after_explicit_reshards", dumpIndex++));
@@ -72,6 +70,7 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
7270

7371
void addExportPipeline(OpPassManager& pm, int& dumpIndex,
7472
const ExportOptions& options) {
73+
pm.addPass(createExportNamedComputationsPass());
7574
pm.addNestedPass<func::FuncOp>(createConstantOrScalarMergerPass());
7675
if (!options.avoidExportForPartitioning) {
7776
pm.addPass(createRemoveShardingGroupsPass());
@@ -98,8 +97,6 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
9897
// reshards/collectives.
9998
if (!options.avoidExportForPartitioning) {
10099
runShardyPartitioner(pm, dumpIndex, options);
101-
} else {
102-
pm.addPass(createExportNamedComputationsPass());
103100
}
104101
if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
105102
pm.addPass(createRemovePropagationDebugInfoPass());

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ namespace sdy {
5151

5252
namespace {
5353

54+
using func::CallOp;
55+
using func::FuncOp;
56+
5457
void insertExplicitReshardsToTargetSharding(OpOperand& opOperand,
5558
TensorShardingAttr targetSharding,
5659
IRRewriter& rewriter,
@@ -102,22 +105,10 @@ void insertExplicitReshardsOnFuncReturn(Operation* op, func::FuncOp& funcOp,
102105
}
103106
}
104107

105-
void insertExplicitReshardsOnDataFlowOp(
106-
ShardableDataFlowOpInterface& op, IRRewriter& rewriter,
107-
const SymbolTable& symbolTable, const bool onFullVersion,
108-
const bool avoidReshardsOnNamedComputations) {
109-
if (isa<NamedComputationOp>(op) && avoidReshardsOnNamedComputations) {
110-
for (Value owner : op.getOpResultEdgeOwners()) {
111-
for (OpOperand* sourceOpOperand : op.getEdgeSources(owner)) {
112-
insertExplicitReshardsToTargetSharding(
113-
*sourceOpOperand,
114-
/*targetSharding=*/op.getEdgeOwnerSharding(owner), rewriter,
115-
symbolTable,
116-
/*insertAfterOperand=*/true, onFullVersion);
117-
}
118-
}
119-
return;
120-
}
108+
void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
109+
IRRewriter& rewriter,
110+
const SymbolTable& symbolTable,
111+
const bool onFullVersion) {
121112
for (Value owner : llvm::concat<Value>(op.getOpResultEdgeOwners(),
122113
op.getBlockArgumentEdgeOwners())) {
123114
TensorShardingAttr ownerSharding = op.transformTargetSharding(
@@ -132,6 +123,33 @@ void insertExplicitReshardsOnDataFlowOp(
132123
}
133124
}
134125

126+
void insertExplicitReshardsOnCallOp(CallOp callOp, IRRewriter& rewriter,
127+
const SymbolTable& symbolTable,
128+
const bool onFullVersion) {
129+
FuncOp funcOp = symbolTable.lookup<FuncOp>(callOp.getCallee());
130+
TensorShardingPerValueAttr funcArgShardings =
131+
mlir::sdy::getFuncArgShardings(funcOp, symbolTable);
132+
if (!funcArgShardings) {
133+
mlir::Attribute meshOrRef = getMeshOrRef(
134+
callOp.getNumOperands(), symbolTable,
135+
[&](int64_t i) { return getSharding(callOp.getOperand(i)); });
136+
// Return without inserting reshards as neither func arguments nor call
137+
// operands have a sharding with non-maximal mesh.
138+
if (!meshOrRef) {
139+
return;
140+
}
141+
funcArgShardings = getFullyClosedLike(callOp.getOperands(), meshOrRef);
142+
}
143+
rewriter.setInsertionPoint(callOp);
144+
for (auto [funcArgSharding, sourceOpOperand] : llvm::zip_equal(
145+
funcArgShardings.getShardings(), callOp->getOpOperands())) {
146+
insertExplicitReshardsToTargetSharding(
147+
sourceOpOperand,
148+
/*targetSharding=*/funcArgSharding, rewriter, symbolTable,
149+
/*insertAfterOperand=*/true, onFullVersion);
150+
}
151+
}
152+
135153
// Reshard the result of a dot operation if all the following hold:
136154
//
137155
// 1. LHS and RHS have fully compatible shardings.
@@ -382,7 +400,7 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
382400
}
383401
// To avoid copies of the same functions with mismatching shardings on the
384402
// arguments onto multiple callsites.
385-
if (isa<NamedComputationOp>(op)) {
403+
if (isa<func::CallOp>(op)) {
386404
return true;
387405
}
388406

@@ -472,8 +490,15 @@ struct InsertExplicitReshardsPass
472490
// TODO(enver): Prefer resharding the owner when multiple sources are
473491
// sharded in the same way.
474492
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter,
475-
symbolTable, onFullVersion,
476-
avoidReshardsOnNamedComputations);
493+
symbolTable, onFullVersion);
494+
return;
495+
}
496+
497+
if (CallOp callOp = dyn_cast<CallOp>(op)) {
498+
if (!avoidReshardsOnCalls) {
499+
insertExplicitReshardsOnCallOp(callOp, rewriter, symbolTable,
500+
onFullVersion);
501+
}
477502
return;
478503
}
479504

shardy/dialect/sdy/transforms/export/passes.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ struct ExportOptions : public PassPipelineOptions<ExportOptions> {
7676
llvm::cl::desc("Sink sdy.propagation_edges attr."),
7777
llvm::cl::init(false)};
7878

79-
Option<bool> avoidReshardsOnNamedComputations{
80-
*this, "avoid-reshards-on-named-computations",
81-
llvm::cl::desc("Avoid inserting explicit reshards/collectives for named "
82-
"computations."),
79+
Option<bool> avoidReshardsOnCalls{
80+
*this, "avoid-reshards-on-calls",
81+
llvm::cl::desc(
82+
"Avoid inserting explicit reshards/collectives for calls."),
8383
llvm::cl::init(false)};
8484

8585
Option<bool> updateNonDivisibleInputOutputShardings{

shardy/dialect/sdy/transforms/export/passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun
128128
Option<"enableFullVersion", "enable-full-version",
129129
"bool", /*default=*/"false",
130130
"Enable full version.">,
131-
Option<"avoidReshardsOnNamedComputations",
132-
"avoid-reshards-on-named-computations",
131+
Option<"avoidReshardsOnCalls",
132+
"avoid-reshards-on-calls",
133133
"bool", /*default=*/"false",
134-
"Avoid explicit reshards/collectives on named computations.">
134+
"Avoid explicit reshards/collectives on calls.">
135135
];
136136
}
137137

0 commit comments

Comments
 (0)