Skip to content

Commit 8973ec8

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy unflatenner up past SinkDataFlowEdges pass.
In the process, split out SinkFuncDataFlowEdges from SinkDataFlowEdges. It is because func arguments needs to get their shardings from func data flow edges in order for the shardy unflattener to deduplicate on funcs based on their input/output shardings properly. PiperOrigin-RevId: 908391477
1 parent 03f2b7b commit 8973ec8

14 files changed

Lines changed: 590 additions & 531 deletions

shardy/dialect/sdy/transforms/export/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ cc_library(
4848
"resolve_permutation_factors.cc",
4949
"sharding_constraint_to_reshard.cc",
5050
"sink_data_flow_edges.cc",
51+
"sink_func_data_flow_edges.cc",
5152
"unflatten_call_graph.cc",
5253
"update_non_divisible_input_output_shardings.cc",
5354
],

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,16 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
4040
const ExportOptions& options) {
4141
InsertExplicitReshardsPassOptions passOptions;
4242
passOptions.enableFullVersion = options.enableInsertExplicitCollectives;
43-
passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls;
4443
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass(passOptions));
4544
if (options.enableInsertExplicitCollectives) {
4645
pm.addPass(mlir::sdy::createSaveModuleOpPass(
4746
options.dumpDirectory, "after_explicit_reshards", dumpIndex++));
4847
addCanonicalizerPass(pm, kReshardLabel);
49-
pm.addPass(createUnflattenCallGraphPass(
50-
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
51-
// Keep a SymbolDCE after UnflattenCallGraph.
52-
pm.addPass(createSymbolDCEPass());
5348
pm.addNestedPass<func::FuncOp>(createReshardToCollectivesPass());
5449
// NOTE: ReshardToCollectives pass above generates all-slice collectives,
5550
// which during the canonicalizer below may be converted to reduce scatters
5651
// by potentially fusing with preceeding all-reduces, which are inserted
5752
// during InsertExplicitReshards pass.
58-
} else {
59-
pm.addPass(createUnflattenCallGraphPass(
60-
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
61-
// Keep a SymbolDCE after UnflattenCallGraph.
62-
pm.addPass(createSymbolDCEPass());
6353
}
6454

6555
addCanonicalizerPass(pm, kCollectiveLabel);
@@ -82,16 +72,22 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
8272
void addExportPipeline(OpPassManager& pm, int& dumpIndex,
8373
const ExportOptions& options) {
8474
pm.addNestedPass<func::FuncOp>(createConstantOrScalarMergerPass());
85-
if (!options.avoidExportForPartitioning) {
86-
pm.addPass(createRemoveShardingGroupsPass());
87-
pm.addNestedPass<func::FuncOp>(createShardingConstraintToReshardPass());
88-
}
75+
pm.addNestedPass<func::FuncOp>(createSinkFuncDataFlowEdgesPass());
76+
pm.addPass(createUnflattenCallGraphPass(
77+
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
78+
// Keep a SymbolDCE after UnflattenCallGraph.
79+
pm.addPass(createSymbolDCEPass());
8980
pm.addNestedPass<func::FuncOp>(
9081
createSinkDataFlowEdgesPass(SinkDataFlowEdgesPassOptions{
9182
/*sinkDebugShardingOrigins=*/options.dumpShardingOrigins,
9283
/*sinkDebugPropagationEdgeSharding=*/options.dumpPropagationEdges}));
84+
if (!options.avoidExportForPartitioning) {
85+
pm.addPass(createRemoveShardingGroupsPass());
86+
pm.addNestedPass<func::FuncOp>(createShardingConstraintToReshardPass());
87+
}
9388
if (options.updateNonDivisibleInputOutputShardings) {
9489
pm.addPass(createUpdateNonDivisibleInputOutputShardingsPass());
90+
9591
pm.addPass(createRemoveSubAxesInInputOutputShardingsPass());
9692
}
9793
pm.addPass(createCloseShardingsPass());
@@ -106,11 +102,6 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
106102
// reshards/collectives.
107103
if (!options.avoidExportForPartitioning) {
108104
runShardyPartitioner(pm, dumpIndex, options);
109-
} else {
110-
pm.addPass(createUnflattenCallGraphPass(
111-
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
112-
// Keep a SymbolDCE after UnflattenCallGraph.
113-
pm.addPass(createSymbolDCEPass());
114105
}
115106
if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
116107
pm.addPass(createRemovePropagationDebugInfoPass());

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,8 @@ struct InsertExplicitReshardsPass
495495
}
496496

497497
if (CallOp callOp = dyn_cast<CallOp>(op)) {
498-
if (!avoidReshardsOnCalls) {
499498
insertExplicitReshardsOnCallOp(callOp, rewriter, symbolTable,
500499
onFullVersion);
501-
}
502500
return;
503501
}
504502

shardy/dialect/sdy/transforms/export/passes.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> {
4343
];
4444
}
4545

46+
def SinkFuncDataFlowEdgesPass : Pass<"sdy-sink-func-data-flow-edges", "func::FuncOp"> {
47+
let summary = "Sinks all `FuncDataFlowEdgeOp` into their input.";
48+
let description = [{
49+
Moves the sharding of each `FuncDataFlowEdgeOp` to its input and replaces
50+
the op with its input.
51+
}];
52+
let dependentDialects = ["mlir::sdy::SdyDialect"];
53+
}
54+
4655
def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible-input-output-shardings", "ModuleOp"> {
4756
let summary = "Makes FuncOp inputs/outputs evenly sharded, removing any need for padding due to non-divisible shardings.";
4857
let description = [{
@@ -123,10 +132,6 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun
123132
Option<"enableFullVersion", "enable-full-version",
124133
"bool", /*default=*/"false",
125134
"Enable full version.">,
126-
Option<"avoidReshardsOnCalls",
127-
"avoid-reshards-on-calls",
128-
"bool", /*default=*/"false",
129-
"Avoid explicit reshards/collectives on calls.">
130135
];
131136
}
132137

shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,6 @@ struct SinkDataFlowEdgesPass
6767
rewriter.replaceOp(dataFlowEdgeOp, input);
6868
return WalkResult::skip();
6969
}
70-
if (isa<FuncDataFlowEdgeOp>(op)) {
71-
FuncDataFlowEdgeOp funcEdgeOp = cast<FuncDataFlowEdgeOp>(op);
72-
Value operand = funcEdgeOp.getOperand();
73-
Value result = funcEdgeOp.getResult();
74-
TensorShardingAttr operandSharding = getSharding(operand);
75-
if (TensorShardingAttr sharding = getSharding(result)) {
76-
setSharding(operand, sharding);
77-
} else if (operandSharding) {
78-
setSharding(operand,
79-
TensorShardingAttr::getFullyOpenLike(operandSharding));
80-
}
81-
rewriter.replaceOp(funcEdgeOp, operand);
82-
return WalkResult::skip();
83-
}
8470
auto shardableDataFlowOp = dyn_cast<ShardableDataFlowOpInterface>(op);
8571
if (!shardableDataFlowOp) {
8672
return WalkResult::advance();
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright 2024 The Shardy Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cassert>
17+
#include <memory> // IWYU pragma: keep
18+
#include <tuple>
19+
#include <utility>
20+
21+
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/ADT/SmallVector.h"
23+
#include "mlir/Dialect/Func/IR/FuncOps.h"
24+
#include "mlir/IR/Attributes.h"
25+
#include "mlir/IR/BuiltinAttributes.h"
26+
#include "mlir/IR/Operation.h"
27+
#include "mlir/IR/PatternMatch.h"
28+
#include "mlir/IR/Value.h"
29+
#include "mlir/IR/ValueRange.h"
30+
#include "mlir/IR/Visitors.h"
31+
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
32+
#include "mlir/Support/LLVM.h"
33+
#include "shardy/dialect/sdy/ir/constants.h"
34+
#include "shardy/dialect/sdy/ir/dialect.h"
35+
#include "shardy/dialect/sdy/ir/utils.h"
36+
#include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep
37+
#include "shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h"
38+
#include "shardy/dialect/sdy/transforms/propagation/utils.h"
39+
40+
namespace mlir {
41+
namespace sdy {
42+
43+
#define GEN_PASS_DEF_SINKFUNCDATAFLOWEDGESPASS
44+
#include "shardy/dialect/sdy/transforms/export/passes.h.inc"
45+
46+
namespace {
47+
48+
struct SinkFuncDataFlowEdgesPass
49+
: public impl::SinkFuncDataFlowEdgesPassBase<SinkFuncDataFlowEdgesPass> {
50+
using SinkFuncDataFlowEdgesPassBase::SinkFuncDataFlowEdgesPassBase;
51+
52+
void runOnOperation() final {
53+
func::FuncOp funcOp = getOperation();
54+
IRRewriter rewriter(funcOp);
55+
// Copy the sharding from data flow edges to the data flow ops.
56+
funcOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
57+
// Since we are doing the walk in preorder with a forward iterator, ops
58+
// are walked before their users and regions. Since `DataFlowEdgeOp` can
59+
// only appear inside the data flow op's region or as its user, we always
60+
// encounter the data flow op before their data flow edges. This means it
61+
// is safe to erase the `FuncDataFlowEdgeOp` at this point. We need the
62+
// skip at the end because it's a condition to erase the op. See the
63+
// documentation for `Operation::walk` for more details.
64+
if (isa<FuncDataFlowEdgeOp>(op)) {
65+
FuncDataFlowEdgeOp funcEdgeOp = cast<FuncDataFlowEdgeOp>(op);
66+
Value operand = funcEdgeOp.getOperand();
67+
Value result = funcEdgeOp.getResult();
68+
TensorShardingAttr operandSharding = getSharding(operand);
69+
if (TensorShardingAttr sharding = getSharding(result)) {
70+
setSharding(operand, sharding);
71+
} else if (operandSharding) {
72+
setSharding(operand,
73+
TensorShardingAttr::getFullyOpenLike(operandSharding));
74+
}
75+
rewriter.replaceOp(funcEdgeOp, operand);
76+
return WalkResult::skip();
77+
}
78+
return WalkResult::advance();
79+
});
80+
}
81+
};
82+
83+
} // namespace
84+
85+
} // namespace sdy
86+
} // namespace mlir

shardy/dialect/sdy/transforms/export/test/call_ops_avoid_reshards_on_calls_true.mlir

Lines changed: 0 additions & 126 deletions
This file was deleted.

0 commit comments

Comments
 (0)