Skip to content

Commit 0c5bbdd

Browse files
petebucopybara-github
authored andcommitted
[mpmd] Support replicated and unreduced axes in GenerateSdyMeshesFromTopologyPass
PiperOrigin-RevId: 810029732
1 parent 17f66cc commit 0c5bbdd

2 files changed

Lines changed: 73 additions & 10 deletions

File tree

shardy/dialect/mpmd/transforms/import/generate_sdy_meshes_from_topology_pass.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,9 @@ class GenerateSdyMeshesFromTopologyPass
111111
sharding.getUnreducedAxes());
112112
}
113113
StringRef mesh_name;
114-
SmallVector<sdy::DimensionShardingAttr> dim_shardings;
115-
for (auto dim_sharding : sharding.getDimShardings()) {
116-
SmallVector<sdy::AxisRefAttr> axes;
117-
for (sdy::AxisRefAttr axis : dim_sharding.getAxes()) {
114+
auto rename_axes = [&mesh_name](ArrayRef<sdy::AxisRefAttr> axes) {
115+
SmallVector<sdy::AxisRefAttr> new_axes;
116+
for (sdy::AxisRefAttr axis : axes) {
118117
auto [prefix, axis_name] = axis.getName().split(kMeshAxisSeparator);
119118
SDY_CHECK(!axis_name.empty())
120119
<< "Axis name does not contain '" << kMeshAxisSeparator << "'";
@@ -124,19 +123,23 @@ class GenerateSdyMeshesFromTopologyPass
124123
<< prefix.str();
125124
}
126125
mesh_name = prefix;
127-
axes.push_back(sdy::AxisRefAttr::get(
128-
module_op.getContext(), axis_name, axis.getSubAxisInfo()));
126+
new_axes.push_back(sdy::AxisRefAttr::get(axis.getContext(), axis_name,
127+
axis.getSubAxisInfo()));
129128
}
129+
return new_axes;
130+
};
131+
SmallVector<sdy::DimensionShardingAttr> dim_shardings;
132+
for (auto dim_sharding : sharding.getDimShardings()) {
130133
dim_shardings.push_back(sdy::DimensionShardingAttr::get(
131-
module_op.getContext(), axes, dim_sharding.getIsClosed(),
132-
dim_sharding.getPriority()));
134+
module_op.getContext(), rename_axes(dim_sharding.getAxes()),
135+
dim_sharding.getIsClosed(), dim_sharding.getPriority()));
133136
}
134137
SDY_CHECK(!llvm::is_contained(old_meshes, mesh_name))
135138
<< "Invalid mesh name: " << mesh_name.str();
136-
// TODO(b/440336690): Add support for replicated axes and unreduced axes.
137139
return sdy::TensorShardingAttr::get(
138140
sharding.getContext(), mesh_name, dim_shardings,
139-
sharding.getReplicatedAxes(), sharding.getUnreducedAxes());
141+
rename_axes(sharding.getReplicatedAxes()),
142+
rename_axes(sharding.getUnreducedAxes()));
140143
});
141144

142145
for (StringRef mesh_name : old_meshes) {

shardy/dialect/mpmd/transforms/import/test/generate_sdy_meshes_from_topology.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,63 @@ module @fully_replicated_tensor {
108108
return %0 : tensor<16xf32>
109109
}
110110
}
111+
112+
// -----
113+
114+
// CHECK-LABEL: module @replicated_axes
115+
module @replicated_axes {
116+
// CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
117+
// CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
118+
// CHECK-NOT: sdy.mesh @mesh
119+
// CHECK-NOT: sdy.mesh @mesh_0
120+
sdy.mesh @mesh = <["tpu_x"=8, "tpu_y"=8]>
121+
sdy.mesh @mesh_0 = <["cpu_z"=8]>
122+
123+
// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}], replicated={"y"}>}
124+
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
125+
func.func @main(
126+
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x"}], replicated={"tpu_y"}>},
127+
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"cpu_z":(1)2}]>})
128+
-> (tensor<16xf32>) attributes {
129+
topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>} {
130+
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg0) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
131+
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
132+
mpmd.return %2 : tensor<16xf32>
133+
} : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
134+
%1 = mpmd.named_computation<"stage2"> (%arg1, %0) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
135+
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
136+
mpmd.return %2 : tensor<16xf32>
137+
} : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
138+
return %1 : tensor<16xf32>
139+
}
140+
}
141+
142+
// -----
143+
144+
// CHECK-LABEL: module @unreduced_axes
145+
module @unreduced_axes {
146+
// CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
147+
// CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
148+
// CHECK-NOT: sdy.mesh @mesh
149+
// CHECK-NOT: sdy.mesh @mesh_0
150+
sdy.mesh @mesh = <["tpu_x"=8, "tpu_y"=8]>
151+
sdy.mesh @mesh_0 = <["cpu_z"=8]>
152+
153+
// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}], unreduced={"y"}>}
154+
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
155+
func.func @main(
156+
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x"}], unreduced={"tpu_y"}>},
157+
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"cpu_z":(1)2}]>})
158+
-> (tensor<16xf32>) attributes {
159+
topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>} {
160+
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg0) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
161+
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
162+
mpmd.return %2 : tensor<16xf32>
163+
} : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
164+
%1 = mpmd.named_computation<"stage2"> (%arg1, %0) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
165+
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
166+
mpmd.return %2 : tensor<16xf32>
167+
} : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
168+
return %1 : tensor<16xf32>
169+
}
170+
}

0 commit comments

Comments
 (0)