Skip to content

Commit a6b9848

Browse files
ekayaaslancopybara-github
authored andcommitted
Add a pass to reorder functions in pre-order of the call graph.
This change introduces a new pass to reorder the functions within an MLIR module. The functions are arranged such that callers appear before their callees, following a pre-order traversal of the call graph. This ordering keeps the initial population of the worklist the same on the greedy rewriter driver during the propagation between inlined (named computations) and non-inlined (func/calls). PiperOrigin-RevId: 901743552
1 parent 58c5fcb commit a6b9848

7 files changed

Lines changed: 108 additions & 9 deletions

File tree

shardy/dialect/sdy/ir/utils.cc

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,16 +1012,29 @@ bool walkCalls(ModuleOp moduleOp, ProcessCallOpFn processCallOp,
10121012
return true;
10131013
}
10141014

1015-
void iterateFuncs(ModuleOp moduleOp, ProcessFuncOpFn processFuncOp) {
1015+
void iterateFuncs(ModuleOp moduleOp, ProcessFuncOpFn processFuncOp,
1016+
bool preOrder) {
10161017
CallGraph callGraph(moduleOp);
10171018
llvm::ReversePostOrderTraversal<const CallGraph*> rpo(&callGraph);
1018-
for (CallGraphNode* node : llvm::reverse(rpo)) {
1019-
if (node->isExternal()) {
1020-
continue;
1019+
if (preOrder) {
1020+
for (CallGraphNode* node : rpo) {
1021+
if (node->isExternal()) {
1022+
continue;
1023+
}
1024+
mlir::Region* region = node->getCallableRegion();
1025+
if (FuncOp funcOp = dyn_cast_or_null<FuncOp>(region->getParentOp())) {
1026+
processFuncOp(funcOp);
1027+
}
10211028
}
1022-
mlir::Region* region = node->getCallableRegion();
1023-
if (FuncOp funcOp = dyn_cast_or_null<FuncOp>(region->getParentOp())) {
1024-
processFuncOp(funcOp);
1029+
} else {
1030+
for (CallGraphNode* node : llvm::reverse(rpo)) {
1031+
if (node->isExternal()) {
1032+
continue;
1033+
}
1034+
mlir::Region* region = node->getCallableRegion();
1035+
if (FuncOp funcOp = dyn_cast_or_null<FuncOp>(region->getParentOp())) {
1036+
processFuncOp(funcOp);
1037+
}
10251038
}
10261039
}
10271040
}

shardy/dialect/sdy/ir/utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,9 +671,11 @@ bool walkCalls(ModuleOp moduleOp, ProcessCallOpFn processCallOp,
671671
// Iterates on the funcs and performs `processFuncOp` on funcs. Iterates on the
672672
// funcs and blocks in post order of the call graph by default, that is, the
673673
// functions are processed before their callers, and child blocks are processed
674-
// before their parents.
674+
// before their parents. Iterates funcs and blocks in pre order if `preOrder` is
675+
// true.
675676
using ProcessFuncOpFn = std::function<void(func::FuncOp)>;
676-
void iterateFuncs(ModuleOp moduleOp, ProcessFuncOpFn processFuncOp);
677+
void iterateFuncs(ModuleOp moduleOp, ProcessFuncOpFn processFuncOp,
678+
bool preOrder = false);
677679

678680
// Returns the reduction operation used in the scatter's update computation if
679681
// it is a recognized associative and commutative binary op applied to all

shardy/dialect/sdy/transforms/import/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ cc_library(
4242
"inline_meshes.cc",
4343
"lift_inlined_meshes.cc",
4444
"manual_axes_cleanup.cc",
45+
"pre_order_funcs.cc",
4546
"propagate_sharding_from_func_to_call.cc",
4647
"remove_size_one_axes.cc",
4748
"sharding_group_import.cc",

shardy/dialect/sdy/transforms/import/passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ def ConstantOrScalarSplitterPass : Pass<"sdy-constant-or-scalar-splitter", "Modu
197197
let dependentDialects = ["mlir::sdy::SdyDialect"];
198198
}
199199

200+
def PreOrderFuncsPass : Pass<"sdy-pre-order-funcs", "ModuleOp"> {
201+
let summary = "Reorders functions in the module in pre-order of the call graph.";
202+
let description = [{
203+
Reorders functions in the module in pre-order of the call graph.
204+
This is useful when we inline func/calls, as propagation is top-down on blocks,
205+
and a pre-order iteration on funcs will emulate it.
206+
}];
207+
let dependentDialects = ["mlir::sdy::SdyDialect"];
208+
}
209+
200210
def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> {
201211
let summary = "Canonicalization and validation pass for sharding groups.";
202212
let description = [{
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright 2026 The Shardy Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <vector>
17+
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/IR/Block.h"
20+
#include "mlir/IR/BuiltinOps.h"
21+
#include "shardy/dialect/sdy/ir/utils.h"
22+
#include "shardy/dialect/sdy/transforms/import/passes.h" // IWYU pragma: keep
23+
24+
namespace mlir {
25+
namespace sdy {
26+
27+
#define GEN_PASS_DEF_PREORDERFUNCSPASS
28+
#include "shardy/dialect/sdy/transforms/import/passes.h.inc"
29+
30+
namespace {
31+
32+
struct PreOrderFuncsPass
33+
: public impl::PreOrderFuncsPassBase<PreOrderFuncsPass> {
34+
using PreOrderFuncsPassBase::PreOrderFuncsPassBase;
35+
36+
void runOnOperation() override {
37+
ModuleOp moduleOp = getOperation();
38+
std::vector<func::FuncOp> funcsInPreOrder;
39+
iterateFuncs(
40+
moduleOp,
41+
[&](func::FuncOp funcOp) { funcsInPreOrder.push_back(funcOp); },
42+
/*preOrder=*/true);
43+
mlir::Block& body = moduleOp.getBodyRegion().front();
44+
for (func::FuncOp funcOp : funcsInPreOrder) {
45+
funcOp->moveBefore(&body, body.end());
46+
}
47+
}
48+
};
49+
50+
} // namespace
51+
52+
} // namespace sdy
53+
} // namespace mlir
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: sdy_opt %s -sdy-pre-order-funcs | FileCheck %s
2+
3+
// CHECK: func.func @main
4+
// CHECK: func.func @func1
5+
// CHECK: func.func @func2
6+
7+
func.func @func2() {
8+
return
9+
}
10+
11+
func.func @func1() {
12+
func.call @func2() : () -> ()
13+
return
14+
}
15+
16+
func.func @main() {
17+
func.call @func1() : () -> ()
18+
return
19+
}

shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ void addPropagationPipeline(OpPassManager& pm, int& dumpIndex,
6464
optionsWithKeepShardingRules.keepShardingRules = true;
6565
// We intentionally don't increment the dump index here, since this pass
6666
// might dump 0 to multiple files, and will use a nested dump index.
67+
pm.addPass(createPreOrderFuncsPass());
6768
pm.addPass(createUserPriorityPropagationPass(optionsWithKeepShardingRules,
6869
dumpIndex));
6970
}

0 commit comments

Comments
 (0)