@@ -24,12 +24,16 @@ limitations under the License.
2424#include " mlir/IR/OperationSupport.h"
2525#include " mlir/IR/PatternMatch.h"
2626#include " mlir/IR/SymbolTable.h"
27+ #include " mlir/IR/Value.h"
28+ #include " mlir/IR/ValueRange.h"
2729#include " mlir/Support/LLVM.h"
2830#include " mlir/Transforms/DialectConversion.h"
31+ #include " shardy/common/logging.h"
2932#include " shardy/dialect/sdy/ir/constants.h"
3033#include " shardy/dialect/sdy/ir/dialect.h"
3134#include " shardy/dialect/sdy/ir/utils.h"
3235#include " shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep
36+ #include " shardy/dialect/sdy/transforms/propagation/utils.h"
3337
3438namespace mlir {
3539namespace sdy {
@@ -42,6 +46,18 @@ namespace {
4246using func::CallOp;
4347using func::FuncOp;
4448
49+ void removeDataFlowEdges (ValueRange values, IRRewriter& rewriter) {
50+ for (Value value : values) {
51+ if (value.use_empty ()) {
52+ continue ;
53+ }
54+ if (auto dataFlowEdgeOp = dyn_cast<DataFlowEdgeOp>(*value.user_begin ())) {
55+ SDY_CHECK (value.hasOneUse ());
56+ rewriter.replaceOp (dataFlowEdgeOp, dataFlowEdgeOp.getInput ());
57+ }
58+ }
59+ }
60+
4561struct NamedComputationWithCount {
4662 NamedComputationOp namedComputationOp;
4763 int64_t callSiteCount;
@@ -66,6 +82,8 @@ StringAttr createFuncOp(
6682 inlineRegionAndConvertTerminatorOp<func::ReturnOp>(
6783 namedComputationOp.getBody (), funcOp.getBody ());
6884
85+ removeDataFlowEdges (funcOp.getArguments (), rewriter);
86+
6987 // Copy the input shardings to the func.
7088 if (inShardings.has_value ()) {
7189 for (auto [i, sharding] : llvm::enumerate (inShardings->getShardings ())) {
@@ -95,15 +113,35 @@ TensorShardingPerValueAttr getFullyClosedLike(
95113
96114void exportNamedComputations (ModuleOp moduleOp, SymbolTable& symbolTable) {
97115 Block& moduleBlock = moduleOp.getRegion ().front ();
116+ MLIRContext* ctx = moduleOp.getContext ();
117+ IRRewriter rewriter (moduleOp);
98118
99119 // NOTE: The walk needs to be in post order, which is the default order, to
100120 // account for nested named computations.
121+ SmallVector<Value> callResults;
101122 moduleOp.walk ([&](NamedComputationOp namedComputationOp) {
102- IRRewriter rewriter (namedComputationOp);
103123 rewriter.setInsertionPointToEnd (&moduleBlock);
104124
125+ // Propagate the shardings from the data flow edges to argument shardings.
126+ ArrayRef<BlockArgument> blockArgOwners =
127+ namedComputationOp.getBody ().getArguments ();
128+ if (SmallVector<TensorShardingAttr> blockArgShardings =
129+ getShardingsFromDataFlowEdges (blockArgOwners);
130+ !blockArgShardings.empty ()) {
131+ namedComputationOp.setInShardingsAttr (
132+ TensorShardingPerValueAttr::get (ctx, blockArgShardings));
133+ }
105134 std::optional<TensorShardingPerValueAttr> inShardings =
106135 namedComputationOp.getInShardings ();
136+
137+ // Propagate the shardings from the data flow edges to result shardings.
138+ ResultRange resultOwners = namedComputationOp.getResults ();
139+ if (SmallVector<TensorShardingAttr> resultShardings =
140+ getShardingsFromDataFlowEdges (resultOwners);
141+ !resultShardings.empty ()) {
142+ namedComputationOp.setOutShardingsAttr (
143+ TensorShardingPerValueAttr::get (ctx, resultShardings));
144+ }
107145 std::optional<TensorShardingPerValueAttr> outShardings =
108146 namedComputationOp.getOutShardings ();
109147
@@ -117,6 +155,7 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
117155 auto callOp = rewriter.replaceOpWithNewOp <CallOp>(
118156 namedComputationOp, namedComputationOp.getResultTypes (), funcSymName,
119157 namedComputationOp.getOperands ());
158+ llvm::append_range (callResults, callOp.getResults ());
120159 callOp->setAttrs (callOpAttrs);
121160 FuncOp funcOp = symbolTable.lookup <FuncOp>(funcSymName);
122161 // Copy the func output shardings to the call op.
@@ -128,6 +167,7 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
128167 : getFullyClosedLike (funcResultShardings));
129168 }
130169 });
170+ removeDataFlowEdges (callResults, rewriter);
131171}
132172
133173struct ExportNamedComputationsPass
@@ -138,8 +178,8 @@ struct ExportNamedComputationsPass
138178 void runOnOperation () final {
139179 ModuleOp moduleOp = getOperation ();
140180 SymbolTableCollection symbolTableCollection;
141-
142181 SymbolTable& symbolTable = symbolTableCollection.getSymbolTable (moduleOp);
182+
143183 exportNamedComputations (moduleOp, symbolTable);
144184 }
145185};
0 commit comments