Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions shardy/dialect/mpmd/transforms/common/merge_fragments.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,6 @@ class MergeInferredFragmentsPass
current = fragment;
break;
}
if (auto transfer = dyn_cast<TransferOp>(current);
transfer &&
transfer.getType().getMeshName() == producer_op.getMeshName()) {
// If we transfer back into the same mesh, then abort, because the
// next node could have used the transferred value.
return mergeFailure(
"The closest fragment in the same mesh is a transfer. We do not "
"merge sideways for simplicity, to avoid having to check "
"dependencies.");
}
current = current->getNextNode();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,65 @@ func.func @interleaved(%arg0: !m1_4x8, %arg1: !m1_4x8)

func.return %2, %1 : !m1_4x8, !m1_4x8
}

// CHECK-LABEL: func @merge_across_transfer
func.func @merge_across_transfer(%arg0: !m1_4x8, %arg1: !m2_4x8)
-> (!m1_4x8, !m1_4x8) attributes {topology=#topo} {
// CHECK-NEXT: %[[TRANS:.*]] = mpmd.transfer %arg1 : (!mpmd.mesh_tensor<"m2", tensor<4x8xf32>>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
// CHECK-NEXT: %[[FRAG:.*]]:2 = mpmd.fragment<mesh="m1", origin=["f"]>
// CHECK-NEXT: stablehlo.add
// CHECK-NEXT: stablehlo.multiply
// CHECK-NEXT: mpmd.return
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FRAG]]#0, %[[FRAG]]#1

%0 = mpmd.fragment<mesh="m1", origin=[]> (%arg0)
(%arg2: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg2 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!m1_4x8) -> !m1_4x8

%1 = mpmd.transfer %arg1 : (!m2_4x8) -> !m1_4x8

%2 = mpmd.fragment<mesh="m1", origin=["f"]> (%arg0)
(%arg2: tensor<4x8xf32>) {
%4 = stablehlo.multiply %arg2, %arg2 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!m1_4x8) -> !m1_4x8

func.return %0, %2 : !m1_4x8, !m1_4x8
}

// CHECK-LABEL: func @no_merge_dependency
func.func @no_merge_dependency(%arg0: !m1_4x8)
-> (!m1_4x8, !m1_4x8) attributes {topology=#topo} {
// CHECK: mpmd.fragment<mesh="m1", origin=[]>
// CHECK: mpmd.transfer
// CHECK: mpmd.fragment<mesh="m2"
// CHECK: mpmd.transfer
// CHECK: mpmd.fragment<mesh="m1", origin=["f"]>

%0 = mpmd.fragment<mesh="m1", origin=[]> (%arg0)
(%arg2: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg2 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!m1_4x8) -> !m1_4x8

%1 = mpmd.transfer %0 : (!m1_4x8) -> !m2_4x8

%2 = mpmd.fragment<mesh="m2", origin=["g"]> (%1)
(%arg2: tensor<4x8xf32>) {
%4 = stablehlo.multiply %arg2, %arg2 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!m2_4x8) -> !m2_4x8

%3 = mpmd.transfer %2 : (!m2_4x8) -> !m1_4x8

%4 = mpmd.fragment<mesh="m1", origin=["f"]> (%3)
(%arg2: tensor<4x8xf32>) {
%5 = stablehlo.subtract %arg2, %arg2 : tensor<4x8xf32>
mpmd.return %5 : tensor<4x8xf32>
} : (!m1_4x8) -> !m1_4x8

func.return %0, %4 : !m1_4x8, !m1_4x8
}
6 changes: 6 additions & 0 deletions shardy/dialect/mpmd/transforms/export/export_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) {
// fragments to compile) and may cause performance regressions. Thus, we merge
// them with other fragments.
pm.addNestedPass<FuncOp>(createMergeInferredFragmentsPass());
{
MergeInferredFragmentsPassOptions mergeInferredOptions;
mergeInferredOptions.mergeSideways = true;
pm.addNestedPass<FuncOp>(
createMergeInferredFragmentsPass(std::move(mergeInferredOptions)));
}

// Mark each fragment with the inputs and outputs which are offloaded to host
// memory.
Expand Down
27 changes: 26 additions & 1 deletion shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mpmd_opt %s -mpmd-export-pipeline 2>&1 | FileCheck %s
// RUN: mpmd_opt %s -mpmd-export-pipeline -split-input-file 2>&1 | FileCheck %s

!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>

Expand All @@ -17,3 +17,28 @@ func.func @main(%arg0: !mesh_1_tensor_4_8_f32 {tf.aliasing_output = 0: i32}, %ar
} : (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) -> (!mesh_1_tensor_4_8_f32)
func.return %0 : !mesh_1_tensor_4_8_f32
}

// -----

!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
!m2_4x8 = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>>

// CHECK-LABEL: func.func @test_sideways_merge
func.func @test_sideways_merge(%arg0: !mesh_1_tensor_4_8_f32, %arg1: !m2_4x8)
-> (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) attributes {
"topology"=#mpmd.topology<
<"m1": <["x"=2]>>,
<"m2": <["x"=2]>>
>} {
// CHECK: mpmd.fragment_call<mesh="m1", origin=[]> @[[CALLEE_M1:.*]]
// CHECK-NOT: mpmd.fragment_call<mesh="m1"

%0 = mpmd.fragment<mesh="m1", origin=[]> (%arg0) (%arg2: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg2 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor_4_8_f32) -> !mesh_1_tensor_4_8_f32

%1 = mpmd.transfer %arg1 : (!m2_4x8) -> !mesh_1_tensor_4_8_f32

func.return %0, %1, %1 : !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32
}
Loading