Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ cc_library(
deps = [
":constants",
"//xla/hlo/ir:hlo_sharding",
"//xla/hlo/ir:mesh_and_axis",
"//xla/hlo/ir:named_sharding",
"//xla/hlo/translate/hlo_to_mhlo:hlo_utils",
"//xla/mlir_hlo:hlo_dialect_registration",
"//xla/service/spmd/shardy/extensions:mhlo_extensions",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down
3 changes: 3 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ cc_library(
hdrs = ["import_shardy_attrs.h"],
visibility = internal_visibility([":friends"]),
deps = [
":dedup_meshes",
"//xla/hlo/ir:hlo_sharding",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/translate/hlo_to_mhlo:hlo_utils",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
Expand All @@ -72,6 +74,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@shardy//shardy/dialect/sdy/transforms/import:passes",
"@stablehlo//:stablehlo_ops",
],
)
Expand Down
176 changes: 130 additions & 46 deletions xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "shardy/dialect/sdy/transforms/import/passes.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/sdy_round_trip/dedup_meshes.h"
#include "xla/service/spmd/shardy/utils.h"

namespace xla {
Expand Down Expand Up @@ -83,6 +86,13 @@ using ::mlir::stablehlo::CustomCallOp;

namespace stablehlo = ::mlir::stablehlo;

HloSharding parseShardingFromString(StringAttr sharding) {
absl::StatusOr<xla::HloSharding> hloSharding =
xla::ParseSharding(sharding.str());
CHECK(hloSharding.ok());
return *hloSharding;
}

CustomCallOp dynCastX64CombineCustomCall(Operation* op) {
auto customCallOp = mlir::dyn_cast<CustomCallOp>(op);
if (!customCallOp || customCallOp.getCallTargetName() != "X64Combine") {
Expand Down Expand Up @@ -177,48 +187,97 @@ void handleFuncResultSharding(CustomCallOp funcResultSharding, FuncOp funcOp,
// the module was exported from Shardy and we are now round-tripping back.
// This should happen after the meshes were created from the `ModuleOp` attrs
// (see `SdyRoundTripImportShardyAttrsPass`).
void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter,
bool enableHloShardingV3) {
// Copy over the argument shardings, but not the result shardings yet.
// We need to wait until after we've converted all the Operations before
// copying the result shardings.
for (auto [argNum, argType] : llvm::enumerate(funcOp.getArgumentTypes())) {
funcOp.removeArgAttr(argNum, kXlaShardingAttr);
// Attempt to extract the TensorShardingAttr from the frontend attributes of
// the function argument/result.
if (DictionaryAttr dictAttr = getFuncArgFrontendAttrs(funcOp, argNum)) {
if (auto sharding = parseStringAttr<TensorShardingAttr>(
dictAttr,
xla::ToStringRef(HloSharding::kShardingFrontendAttrName))) {
funcOp.setArgAttr(argNum, kShardingAttr, sharding);
removeFrontendAttribute(
funcOp, xla::ToStringRef(HloSharding::kShardingFrontendAttrName),
argNum);
if (enableHloShardingV3) {
if (auto oldSharding =
funcOp.getArgAttrOfType<StringAttr>(argNum, kXlaShardingAttr)) {
funcOp.setArgAttr(
argNum, kShardingAttr,
convertToSdyShardingAttr(parseShardingFromString(oldSharding),
argType, funcOp.getContext()));
}
} else {
// Attempt to extract the TensorShardingAttr from the frontend attributes
// of the function argument/result.
if (DictionaryAttr dictAttr = getFuncArgFrontendAttrs(funcOp, argNum)) {
if (auto sharding = parseStringAttr<TensorShardingAttr>(
dictAttr,
xla::ToStringRef(HloSharding::kShardingFrontendAttrName))) {
funcOp.setArgAttr(argNum, kShardingAttr, sharding);
removeFrontendAttribute(
funcOp, xla::ToStringRef(HloSharding::kShardingFrontendAttrName),
argNum);
}
}
}
funcOp.removeArgAttr(argNum, kXlaShardingAttr);
}

// Due to `SdyRoundTripExportShardingsPass` keeping `mhlo.sharding`s, remove
// them purely for cleanliness of the module.
for (int64_t resNum = 0; resNum < funcOp.getNumResults(); ++resNum) {
if (enableHloShardingV3) {
// Result shardings only need to be handled in HloShardingV3 case where we
// don't use FuncResultSharding custom call.
if (auto oldSharding = funcOp.getResultAttrOfType<StringAttr>(
resNum, kXlaShardingAttr)) {
funcOp.setResultAttr(
resNum, kShardingAttr,
convertToSdyShardingAttr(parseShardingFromString(oldSharding),
funcOp.getResultTypes()[resNum],
funcOp.getContext()));
}
}
funcOp.removeResultAttr(
resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr));
}

// Extract the round-tripped shardy attributes from the operations.
funcOp.front().walk([&](Operation* op) {
op->removeAttr(kXlaShardingAttr);
if (!enableHloShardingV3) {
op->removeAttr(kXlaShardingAttr);
}
DictionaryAttr dictAttr = getFrontendAttrs(op);
if (!dictAttr) {
// No frontend attributes to import.
if (!enableHloShardingV3 && !dictAttr) {
return;
}

// Import sharding rules.
if (dictAttr) {
if (auto shardingRuleAttr = parseStringAttr<OpShardingRuleAttr>(
dictAttr, kShardingRuleRoundTripAttr)) {
op->setAttr(kShardingRuleAttr, shardingRuleAttr);
removeFrontendAttribute(op, kShardingRuleRoundTripAttr);
}
}

// No shardings to import.
if (enableHloShardingV3 &&
!op->getAttrOfType<StringAttr>(kXlaShardingAttr)) {
return;
}

// `SendOp`, `RecvOp`, and `AfterAllOp` can have a sharding when doing TPU
// callbacks through JAX.
if (mlir::isa<stablehlo::SendOp, stablehlo::RecvOp, stablehlo::AfterAllOp>(
op)) {
if (auto sharding = parseStringAttr<TensorShardingPerValueAttr>(
dictAttr,
xla::ToStringRef(HloSharding::kShardingFrontendAttrName))) {
op->setAttr(kShardingAttr, sharding);
if (enableHloShardingV3) {
auto sharding = op->getAttrOfType<StringAttr>(kXlaShardingAttr);
op->setAttr(kShardingAttr, convertToSdySharding(
parseShardingFromString(sharding),
op->getResultTypes(), op->getContext()));
} else {
if (auto sharding = parseStringAttr<TensorShardingPerValueAttr>(
dictAttr,
xla::ToStringRef(HloSharding::kShardingFrontendAttrName))) {
op->setAttr(kShardingAttr, sharding);
}
}
}
// NOTE: we are only setting the sharding on known custom-calls. For any
Expand All @@ -229,27 +288,35 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
// round-trip b/w HLO and MLIR after SDY propagation.
if (auto customCallOp = mlir::dyn_cast<CustomCallOp>(op)) {
StringRef targetName = customCallOp.getCallTargetName();
if (targetName == kFuncResultShardingTargetName) {
if (targetName == kFuncResultShardingTargetName && !enableHloShardingV3) {
handleFuncResultSharding(customCallOp, funcOp, dictAttr, rewriter);
return;
}
if (targetName == kShardingCustomCallTargetName ||
isPythonCallbackCustomCall(customCallOp)) {
customCallOp->setAttr(
kShardingAttr,
parseStringAttr<TensorShardingPerValueAttr>(
dictAttr,
xla::ToStringRef(HloSharding::kShardingFrontendAttrName)));
if (enableHloShardingV3) {
auto sharding =
customCallOp->getAttrOfType<StringAttr>(kXlaShardingAttr);
customCallOp->setAttr(
kShardingAttr,
convertToSdySharding(parseShardingFromString(sharding),
customCallOp->getResultTypes(),
customCallOp->getContext()));
} else {
customCallOp->setAttr(
kShardingAttr,
parseStringAttr<TensorShardingPerValueAttr>(
dictAttr,
xla::ToStringRef(HloSharding::kShardingFrontendAttrName)));
}
}
}
removeFrontendAttribute(
op, xla::ToStringRef(HloSharding::kShardingFrontendAttrName));

// Import sharding rules.
if (auto shardingRuleAttr = parseStringAttr<OpShardingRuleAttr>(
dictAttr, kShardingRuleRoundTripAttr)) {
op->setAttr(kShardingRuleAttr, shardingRuleAttr);
removeFrontendAttribute(op, kShardingRuleRoundTripAttr);
op->removeAttr(kXlaShardingAttr);

if (!enableHloShardingV3) {
removeFrontendAttribute(
op, xla::ToStringRef(HloSharding::kShardingFrontendAttrName));
}
});
}
Expand Down Expand Up @@ -297,27 +364,33 @@ class SdyRoundTripImportShardyAttrsPass
void runOnOperation() final {
ModuleOp moduleOp = getOperation();

IRRewriter rewriter(moduleOp);
SymbolTable symbolTable(moduleOp);

// We can use the saved string attributes to restore the original mesh and
// value shardings with the original mesh axis names and priorities on the
// sharding. If there is no `kMeshesRoundTripAttr, there were no meshes in
// the original Shardy model.
std::optional<DictionaryAttr> meshesAttr =
tryGetFrontendAttr<DictionaryAttr>(moduleOp, kMeshesRoundTripAttr);
mlir::ArrayRef<NamedAttribute> sdyMeshes =
meshesAttr.has_value() ? meshesAttr.value().getValue()
: mlir::ArrayRef<NamedAttribute>();

IRRewriter rewriter(moduleOp);
// Insert the meshes before any functions.
rewriter.setInsertionPointToStart(moduleOp.getBody());
SymbolTable symbolTable(moduleOp);
for (NamedAttribute mesh : sdyMeshes) {
auto meshAttr = mlir::cast<MeshAttr>(mesh.getValue());
symbolTable.insert(mlir::sdy::MeshOp::create(rewriter, moduleOp.getLoc(),
mesh.getName(), meshAttr));
if (!enableHloShardingV3) {
// Insert the meshes before any functions.
rewriter.setInsertionPointToStart(moduleOp.getBody());
std::optional<DictionaryAttr> meshesAttr =
tryGetFrontendAttr<DictionaryAttr>(moduleOp, kMeshesRoundTripAttr);
mlir::ArrayRef<NamedAttribute> sdyMeshes =
meshesAttr.has_value() ? meshesAttr.value().getValue()
: mlir::ArrayRef<NamedAttribute>();

for (NamedAttribute mesh : sdyMeshes) {
auto meshAttr = mlir::cast<MeshAttr>(mesh.getValue());
symbolTable.insert(mlir::sdy::MeshOp::create(
rewriter, moduleOp.getLoc(), mesh.getName(), meshAttr));
}
removeFrontendAttribute(moduleOp, kMeshesRoundTripAttr);
}
removeFrontendAttribute(moduleOp, kMeshesRoundTripAttr);

// TODO (b/485486745): Remove kInTupleShardings and kOutTupleShardings
// frontend attributes added directly at tf2xla level
if (FuncOp mainFunc = moduleOp.lookupSymbol<FuncOp>("main")) {
auto argShardingSetter = [](FuncOp funcOp, int64_t argNum,
TensorShardingAttr argSharding) {
Expand All @@ -341,7 +414,18 @@ class SdyRoundTripImportShardyAttrsPass
}

for (auto funcOp : moduleOp.getOps<FuncOp>()) {
convertShardyAttrs(funcOp, rewriter);
convertShardyAttrs(funcOp, rewriter, enableHloShardingV3);
}

if (enableHloShardingV3) {
// Lift inlined meshes, as meshes are inlined in HloShardingV3 and
// therefore in sdy shardings generated from conversion.
mlir::PassManager pm(moduleOp.getContext());
pm.addPass(mlir::sdy::createLiftInlinedMeshesPass());
pm.addPass(createSdyRoundTripDedupMeshesPass());
if (mlir::failed(pm.run(moduleOp))) {
signalPassFailure();
}
}
}

Expand Down
Loading
Loading