Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions shardy/dialect/sdy/ir/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ inline const std::string kEmptyMeshSymbol = "empty_mesh";
// Attribute name for the original name of the func before flattening.
inline constexpr llvm::StringRef kOriginalFuncName = "sdy.original_func_name";

// Attribute name of the main func.
inline constexpr llvm::StringRef kMainFuncName = "main";

} // namespace sdy
} // namespace mlir

Expand Down
27 changes: 27 additions & 0 deletions shardy/dialect/sdy/ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1115,5 +1115,32 @@ FuncOp getFuncOpOrDie(StringRef funcSymName, const SymbolTable& symbolTable) {
return funcOp;
}

TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
Attribute meshOrRef) {
SmallVector<TensorShardingAttr> resultShardings;
resultShardings.reserve(values.size());
for (mlir::Value value : values) {
resultShardings.push_back(TensorShardingAttr::getFullyReplicated(
meshOrRef.getContext(), mlir::sdy::getTensorRank(value), meshOrRef,
/*isClosed=*/true));
}
return TensorShardingPerValueAttr::get(meshOrRef.getContext(),
resultShardings);
}

// Returns the main func. Dies if there is no main func.
FuncOp getMainFuncOrDie(ModuleOp moduleOp, SymbolTable& symbolTable,
bool useTheOneIfSingleFunc) {
if (useTheOneIfSingleFunc) {
auto funcOps = moduleOp.getOps<FuncOp>();
if (std::distance(funcOps.begin(), funcOps.end()) == 1) {
return *funcOps.begin();
}
}
FuncOp funcOp = symbolTable.lookup<FuncOp>(kMainFuncName);
SDY_CHECK(funcOp) << "Failed to lookup function: " << kMainFuncName.str();
return funcOp;
}

} // namespace sdy
} // namespace mlir
12 changes: 12 additions & 0 deletions shardy/dialect/sdy/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,18 @@ mlir::func::FuncOp cloneFuncRecursively(func::FuncOp funcOp,
func::FuncOp getFuncOpOrDie(StringRef funcSymName,
const SymbolTable& symbolTable);

// Returns a `TensorShardingPerValueAttr` on the shardings of the `values`. If
// the sharding of a value is null, it creates a fully closed sharding for it on
// the given `meshOrRef` and the rank of the tensor corresponding to the value.
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
Attribute meshOrRef);

// Returns the main func. Dies if there is no main func. If
// `useTheOneIfSingleFunc` is true, then first check if the module has only one
// func, and assume it as the main func. Useful for tests.
mlir::func::FuncOp getMainFuncOrDie(ModuleOp moduleOp, SymbolTable& symbolTable,
bool useTheOneIfSingleFunc = false);

} // namespace sdy
} // namespace mlir

Expand Down
93 changes: 93 additions & 0 deletions shardy/dialect/sdy/ir/utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,99 @@ TEST_F(UtilsTest, GetShardableValue_AsyncStartOp) {
EXPECT_EQ(getShardableValue(blockArg), operand);
}

TEST_F(UtilsTest, GetMainFuncOrDie_MainIsFirst) {
auto localModule = mlir::parseSourceString<ModuleOp>(
"module {\n"
" func.func @main() {\n"
" call @foo() : () -> ()\n"
" return\n"
" }\n"
" func.func private @foo() {\n"
" return\n"
" }\n"
"}",
&context);
SymbolTable symbolTable(localModule.get());
EXPECT_THAT(getMainFuncOrDie(localModule.get(), symbolTable).getName(),
"main");
}

TEST_F(UtilsTest, GetMainFuncOrDie_MainIsLast) {
auto localModule = mlir::parseSourceString<ModuleOp>(
"module {\n"
" func.func private @foo() {\n"
" return\n"
" }\n"
" func.func @main() {\n"
" call @foo() : () -> ()\n"
" return\n"
" }\n"
"}",
&context);
SymbolTable symbolTable(localModule.get());
EXPECT_THAT(getMainFuncOrDie(localModule.get(), symbolTable).getName(),
"main");
}

TEST_F(UtilsTest, GetMainFuncOrDie_MainIsOnly) {
auto localModule = mlir::parseSourceString<ModuleOp>(
"module {\n"
" func.func @main() {\n"
" return\n"
" }\n"
"}",
&context);
SymbolTable symbolTable(localModule.get());
EXPECT_THAT(getMainFuncOrDie(localModule.get(), symbolTable).getName(),
"main");
}

TEST_F(UtilsTest, GetMainFuncOrDie_SingleButNotMain) {
auto localModule = mlir::parseSourceString<ModuleOp>(
"module {\n"
" func.func @some() {\n"
" return\n"
" }\n"
"}",
&context);
SymbolTable symbolTable(localModule.get());
ASSERT_DEATH(getMainFuncOrDie(localModule.get(), symbolTable),
"Failed to lookup function: main");
}

TEST_F(UtilsTest, GetMainFuncOrDie_UseTheOneIfSingle_SingleButNotMain) {
auto localModule = mlir::parseSourceString<ModuleOp>(
"module {\n"
" func.func @some() {\n"
" return\n"
" }\n"
"}",
&context);
SymbolTable symbolTable(localModule.get());
EXPECT_THAT(getMainFuncOrDie(localModule.get(), symbolTable,
/*useTheOneIfSingle=*/true)
.getName(),
"some");
}

TEST_F(UtilsTest, GetMainFuncOrDie_UseTheOneIfSingle_NoMain) {
auto localModule = mlir::parseSourceString<ModuleOp>(
"module {\n"
" func.func private @foo() {\n"
" return\n"
" }\n"
" func.func @some() {\n"
" call @foo() : () -> ()\n"
" return\n"
" }\n"
"}",
&context);
SymbolTable symbolTable(localModule.get());
ASSERT_DEATH(getMainFuncOrDie(localModule.get(), symbolTable,
/*useTheOneIfSingle=*/true),
"Failed to lookup function: main");
}

} // namespace

} // namespace sdy
Expand Down
3 changes: 2 additions & 1 deletion shardy/dialect/sdy/transforms/common/propagation_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct PropagationOptions {
// auto-partitioner will be invoked after propagation of user-specified
// shardings.
bool enableAutoPartitioning = false;
// Whether to avoid explicit reshards/collectives on named computations.
// Whether to avoid explicit reshards/collectives on named computations/calls.
// TODO(enver): Rename to avoidReshardsOnCalls.
bool avoidReshardsOnNamedComputations = false;
// Whether to update axes with non-divisible input/output shardings.
bool updateNonDivisibleInputOutputShardings = true;
Expand Down
44 changes: 42 additions & 2 deletions shardy/dialect/sdy/transforms/export/export_named_computations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ limitations under the License.
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "shardy/common/logging.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/transforms/propagation/utils.h"

namespace mlir {
namespace sdy {
Expand All @@ -42,6 +46,18 @@ namespace {
using func::CallOp;
using func::FuncOp;

void removeDataFlowEdges(ValueRange values, IRRewriter& rewriter) {
for (Value value : values) {
if (value.use_empty()) {
continue;
}
if (auto dataFlowEdgeOp = dyn_cast<DataFlowEdgeOp>(*value.user_begin())) {
SDY_CHECK(value.hasOneUse());
rewriter.replaceOp(dataFlowEdgeOp, dataFlowEdgeOp.getInput());
}
}
}

struct NamedComputationWithCount {
NamedComputationOp namedComputationOp;
int64_t callSiteCount;
Expand All @@ -66,6 +82,8 @@ StringAttr createFuncOp(
inlineRegionAndConvertTerminatorOp<func::ReturnOp>(
namedComputationOp.getBody(), funcOp.getBody());

removeDataFlowEdges(funcOp.getArguments(), rewriter);

// Copy the input shardings to the func.
if (inShardings.has_value()) {
for (auto [i, sharding] : llvm::enumerate(inShardings->getShardings())) {
Expand Down Expand Up @@ -95,15 +113,35 @@ TensorShardingPerValueAttr getFullyClosedLike(

void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
Block& moduleBlock = moduleOp.getRegion().front();
MLIRContext* ctx = moduleOp.getContext();
IRRewriter rewriter(moduleOp);

// NOTE: The walk needs to be in post order, which is the default order, to
// account for nested named computations.
SmallVector<Value> callResults;
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
IRRewriter rewriter(namedComputationOp);
rewriter.setInsertionPointToEnd(&moduleBlock);

// Propagate the shardings from the data flow edges to argument shardings.
ArrayRef<BlockArgument> blockArgOwners =
namedComputationOp.getBody().getArguments();
if (SmallVector<TensorShardingAttr> blockArgShardings =
getShardingsFromDataFlowEdges(blockArgOwners);
!blockArgShardings.empty()) {
namedComputationOp.setInShardingsAttr(
TensorShardingPerValueAttr::get(ctx, blockArgShardings));
}
std::optional<TensorShardingPerValueAttr> inShardings =
namedComputationOp.getInShardings();

// Propagate the shardings from the data flow edges to result shardings.
ResultRange resultOwners = namedComputationOp.getResults();
if (SmallVector<TensorShardingAttr> resultShardings =
getShardingsFromDataFlowEdges(resultOwners);
!resultShardings.empty()) {
namedComputationOp.setOutShardingsAttr(
TensorShardingPerValueAttr::get(ctx, resultShardings));
}
std::optional<TensorShardingPerValueAttr> outShardings =
namedComputationOp.getOutShardings();

Expand All @@ -117,6 +155,7 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
auto callOp = rewriter.replaceOpWithNewOp<CallOp>(
namedComputationOp, namedComputationOp.getResultTypes(), funcSymName,
namedComputationOp.getOperands());
llvm::append_range(callResults, callOp.getResults());
callOp->setAttrs(callOpAttrs);
FuncOp funcOp = symbolTable.lookup<FuncOp>(funcSymName);
// Copy the func output shardings to the call op.
Expand All @@ -128,6 +167,7 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
: getFullyClosedLike(funcResultShardings));
}
});
removeDataFlowEdges(callResults, rewriter);
}

struct ExportNamedComputationsPass
Expand All @@ -138,8 +178,8 @@ struct ExportNamedComputationsPass
void runOnOperation() final {
ModuleOp moduleOp = getOperation();
SymbolTableCollection symbolTableCollection;

SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp);

exportNamedComputations(moduleOp, symbolTable);
}
};
Expand Down
6 changes: 1 addition & 5 deletions shardy/dialect/sdy/transforms/export/export_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
const ExportOptions& options) {
InsertExplicitReshardsPassOptions passOptions;
passOptions.enableFullVersion = options.enableInsertExplicitCollectives;
passOptions.avoidReshardsOnNamedComputations =
options.avoidReshardsOnNamedComputations;
passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls;
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass(passOptions));
pm.addPass(createExportNamedComputationsPass());
if (options.enableInsertExplicitCollectives) {
pm.addPass(mlir::sdy::createSaveModuleOpPass(
options.dumpDirectory, "after_explicit_reshards", dumpIndex++));
Expand Down Expand Up @@ -98,8 +96,6 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
// reshards/collectives.
if (!options.avoidExportForPartitioning) {
runShardyPartitioner(pm, dumpIndex, options);
} else {
pm.addPass(createExportNamedComputationsPass());
}
if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
pm.addPass(createRemovePropagationDebugInfoPass());
Expand Down
63 changes: 44 additions & 19 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ namespace sdy {

namespace {

using func::CallOp;
using func::FuncOp;

void insertExplicitReshardsToTargetSharding(OpOperand& opOperand,
TensorShardingAttr targetSharding,
IRRewriter& rewriter,
Expand Down Expand Up @@ -102,22 +105,10 @@ void insertExplicitReshardsOnFuncReturn(Operation* op, func::FuncOp& funcOp,
}
}

void insertExplicitReshardsOnDataFlowOp(
ShardableDataFlowOpInterface& op, IRRewriter& rewriter,
const SymbolTable& symbolTable, const bool onFullVersion,
const bool avoidReshardsOnNamedComputations) {
if (isa<NamedComputationOp>(op) && avoidReshardsOnNamedComputations) {
for (Value owner : op.getOpResultEdgeOwners()) {
for (OpOperand* sourceOpOperand : op.getEdgeSources(owner)) {
insertExplicitReshardsToTargetSharding(
*sourceOpOperand,
/*targetSharding=*/op.getEdgeOwnerSharding(owner), rewriter,
symbolTable,
/*insertAfterOperand=*/true, onFullVersion);
}
}
return;
}
void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
IRRewriter& rewriter,
const SymbolTable& symbolTable,
const bool onFullVersion) {
for (Value owner : llvm::concat<Value>(op.getOpResultEdgeOwners(),
op.getBlockArgumentEdgeOwners())) {
TensorShardingAttr ownerSharding = op.transformTargetSharding(
Expand All @@ -132,6 +123,33 @@ void insertExplicitReshardsOnDataFlowOp(
}
}

void insertExplicitReshardsOnCallOp(CallOp callOp, IRRewriter& rewriter,
const SymbolTable& symbolTable,
const bool onFullVersion) {
FuncOp funcOp = symbolTable.lookup<FuncOp>(callOp.getCallee());
TensorShardingPerValueAttr funcArgShardings =
mlir::sdy::getFuncArgShardings(funcOp, symbolTable);
if (!funcArgShardings) {
mlir::Attribute meshOrRef = getMeshOrRef(
callOp.getNumOperands(), symbolTable,
[&](int64_t i) { return getSharding(callOp.getOperand(i)); });
// Return without inserting reshards as neither func arguments nor call
// operands have a sharding with non-maximal mesh.
if (!meshOrRef) {
return;
}
funcArgShardings = getFullyClosedLike(callOp.getOperands(), meshOrRef);
}
rewriter.setInsertionPoint(callOp);
for (auto [funcArgSharding, sourceOpOperand] : llvm::zip_equal(
funcArgShardings.getShardings(), callOp->getOpOperands())) {
insertExplicitReshardsToTargetSharding(
sourceOpOperand,
/*targetSharding=*/funcArgSharding, rewriter, symbolTable,
/*insertAfterOperand=*/true, onFullVersion);
}
}

// Reshard the result of a dot operation if all the following hold:
//
// 1. LHS and RHS have fully compatible shardings.
Expand Down Expand Up @@ -382,7 +400,7 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
}
// To avoid copies of the same functions with mismatching shardings on the
// arguments onto multiple callsites.
if (isa<NamedComputationOp>(op)) {
if (isa<func::CallOp>(op)) {
return true;
}

Expand Down Expand Up @@ -472,8 +490,15 @@ struct InsertExplicitReshardsPass
// TODO(enver): Prefer resharding the owner when multiple sources are
// sharded in the same way.
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter,
symbolTable, onFullVersion,
avoidReshardsOnNamedComputations);
symbolTable, onFullVersion);
return;
}

if (CallOp callOp = dyn_cast<CallOp>(op)) {
if (!avoidReshardsOnCalls) {
insertExplicitReshardsOnCallOp(callOp, rewriter, symbolTable,
onFullVersion);
}
return;
}

Expand Down
Loading
Loading