Skip to content

Commit 2dd4be9

Browse files
ekayaaslancopybara-github
authored andcommitted
Add utility to walk funcs.
Also simplify the utility to walk calls. Add few more utilities: - getOriginalFuncName - getFuncOpOrDie PiperOrigin-RevId: 900202519
1 parent 56477a9 commit 2dd4be9

File tree

3 files changed

+82
-51
lines changed

3 files changed

+82
-51
lines changed

shardy/dialect/sdy/ir/utils.cc

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,11 @@ StringAttr getOriginalFuncName(FuncOp funcOp) {
912912
return funcOp.getSymNameAttr();
913913
}
914914

915+
StringAttr getOriginalFuncName(CallOp callOp, const SymbolTable& symbolTable) {
916+
FuncOp funcOp = getFuncOpOrDie(callOp.getCallee(), symbolTable);
917+
return getOriginalFuncName(funcOp);
918+
}
919+
915920
mlir::Attribute getMeshOrRef(
916921
int64_t numElements, const SymbolTable& symbolTable,
917922
std::function<TensorShardingAttr(int64_t)> getSharding) {
@@ -974,9 +979,8 @@ TensorShardingPerValueAttr getFuncResultShardings(
974979
return TensorShardingPerValueAttr::get(funcOp.getContext(), resultShardings);
975980
}
976981

977-
std::optional<FuncOp> walkCalls(ModuleOp moduleOp,
978-
ProcessCallOpFn processCallOp, bool preOrder) {
979-
FuncOp mainFuncOp;
982+
bool walkCalls(ModuleOp moduleOp, ProcessCallOpFn processCallOp,
983+
bool preOrder) {
980984
CallGraph callGraph(moduleOp);
981985
llvm::ReversePostOrderTraversal<const CallGraph*> rpo(&callGraph);
982986
if (preOrder) { // Iterate pre-order.
@@ -985,16 +989,11 @@ std::optional<FuncOp> walkCalls(ModuleOp moduleOp,
985989
continue;
986990
}
987991
mlir::Region* region = node->getCallableRegion();
988-
// The first func is the main one as it is in pre-order.
989-
if (auto funcOp = dyn_cast_or_null<FuncOp>(region->getParentOp());
990-
!mainFuncOp && funcOp) {
991-
mainFuncOp = funcOp;
992-
}
993992
if (region
994993
->walk<WalkOrder::PreOrder>(
995994
[&](CallOp callOp) { return processCallOp(callOp); })
996995
.wasInterrupted()) {
997-
return std::nullopt;
996+
return false;
998997
}
999998
}
1000999
} else {
@@ -1004,25 +1003,27 @@ std::optional<FuncOp> walkCalls(ModuleOp moduleOp,
10041003
continue;
10051004
}
10061005
mlir::Region* region = node->getCallableRegion();
1007-
// The last func is the main one as it is in post-order.
1008-
if (auto funcOp = dyn_cast_or_null<FuncOp>(region->getParentOp());
1009-
funcOp) {
1010-
mainFuncOp = funcOp;
1011-
}
10121006
if (region->walk([&](CallOp callOp) { return processCallOp(callOp); })
10131007
.wasInterrupted()) {
1014-
return std::nullopt;
1008+
return false;
10151009
}
10161010
}
10171011
}
1018-
return mainFuncOp;
1012+
return true;
10191013
}
10201014

1021-
FuncOp walkCallsOrDie(ModuleOp moduleOp, ProcessCallOpFn processCallOp,
1022-
bool preOrder) {
1023-
auto mainFuncOp = walkCalls(moduleOp, processCallOp, preOrder);
1024-
SDY_CHECK(mainFuncOp);
1025-
return *mainFuncOp;
1015+
void iterateFuncs(ModuleOp moduleOp, ProcessFuncOpFn processFuncOp) {
1016+
CallGraph callGraph(moduleOp);
1017+
llvm::ReversePostOrderTraversal<const CallGraph*> rpo(&callGraph);
1018+
for (CallGraphNode* node : llvm::reverse(rpo)) {
1019+
if (node->isExternal()) {
1020+
continue;
1021+
}
1022+
mlir::Region* region = node->getCallableRegion();
1023+
if (FuncOp funcOp = dyn_cast_or_null<FuncOp>(region->getParentOp())) {
1024+
processFuncOp(funcOp);
1025+
}
1026+
}
10261027
}
10271028

10281029
Operation* getCommonSupportedReductionOp(stablehlo::ScatterOp scatter) {
@@ -1108,5 +1109,11 @@ FuncOp cloneFuncRecursively(FuncOp funcOp, SymbolTable& symbolTable) {
11081109
return clonedFuncOp;
11091110
}
11101111

1112+
FuncOp getFuncOpOrDie(StringRef funcSymName, const SymbolTable& symbolTable) {
1113+
FuncOp funcOp = symbolTable.lookup<FuncOp>(funcSymName);
1114+
SDY_CHECK(funcOp) << "Failed to lookup function: " << funcSymName.str();
1115+
return funcOp;
1116+
}
1117+
11111118
} // namespace sdy
11121119
} // namespace mlir

shardy/dialect/sdy/ir/utils.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,11 @@ DenseIntElementsAttr getReplicaGroups(AxisRefListAttr reductionAxesAttr,
642642
// case there is no such attribute attached, create one on the name of `funcOp`.
643643
StringAttr getOriginalFuncName(func::FuncOp funcOp);
644644

645+
// Gets `kOriginalFuncName` attribute attached to the func of `callOp`. In case
646+
// there is no such attribute attached, create one on the name of the func.
647+
StringAttr getOriginalFuncName(func::CallOp callOp,
648+
const SymbolTable& symbolTable);
649+
645650
// Returns the shardings for the arguments of `funcOp`, with fully replicated
646651
// shardings for empty shardings on `funcOp`.
647652
TensorShardingPerValueAttr getFuncArgShardings(func::FuncOp funcOp,
@@ -657,16 +662,18 @@ TensorShardingPerValueAttr getFuncResultShardings(
657662
// functions are processed before their callers, and child blocks are processed
658663
// before their parents. Iterates calls and blocks in pre order if `preOrder` is
659664
// true, that is, the functions are processed after their callers, and child
660-
// blocks are processed after their parents. Returns nullopt if the walk was
661-
// interrupted, returns the main func otherwise.
665+
// blocks are processed after their parents. Returns false if the walk was
666+
// interrupted, returns true otherwise.
662667
using ProcessCallOpFn = std::function<mlir::WalkResult(func::CallOp)>;
663-
std::optional<func::FuncOp> walkCalls(ModuleOp moduleOp,
664-
ProcessCallOpFn processCallOp,
665-
bool preOrder = false);
666-
// Walks calls as in `walkCalls` and returns the main func. Dies if the the walk
667-
// is interrupted, or otherwise it can not identify the main function.
668-
func::FuncOp walkCallsOrDie(ModuleOp moduleOp, ProcessCallOpFn processCallOp,
669-
bool preOrder = false);
668+
bool walkCalls(ModuleOp moduleOp, ProcessCallOpFn processCallOp,
669+
bool preOrder = false);
670+
671+
// Iterates on the funcs and performs `processFuncOp` on funcs. Iterates on the
672+
// funcs and blocks in post order of the call graph by default, that is, the
673+
// functions are processed before their callers, and child blocks are processed
674+
// before their parents.
675+
using ProcessFuncOpFn = std::function<void(func::FuncOp)>;
676+
void iterateFuncs(ModuleOp moduleOp, ProcessFuncOpFn processFuncOp);
670677

671678
// Returns the reduction operation used in the scatter's update computation if
672679
// it is a recognized associative and commutative binary op applied to all
@@ -678,6 +685,11 @@ Operation* getCommonSupportedReductionOp(stablehlo::ScatterOp scatter);
678685
mlir::func::FuncOp cloneFuncRecursively(func::FuncOp funcOp,
679686
SymbolTable& symbolTable);
680687

688+
// Returns the funcOp on `funcSymName`. Dies if the func does not exist on the
689+
// `symbolTable`.
690+
func::FuncOp getFuncOpOrDie(StringRef funcSymName,
691+
const SymbolTable& symbolTable);
692+
681693
} // namespace sdy
682694
} // namespace mlir
683695

shardy/dialect/sdy/ir/utils_test.cc

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -647,47 +647,59 @@ TEST_F(UtilsTest, WalkCalls_Interrupted) {
647647
EXPECT_THAT(calledFuncs, ElementsAre("baz"));
648648
}
649649

650-
TEST_F(UtilsTest, WalkCallsPostOrder_ReturnsMainFunc) {
650+
TEST_F(UtilsTest, IterateFuncs_ThreeCalls) {
651651
auto localModule = mlir::parseSourceString<ModuleOp>(
652652
"module {\n"
653-
" func.func private @bar()\n"
654-
" func.func private @foo() {\n"
655-
" call @bar() : () -> ()\n"
653+
" func.func private @bar() {\n"
654+
" call @baz() : () -> ()\n"
655+
" return\n"
656+
" }\n"
657+
" func.func private @baz() {\n"
656658
" return\n"
657659
" }\n"
658-
" func.func @mymain() {\n"
660+
" func.func @main() {\n"
659661
" call @foo() : () -> ()\n"
660662
" return\n"
661663
" }\n"
664+
" func.func private @foo() {\n"
665+
" call @bar() : () -> ()\n"
666+
" return\n"
667+
" }\n"
662668
"}",
663669
&context);
664-
EXPECT_THAT(
665-
walkCalls(localModule.get(),
666-
[&](func::CallOp callOp) { return WalkResult::advance(); })
667-
->getName(),
668-
"mymain");
670+
std::vector<std::string> iteratedFuncs;
671+
iterateFuncs(localModule.get(), [&](func::FuncOp funcOp) {
672+
iteratedFuncs.push_back(funcOp.getName().str());
673+
});
674+
EXPECT_THAT(iteratedFuncs, ElementsAre("baz", "bar", "foo", "main"));
669675
}
670676

671-
TEST_F(UtilsTest, WalkCallsPreOrder_ReturnsMainFunc) {
677+
TEST_F(UtilsTest, IterateFuncs_Triangle) {
672678
auto localModule = mlir::parseSourceString<ModuleOp>(
673679
"module {\n"
674-
" func.func private @bar()\n"
675-
" func.func private @foo() {\n"
676-
" call @bar() : () -> ()\n"
680+
" func.func private @bar() {\n"
681+
" call @baz() : () -> ()\n"
677682
" return\n"
678683
" }\n"
679-
" func.func @mymain() {\n"
684+
" func.func private @baz() {\n"
685+
" return\n"
686+
" }\n"
687+
" func.func @main() {\n"
680688
" call @foo() : () -> ()\n"
689+
" call @bar() : () -> ()\n"
690+
" return\n"
691+
" }\n"
692+
" func.func private @foo() {\n"
693+
" call @bar() : () -> ()\n"
681694
" return\n"
682695
" }\n"
683696
"}",
684697
&context);
685-
EXPECT_THAT(walkCalls(
686-
localModule.get(),
687-
[&](func::CallOp callOp) { return WalkResult::advance(); },
688-
/*preOrder=*/true)
689-
->getName(),
690-
"mymain");
698+
std::vector<std::string> iteratedFuncs;
699+
iterateFuncs(localModule.get(), [&](func::FuncOp funcOp) {
700+
iteratedFuncs.push_back(funcOp.getName().str());
701+
});
702+
EXPECT_THAT(iteratedFuncs, ElementsAre("baz", "bar", "foo", "main"));
691703
}
692704

693705
TEST_F(UtilsTest, GetShardableValue_AsyncStartOp) {

0 commit comments

Comments
 (0)