Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ class GenerateSdyMeshesFromTopologyPass
sharding.getUnreducedAxes());
}
StringRef mesh_name;
auto rename_axes = [&mesh_name](ArrayRef<sdy::AxisRefAttr> axes) {
SmallVector<sdy::AxisRefAttr> new_axes;
auto check_axes = [&mesh_name](ArrayRef<sdy::AxisRefAttr> axes) {
for (sdy::AxisRefAttr axis : axes) {
auto [prefix, axis_name] = axis.getName().split(kMeshAxisSeparator);
SDY_CHECK(!axis_name.empty())
Expand All @@ -123,23 +122,18 @@ class GenerateSdyMeshesFromTopologyPass
<< prefix.str();
}
mesh_name = prefix;
new_axes.push_back(sdy::AxisRefAttr::get(axis.getContext(), axis_name,
axis.getSubAxisInfo()));
}
return new_axes;
};
SmallVector<sdy::DimensionShardingAttr> dim_shardings;
for (auto dim_sharding : sharding.getDimShardings()) {
dim_shardings.push_back(sdy::DimensionShardingAttr::get(
module_op.getContext(), rename_axes(dim_sharding.getAxes()),
dim_sharding.getIsClosed(), dim_sharding.getPriority()));
check_axes(dim_sharding.getAxes());
}
check_axes(sharding.getReplicatedAxes());
check_axes(sharding.getUnreducedAxes());
SDY_CHECK(!llvm::is_contained(old_meshes, mesh_name))
<< "Invalid mesh name: " << mesh_name.str();
return sdy::TensorShardingAttr::get(
sharding.getContext(), mesh_name, dim_shardings,
rename_axes(sharding.getReplicatedAxes()),
rename_axes(sharding.getUnreducedAxes()));
sharding.getContext(), mesh_name, sharding.getDimShardings(),
sharding.getReplicatedAxes(), sharding.getUnreducedAxes());
});

for (StringRef mesh_name : old_meshes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

// CHECK-LABEL: module @multiple_input_meshes
module @multiple_input_meshes {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2, "tpu_y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["cpu_z"=8]>
// CHECK-DAG: sdy.mesh @empty_mesh = <[]>
// CHECK-DAG: sdy.mesh @maximal_mesh = <[], device_ids=[0]>
// CHECK-NOT: sdy.mesh @mesh
Expand All @@ -13,16 +13,16 @@ module @multiple_input_meshes {
sdy.mesh @empty_mesh = <[]>
sdy.mesh @maximal_mesh = <[], device_ids=[0]>

// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x", "y"}]>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
// CHECK: %arg2: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x", "y"}]>}
// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x", "tpu_y"}]>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"cpu_z":(1)2}]>}
// CHECK: %arg2: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x", "tpu_y"}]>}
func.func @main(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x", "tpu_y"}]>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"cpu_z":(1)2}]>},
%arg2: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x", "tpu_y"}]>})
-> (tensor<16xf32>) attributes {
// CHECK: topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>
topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>} {
// CHECK: topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>} {
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg2) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
mpmd.return %2 : tensor<16xf32>
Expand All @@ -39,20 +39,20 @@ module @multiple_input_meshes {

// CHECK-LABEL: module @empty_mesh
module @empty_mesh {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2]>
// CHECK-DAG: sdy.mesh @empty_mesh = <[]>
// CHECK-NOT: sdy.mesh @mesh
sdy.mesh @mesh = <["tpu_x"=2]>
sdy.mesh @empty_mesh = <[]>

// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}]>}
// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x"}]>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@empty_mesh, [{}]>}
func.func @main(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x"}]>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@empty_mesh, [{}]>})
-> (tensor<16xf32>) attributes {
// CHECK: topology = #mpmd.topology<<"tpu" : <["x"=2]>>>
topology = #mpmd.topology<<"tpu" : <["x"=2]>>>} {
// CHECK: topology = #mpmd.topology<<"tpu" : <["tpu_x"=2]>>>
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2]>>>} {
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg1) (%arg2: tensor<16xf32>, %arg3: tensor<16xf32>) {
%2 = stablehlo.add %arg3, %arg2 : tensor<16xf32>
mpmd.return %2 : tensor<16xf32>
Expand All @@ -65,20 +65,20 @@ module @empty_mesh {

// CHECK-LABEL: module @maximal_mesh
module @maximal_mesh {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2]>
// CHECK-DAG: sdy.mesh @maximal_mesh = <[], device_ids=[0]>
// CHECK-NOT: sdy.mesh @mesh
sdy.mesh @mesh = <["tpu_x"=2]>
sdy.mesh @maximal_mesh = <[], device_ids=[0]>

// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}]>}
// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x"}]>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh, []>}
func.func @main(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x"}]>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh, []>})
-> (tensor<16xf32>) attributes {
// CHECK: topology = #mpmd.topology<<"tpu" : <["x"=2]>>>
topology = #mpmd.topology<<"tpu" : <["x"=2]>>>} {
// CHECK: topology = #mpmd.topology<<"tpu" : <["tpu_x"=2]>>>
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2]>>>} {
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg1) (%arg2: tensor<16xf32>, %arg3: tensor<16xf32>) {
%2 = stablehlo.add %arg3, %arg2 : tensor<16xf32>
mpmd.return %2 : tensor<16xf32>
Expand All @@ -91,7 +91,7 @@ module @maximal_mesh {

// CHECK-LABEL: module @fully_replicated_tensor
module @fully_replicated_tensor {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2]>
// CHECK-DAG: sdy.mesh @empty_mesh = <[]>
// CHECK-NOT: sdy.mesh @mesh
sdy.mesh @mesh = <["tpu_x"=2]>
Expand All @@ -102,8 +102,8 @@ module @fully_replicated_tensor {
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}]>})
-> (tensor<16xf32>) attributes {
// CHECK: topology = #mpmd.topology<<"tpu" : <["x"=2]>>, <"empty_mesh" : <[]>>>
topology = #mpmd.topology<<"tpu" : <["x"=2]>>>} {
// CHECK: topology = #mpmd.topology<<"tpu" : <["tpu_x"=2]>>, <"empty_mesh" : <[]>>>
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2]>>>} {
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg1) (%arg2: tensor<16xf32>, %arg3: tensor<16xf32>) {
%2 = stablehlo.add %arg3, %arg2 : tensor<16xf32>
mpmd.return %2 : tensor<16xf32>
Expand All @@ -116,21 +116,21 @@ module @fully_replicated_tensor {

// CHECK-LABEL: module @replicated_axes
module @replicated_axes {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2, "tpu_y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["cpu_z"=8]>
// CHECK-NOT: sdy.mesh @mesh
// CHECK-NOT: sdy.mesh @mesh_0
sdy.mesh @mesh = <["tpu_x"=8, "tpu_y"=8]>
sdy.mesh @mesh_0 = <["cpu_z"=8]>

// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}], replicated={"y"}>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x"}], replicated={"tpu_y"}>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"cpu_z":(1)2}]>}
func.func @main(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x"}], replicated={"tpu_y"}>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"cpu_z":(1)2}]>})
-> (tensor<16xf32>) attributes {
// CHECK: topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>
topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>} {
// CHECK: topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>} {
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg0) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
mpmd.return %2 : tensor<16xf32>
Expand All @@ -147,21 +147,21 @@ module @replicated_axes {

// CHECK-LABEL: module @unreduced_axes
module @unreduced_axes {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2, "tpu_y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["cpu_z"=8]>
// CHECK-NOT: sdy.mesh @mesh
// CHECK-NOT: sdy.mesh @mesh_0
sdy.mesh @mesh = <["tpu_x"=8, "tpu_y"=8]>
sdy.mesh @mesh_0 = <["cpu_z"=8]>

// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}], unreduced={"y"}>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
// CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x"}], unreduced={"tpu_y"}>}
// CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"cpu_z":(1)2}]>}
func.func @main(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x"}], unreduced={"tpu_y"}>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"cpu_z":(1)2}]>})
-> (tensor<16xf32>) attributes {
// CHECK: topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>
topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>} {
// CHECK: topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>} {
%0 = mpmd.named_computation<"stage1"> (%arg0, %arg0) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
mpmd.return %2 : tensor<16xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: mpmd_opt %s -mpmd-import-pipeline='name-to-mesh-assignment=f1@tpu,f2@cpu enable-heterogeneous-meshes' -split-input-file 2>&1 | FileCheck %s

module @multiple_input_meshes {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2, "tpu_y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["cpu_z"=8]>
// CHECK-DAG: sdy.mesh @empty_mesh = <[]>
// CHECK-DAG: sdy.mesh @maximal_mesh = <[], device_ids=[0]>
// CHECK-NOT: sdy.mesh @mesh
Expand All @@ -12,15 +12,15 @@ module @multiple_input_meshes {
sdy.mesh @empty_mesh = <[]>
sdy.mesh @maximal_mesh = <[], device_ids=[0]>

// CHECK: %arg0: !mpmd.mesh_tensor<"tpu", tensor<16xf32>> {sdy.sharding = #sdy.sharding<@tpu, [{"x", "y"}]>}
// CHECK: %arg1: !mpmd.mesh_tensor<"cpu", tensor<16xf32>> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
// CHECK: %arg2: !mpmd.mesh_tensor<"tpu", tensor<16xf32>> {sdy.sharding = #sdy.sharding<@tpu, [{"x", "y"}]>}
// CHECK: %arg0: !mpmd.mesh_tensor<"tpu", tensor<16xf32>> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x", "tpu_y"}]>}
// CHECK: %arg1: !mpmd.mesh_tensor<"cpu", tensor<16xf32>> {sdy.sharding = #sdy.sharding<@cpu, [{"cpu_z":(1)2}]>}
// CHECK: %arg2: !mpmd.mesh_tensor<"tpu", tensor<16xf32>> {sdy.sharding = #sdy.sharding<@tpu, [{"tpu_x", "tpu_y"}]>}
func.func @main(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x", "tpu_y"}]>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"cpu_z":(1)2}]>},
%arg2: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x", "tpu_y"}]>})
-> (tensor<16xf32>) attributes {
topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>} {
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>} {
%0 = mpmd.named_computation<"f1"> (%arg0, %arg2) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
mpmd.return %2 : tensor<16xf32>
Expand Down
12 changes: 6 additions & 6 deletions shardy/dialect/mpmd/transforms/test/e2e_pipeline.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: mpmd_opt %s -mpmd-import-pipeline='name-to-mesh-assignment=f1@tpu,f2@cpu enable-heterogeneous-meshes' -mpmd-optimize-pipeline -mpmd-sharding-propagation-pipeline -mpmd-export-pipeline 2>&1 | FileCheck %s

module @multiple_input_meshes {
// CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
// CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2, "tpu_y"=4]>
// CHECK-DAG: sdy.mesh @cpu = <["cpu_z"=8]>
// CHECK-DAG: sdy.mesh @empty_mesh = <[]>
// CHECK-DAG: sdy.mesh @maximal_mesh = <[], device_ids=[0]>
// CHECK-NOT: sdy.mesh @mesh
Expand All @@ -12,15 +12,15 @@ module @multiple_input_meshes {
sdy.mesh @empty_mesh = <[]>
sdy.mesh @maximal_mesh = <[], device_ids=[0]>

// CHECK: %arg0: !mpmd.mesh_tensor<"tpu", tensor<16xf32>, sharding=<@tpu, [{"x", "y"}]>>
// CHECK: %arg1: !mpmd.mesh_tensor<"cpu", tensor<16xf32>, sharding=<@cpu, [{"z"}]>>
// CHECK: %arg2: !mpmd.mesh_tensor<"tpu", tensor<16xf32>, sharding=<@tpu, [{"x", "y"}]>>
// CHECK: %arg0: !mpmd.mesh_tensor<"tpu", tensor<16xf32>, sharding=<@tpu, [{"tpu_x", "tpu_y"}]>>
// CHECK: %arg1: !mpmd.mesh_tensor<"cpu", tensor<16xf32>, sharding=<@cpu, [{"cpu_z"}]>>
// CHECK: %arg2: !mpmd.mesh_tensor<"tpu", tensor<16xf32>, sharding=<@tpu, [{"tpu_x", "tpu_y"}]>>
func.func @main(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x", "tpu_y"}]>},
%arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"cpu_z"}]>},
%arg2: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"tpu_x", "tpu_y"}]>})
-> (tensor<16xf32>) attributes {
topology = #mpmd.topology<<"tpu" : <["x"=2, "y"=4]>>, <"cpu" : <["z"=8]>>>} {
topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>} {
// CHECK: %[[FRAGMENT_CALL1:.*]] = mpmd.fragment_call<mesh="tpu", origin=["f1"]> @p0_f1_fwd.multiple_input_meshes(%arg0, %arg2)
%0 = mpmd.named_computation<"f1"> (%arg0, %arg2) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) {
%2 = stablehlo.add %arg4, %arg3 : tensor<16xf32>
Expand Down
Loading