Skip to content

Commit f21cb36

Browse files
bixia1Google-ML-Automation
authored andcommitted
Use the utility routines from mlir::sdy.
PiperOrigin-RevId: 874245224
1 parent 06f2977 commit f21cb36

File tree

4 files changed

+2
-120
lines changed

4 files changed

+2
-120
lines changed

xla/service/spmd/shardy/stablehlo_round_trip/export_manual_reduction_collectives.cc

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -101,64 +101,6 @@ bool inputHasUnreducedAxes(CollectiveTy collective) {
101101
return false;
102102
}
103103

104-
// Builds the replica groups for `reductionAxesAttr`.
105-
//
106-
// For example, given:
107-
//
108-
// - reductionAxesAttr = ["y"]
109-
// - mesh = ["x"=2, "y"=2]
110-
//
111-
// Returns `[[0, 1], [2, 3]]`.
112-
mlir::DenseIntElementsAttr getReplicaGroups(
113-
sdy::AxisRefListAttr reductionAxesAttr, MeshAttr mesh,
114-
OpBuilder& rewriter) {
115-
SmallVector<AxisRefAttr> meshAxisRefs =
116-
getOrderedAxisRefs(reductionAxesAttr, mesh);
117-
118-
ArrayRef<AxisRefAttr> reductionAxes = reductionAxesAttr.getValue();
119-
int64_t groupSize = 1;
120-
llvm::SmallDenseMap<AxisRefAttr, int64_t> axisRefToReductionIndex;
121-
axisRefToReductionIndex.reserve(reductionAxes.size());
122-
for (auto [index, axis] : llvm::enumerate(reductionAxes)) {
123-
groupSize *= axis.getSize(mesh);
124-
axisRefToReductionIndex[axis] = index;
125-
}
126-
int64_t totalSize = mesh.getTotalSize();
127-
int64_t numGroups = totalSize / groupSize;
128-
129-
SmallVector<int64_t> transposePerm(meshAxisRefs.size());
130-
SmallVector<int64_t> reshapeDims;
131-
reshapeDims.reserve(meshAxisRefs.size());
132-
133-
int64_t nonReductionIndex = 0;
134-
int64_t nonReductionCount = meshAxisRefs.size() - reductionAxes.size();
135-
for (auto [meshIndex, axis] : llvm::enumerate(meshAxisRefs)) {
136-
reshapeDims.push_back(axis.getSize(mesh));
137-
auto reductionIndexIt = axisRefToReductionIndex.find(axis);
138-
if (reductionIndexIt == axisRefToReductionIndex.end()) {
139-
// Axis is not a reduction axis.
140-
transposePerm[nonReductionIndex++] = meshIndex;
141-
} else {
142-
transposePerm[nonReductionCount + reductionIndexIt->second] = meshIndex;
143-
}
144-
}
145-
146-
// TODO(b/410040098): output V2 if possible, and maybe canonicalize.
147-
148-
Array<int64_t> array(reshapeDims);
149-
if (mesh.getDeviceIds().empty()) {
150-
array.FillIota(0);
151-
} else {
152-
array.SetValues(mesh.getDeviceIds());
153-
}
154-
array.TransposeDimensions(transposePerm);
155-
array.Reshape({totalSize});
156-
auto replicaGroupsType = RankedTensorType::get({numGroups, groupSize},
157-
rewriter.getIntegerType(64));
158-
return mlir::DenseIntElementsAttr::get(replicaGroupsType,
159-
llvm::to_vector(array));
160-
}
161-
162104
// Creates a manual computation, with all axes in `mesh` as manual, and
163105
// populates its body via the `bodyPopulator` function.
164106
ManualComputationOp createFullyManualComputation(

xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,8 @@ HloSharding convertToHloSharding(
425425
}
426426

427427
// We will add all axes and let canonicalization merge adjacent axes.
428-
SmallVector<AxisRefAttr> meshAxisRefs = getOrderedAxisRefs(sdySharding, mesh);
428+
SmallVector<AxisRefAttr> meshAxisRefs =
429+
mlir::sdy::getOrderedAxisRefs(sdySharding, mesh);
429430
SmallVector<int64_t> reshapeDims(meshAxisRefs.size());
430431
SmallVector<int> transposePerm(meshAxisRefs.size());
431432

xla/service/spmd/shardy/utils.cc

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -300,59 +300,6 @@ std::string duplicateShardingsAtIndices(
300300
TensorShardingPerValueAttr::get(context.get(), newShardings));
301301
}
302302

303-
SmallVector<AxisRefAttr> getOrderedAxisRefs(Attribute shardingOrAxisList,
304-
MeshAttr mesh) {
305-
// We use a map vector to maintain the order of mesh axes.
306-
llvm::MapVector<StringRef, SmallVector<int64_t>> axisNameToPreSizes;
307-
axisNameToPreSizes.reserve(mesh.getAxes().size());
308-
for (MeshAxisAttr meshAxis : mesh.getAxes()) {
309-
SmallVector<int64_t>& preSizes = axisNameToPreSizes[meshAxis.getName()];
310-
preSizes.push_back(1);
311-
preSizes.push_back(meshAxis.getSize());
312-
}
313-
314-
auto consumeAxisRefList = [&](ArrayRef<AxisRefAttr> axisRefs) {
315-
for (AxisRefAttr axisRef : axisRefs) {
316-
// Add sub-axis pre-sizes to `axisNameToPreSizes`. We'll dedup later.
317-
if (axisRef.getSubAxisInfo()) {
318-
SmallVector<int64_t>& preSizes = axisNameToPreSizes[axisRef.getName()];
319-
preSizes.push_back(axisRef.getSubAxisInfo().getPreSize());
320-
preSizes.push_back(axisRef.getSubAxisInfo().getNextPreSize());
321-
}
322-
}
323-
};
324-
325-
if (auto sharding = mlir::dyn_cast<TensorShardingAttr>(shardingOrAxisList)) {
326-
for (DimensionShardingAttr dimSharding : sharding.getDimShardings()) {
327-
consumeAxisRefList(dimSharding.getAxes());
328-
}
329-
consumeAxisRefList(sharding.getUnreducedAxes());
330-
} else {
331-
consumeAxisRefList(
332-
mlir::cast<AxisRefListAttr>(shardingOrAxisList).getValue());
333-
}
334-
335-
SmallVector<AxisRefAttr> axisRefs;
336-
mlir::MLIRContext* ctx = mesh.getContext();
337-
for (auto& [axisName, preSizes] : axisNameToPreSizes) {
338-
if (preSizes.size() == 2) {
339-
// Full axis
340-
axisRefs.push_back(AxisRefAttr::get(ctx, axisName));
341-
continue;
342-
}
343-
llvm::sort(preSizes);
344-
preSizes.erase(llvm::unique(preSizes), preSizes.end());
345-
for (int64_t i = 0; i < preSizes.size() - 1; ++i) {
346-
int64_t preSize = preSizes[i];
347-
int64_t size = preSizes[i + 1] / preSize;
348-
axisRefs.push_back(AxisRefAttr::get(
349-
ctx, axisName, SubAxisInfoAttr::get(ctx, preSize, size)));
350-
}
351-
}
352-
353-
return axisRefs;
354-
}
355-
356303
namespace {
357304

358305
// Check if the func result is meant for Shardy.

xla/service/spmd/shardy/utils.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,6 @@ std::string duplicateShardingsAtIndices(
146146
mlir::StringRef shardingsFrontendAttr,
147147
const llvm::BitVector& indicesToDuplicate);
148148

149-
// Return all axes or sub-axes in the `mesh`, such that sub-axes are derived
150-
// from `shardingOrAxisList` (including unreduced axes but not replicated)
151-
// and sorted by their order in the mesh. For example, given mesh <"x"=2,
152-
// "y"=16, "z"=4> and axis refs [{"x"}, {"y":2(2)}], we would return ["x",
153-
// "y":1(2), "y":2(2), "y":4(4), "z"].
154-
mlir::SmallVector<mlir::sdy::AxisRefAttr> getOrderedAxisRefs(
155-
mlir::Attribute shardingOrAxisList, mlir::sdy::MeshAttr mesh);
156-
157149
// Returns true if the module has at least one GSPMD attribute or op, like an
158150
// `mhlo.sharding` attribute or `Sharding` custom call.
159151
// TODO(b/420837831): delete this once we don't fall back to GSPMD.

0 commit comments

Comments
 (0)