Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,64 +101,6 @@ bool inputHasUnreducedAxes(CollectiveTy collective) {
return false;
}

// Builds the replica groups for `reductionAxesAttr`.
//
// For example, given:
//
// - reductionAxesAttr = ["y"]
// - mesh = ["x"=2, "y"=2]
//
// Returns `[[0, 1], [2, 3]]`.
mlir::DenseIntElementsAttr getReplicaGroups(
sdy::AxisRefListAttr reductionAxesAttr, MeshAttr mesh,
OpBuilder& rewriter) {
SmallVector<AxisRefAttr> meshAxisRefs =
getOrderedAxisRefs(reductionAxesAttr, mesh);

ArrayRef<AxisRefAttr> reductionAxes = reductionAxesAttr.getValue();
int64_t groupSize = 1;
llvm::SmallDenseMap<AxisRefAttr, int64_t> axisRefToReductionIndex;
axisRefToReductionIndex.reserve(reductionAxes.size());
for (auto [index, axis] : llvm::enumerate(reductionAxes)) {
groupSize *= axis.getSize(mesh);
axisRefToReductionIndex[axis] = index;
}
int64_t totalSize = mesh.getTotalSize();
int64_t numGroups = totalSize / groupSize;

SmallVector<int64_t> transposePerm(meshAxisRefs.size());
SmallVector<int64_t> reshapeDims;
reshapeDims.reserve(meshAxisRefs.size());

int64_t nonReductionIndex = 0;
int64_t nonReductionCount = meshAxisRefs.size() - reductionAxes.size();
for (auto [meshIndex, axis] : llvm::enumerate(meshAxisRefs)) {
reshapeDims.push_back(axis.getSize(mesh));
auto reductionIndexIt = axisRefToReductionIndex.find(axis);
if (reductionIndexIt == axisRefToReductionIndex.end()) {
// Axis is not a reduction axis.
transposePerm[nonReductionIndex++] = meshIndex;
} else {
transposePerm[nonReductionCount + reductionIndexIt->second] = meshIndex;
}
}

// TODO(b/410040098): output V2 if possible, and maybe canonicalize.

Array<int64_t> array(reshapeDims);
if (mesh.getDeviceIds().empty()) {
array.FillIota(0);
} else {
array.SetValues(mesh.getDeviceIds());
}
array.TransposeDimensions(transposePerm);
array.Reshape({totalSize});
auto replicaGroupsType = RankedTensorType::get({numGroups, groupSize},
rewriter.getIntegerType(64));
return mlir::DenseIntElementsAttr::get(replicaGroupsType,
llvm::to_vector(array));
}

// Creates a manual computation, with all axes in `mesh` as manual, and
// populates its body via the `bodyPopulator` function.
ManualComputationOp createFullyManualComputation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@ HloSharding convertToHloSharding(
}

// We will add all axes and let canonicalization merge adjacent axes.
SmallVector<AxisRefAttr> meshAxisRefs = getOrderedAxisRefs(sdySharding, mesh);
SmallVector<AxisRefAttr> meshAxisRefs =
mlir::sdy::getOrderedAxisRefs(sdySharding, mesh);
SmallVector<int64_t> reshapeDims(meshAxisRefs.size());
SmallVector<int> transposePerm(meshAxisRefs.size());

Expand Down
53 changes: 0 additions & 53 deletions xla/service/spmd/shardy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,59 +300,6 @@ std::string duplicateShardingsAtIndices(
TensorShardingPerValueAttr::get(context.get(), newShardings));
}

SmallVector<AxisRefAttr> getOrderedAxisRefs(Attribute shardingOrAxisList,
MeshAttr mesh) {
// We use a map vector to maintain the order of mesh axes.
llvm::MapVector<StringRef, SmallVector<int64_t>> axisNameToPreSizes;
axisNameToPreSizes.reserve(mesh.getAxes().size());
for (MeshAxisAttr meshAxis : mesh.getAxes()) {
SmallVector<int64_t>& preSizes = axisNameToPreSizes[meshAxis.getName()];
preSizes.push_back(1);
preSizes.push_back(meshAxis.getSize());
}

auto consumeAxisRefList = [&](ArrayRef<AxisRefAttr> axisRefs) {
for (AxisRefAttr axisRef : axisRefs) {
// Add sub-axis pre-sizes to `axisNameToPreSizes`. We'll dedup later.
if (axisRef.getSubAxisInfo()) {
SmallVector<int64_t>& preSizes = axisNameToPreSizes[axisRef.getName()];
preSizes.push_back(axisRef.getSubAxisInfo().getPreSize());
preSizes.push_back(axisRef.getSubAxisInfo().getNextPreSize());
}
}
};

if (auto sharding = mlir::dyn_cast<TensorShardingAttr>(shardingOrAxisList)) {
for (DimensionShardingAttr dimSharding : sharding.getDimShardings()) {
consumeAxisRefList(dimSharding.getAxes());
}
consumeAxisRefList(sharding.getUnreducedAxes());
} else {
consumeAxisRefList(
mlir::cast<AxisRefListAttr>(shardingOrAxisList).getValue());
}

SmallVector<AxisRefAttr> axisRefs;
mlir::MLIRContext* ctx = mesh.getContext();
for (auto& [axisName, preSizes] : axisNameToPreSizes) {
if (preSizes.size() == 2) {
// Full axis
axisRefs.push_back(AxisRefAttr::get(ctx, axisName));
continue;
}
llvm::sort(preSizes);
preSizes.erase(llvm::unique(preSizes), preSizes.end());
for (int64_t i = 0; i < preSizes.size() - 1; ++i) {
int64_t preSize = preSizes[i];
int64_t size = preSizes[i + 1] / preSize;
axisRefs.push_back(AxisRefAttr::get(
ctx, axisName, SubAxisInfoAttr::get(ctx, preSize, size)));
}
}

return axisRefs;
}

namespace {

// Check if the func result is meant for Shardy.
Expand Down
8 changes: 0 additions & 8 deletions xla/service/spmd/shardy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,6 @@ std::string duplicateShardingsAtIndices(
mlir::StringRef shardingsFrontendAttr,
const llvm::BitVector& indicesToDuplicate);

// Return all axes or sub-axes in the `mesh`, such that sub-axes are derived
// from `shardingOrAxisList` (including unreduced axes but not replicated)
// and sorted by their order in the mesh. For example, given mesh <"x"=2,
// "y"=16, "z"=4> and axis refs [{"x"}, {"y":2(2)}], we would return ["x",
// "y":1(2), "y":2(2), "y":4(4), "z"].
mlir::SmallVector<mlir::sdy::AxisRefAttr> getOrderedAxisRefs(
mlir::Attribute shardingOrAxisList, mlir::sdy::MeshAttr mesh);

// Returns true if the module has at least one GSPMD attribute or op, like an
// `mhlo.sharding` attribute or `Sharding` custom call.
// TODO(b/420837831): delete this once we don't fall back to GSPMD.
Expand Down
Loading