Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,10 @@ SmallVector<int64_t> getTensorSizes(Operation* op) {
return tensorSizes;
}

namespace {

// Returns reduction axes that are the union of all axes on reduction factors.
// The result axes are not necessarilly canonicalized.
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
OpShardingRuleAttr shardingRule) {
SmallVector<AxisRefAttr> reductionAxes;
Expand All @@ -813,6 +817,43 @@ SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
return reductionAxes;
}

// Assume the results have unreduced axes.
//
// Populates `commonAxesPerFactor` with only its reduction factors are populated
// to have common axes.
//
// Hard fails if some reduction factors do not have compatible shardings.
void populateCommonAxesPerReductionFactorOrDie(
Operation* op, const ShardingProjection& shardingProjection,
OpShardingRuleAttr shardingRule, AxesPerFactor& commonAxesPerFactor) {
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
// factors, and simplify the following logic.
commonAxesPerFactor = AxesPerFactor(shardingRule.getNumFactors());
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
// We only iterate operands since reduction factors are not in results.
bool seen = false;
SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
for (const TensorFactorShardings& tensorFactorSharding :
shardingProjection.getOperands()) {
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
getFactorSharding(tensorFactorSharding, reductionFactor)) {
SmallVector<AxisRefAttr> factorShardingVector =
llvm::to_vector(*factorSharding);
if (seen) {
SDY_CHECK(factorShardingVector == commonAxes)
<< "For the operation " << op
<< ", the result has unreduced axes while the operand has "
"incompatible sharding along reduction factors.";
} else {
commonAxes = factorShardingVector;
seen = true;
}
}
}
}
}
} // namespace

TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
OpOperand& use, TensorShardingAttr sourceSharding,
TensorShardingAttr userSharding, const SymbolTable& symbolTable,
Expand Down Expand Up @@ -869,11 +910,26 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(Value value) {
return getUnreducedAxes(getSharding(value));
}

void insertAllReducesForReductionFactors(Operation* op,
ArrayRef<AxisRefAttr> reductionAxes,
const Mesh& mesh,
IRRewriter& rewriter) {
if (reductionAxes.empty() || op->getResults().empty()) {
void insertAllReducesForReductionFactors(
Operation* op, const ShardingProjection& shardingProjection,
AxesPerFactor& commonAxesPerFactor, OpShardingRuleAttr shardingRule,
const Mesh& mesh, IRRewriter& rewriter, const bool onFullVersion) {
if (op->getResults().empty()) {
return;
}

if (!onFullVersion && getUnreducedAxes(op->getResult(0)).empty()) {
return;
}
if (commonAxesPerFactor.empty()) {
// At this point, there are unreduced axes on results.
populateCommonAxesPerReductionFactorOrDie(
op, shardingProjection, shardingRule, commonAxesPerFactor);
}

SmallVector<AxisRefAttr> reductionAxes =
getReductionAxes(commonAxesPerFactor, shardingRule);
if (reductionAxes.empty()) {
return;
}

Expand Down
28 changes: 16 additions & 12 deletions shardy/dialect/sdy/transforms/export/explicit_reshards_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(Value value);
// Returns a concatenated array of operand and result tensor sizes.
SmallVector<int64_t> getTensorSizes(Operation* op);

// Returns reduction axes that are the union of all axes on reduction factors.
// The result axes are not necessarilly canonicalized.
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
OpShardingRuleAttr shardingRule);

// Returns true iff any tensor factor sharding has non-empty overflow axes.
bool hasOverflowAxes(const ShardingProjection& shardingProjection);

Expand Down Expand Up @@ -147,19 +142,28 @@ void insertExplicitReshards(Operation* op,
OpShardingRuleAttr shardingRule,
const SymbolTable& symbolTable, const Mesh& mesh);

// Inserts an `sdy.all-reduce` for each result of `op` if `reductionAxes`
// is non-empty. Assume the followings:
// Inserts an `sdy.all-reduce` for each result of `op`.
//
// Assumes the followings:
// - All op results have the same unreduced axes.
// - All op results have the same mesh as `mesh` ignoring device id orders.
void insertAllReducesForReductionFactors(Operation* op,
ArrayRef<AxisRefAttr> reductionAxes,
const Mesh& mesh,
IRRewriter& rewriter);
// - If `commonAxesPerFactor` is nonempty, op has compatible shardings.
//
// In case `onFullVersion` is false, it inserts all reduces only if op results
// have some unreduced axes.
//
// Hard fails if the reduction factors do not have compatible shardings, and op
// results have unreduced axes.
void insertAllReducesForReductionFactors(
Operation* op, const ShardingProjection& shardingProjection,
AxesPerFactor& commonAxesPerFactor, OpShardingRuleAttr shardingRule,
const Mesh& mesh, IRRewriter& rewriter, bool onFullVersion);

// Finds common factor axes on the operands and results of `op` so that the
// sharding of `op` is compatible with its sharding rule.
//
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
// Refer to the documentation of `InsertExplicitReshardsPass` for more
// details.
//
// Assume the followings:
// - All op results have the same unreduced axes.
Expand Down
109 changes: 42 additions & 67 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,31 +400,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
// - All op results have the same unreduced axes.
// - If the op has no results, none of the operands has unreduced axes.
// - Operand and result meshes are the same ignoring device id order.
// - There are no overflow axes.
//
// Returns the union of axes along all the reduction factors which may not be
// canonicalized.
SmallVector<AxisRefAttr> processOp(Operation* op,
ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
IRRewriter& rewriter,
const SymbolTable& symbolTable,
OpShardingRuleAttr shardingRule,
const Mesh& mesh, const bool onFullVersion) {
ShardingProjection shardingProjection = ShardingProjection::build(
inShardings, outShardings, shardingRule, mesh.attr(),
/*closedIfMissing=*/true);

// Return without inserting reshards if any factor sharding has overflow
// axes. This case is not handled yet.
// TODO(enver): Handle the case when factor shardings have overflow axes.
if (hasOverflowAxes(shardingProjection)) {
return {};
}

//
// Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
AxesPerFactor processOp(Operation* op, ShardingProjection& shardingProjection,
ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
IRRewriter& rewriter, const SymbolTable& symbolTable,
OpShardingRuleAttr shardingRule, const Mesh& mesh,
const bool onFullVersion) {
// Checks if factors are sharded the same way across operands and results.
AxesPerFactor commonAxesPerFactor =
getCompatibleFactorShardings(shardingProjection, shardingRule);

// TODO(b/446833985): Return common axes factors also when the sharding
// projection have overflow axes.
if (onFullVersion) {
// Checks if factors are sharded the same way across operands and results.
AxesPerFactor commonAxesPerFactor =
getCompatibleFactorShardings(shardingProjection, shardingRule);
// Find compatible shardings if it is not already compatible.
if (commonAxesPerFactor.empty()) {
commonAxesPerFactor =
Expand All @@ -443,49 +437,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
updateTensorShardings, rewriter, shardingRule,
symbolTable, mesh);

return getReductionAxes(commonAxesPerFactor, shardingRule);
} else {
TypeSwitch<Operation*>(op)
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
shardingRule, mesh);
})
.Case<stablehlo::DotGeneralOp>(
[&](stablehlo::DotGeneralOp dotGeneralOp) {
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
symbolTable, shardingRule, mesh);
});
}

TypeSwitch<Operation*>(op)
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
shardingRule, mesh);
})
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
symbolTable, shardingRule, mesh);
});

if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
return {};
}

// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
// factors, and simplify the following logic.
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
// We only iterate operands since reduction factors are not in results.
bool seen = false;
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
for (const TensorFactorShardings& tensorFactorSharding :
shardingProjection.getOperands()) {
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
getFactorSharding(tensorFactorSharding, reductionFactor)) {
if (seen) {
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
<< "For the operation " << op
<< ", the result has unreduced axes while the operand has "
"incompatible sharding along reduction factors.";
} else {
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
seen = true;
}
}
}
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
}
return axesAlongAllReductionFactors;
return commonAxesPerFactor;
}

struct InsertExplicitReshardsPass
Expand Down Expand Up @@ -544,11 +508,22 @@ struct InsertExplicitReshardsPass
return;
}

SmallVector<AxisRefAttr> reductionAxes =
processOp(op, inShardings, outShardings, rewriter, symbolTable,
shardingRule, *mesh, onFullVersion);
ShardingProjection shardingProjection = ShardingProjection::build(
inShardings, outShardings, shardingRule, mesh->attr(),
/*closedIfMissing=*/true);
// Return without inserting reshards if any factor sharding has overflow
// axes. This case is not handled yet.
// TODO(enver): Handle the case when factor shardings have overflow axes.
if (hasOverflowAxes(shardingProjection)) {
return;
}
AxesPerFactor commonAxesPerFactor =
processOp(op, shardingProjection, inShardings, outShardings, rewriter,
symbolTable, shardingRule, *mesh, onFullVersion);
// TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
insertAllReducesForReductionFactors(op, reductionAxes, *mesh, rewriter);
insertAllReducesForReductionFactors(op, shardingProjection,
commonAxesPerFactor, shardingRule,
*mesh, rewriter, onFullVersion);

// TODO(enver): Remove sharding rules from ops.
});
Expand Down
Loading