|
| 1 | +// RUN: sdy_opt %s -sdy-propagate-sharding-from-func-to-call -split-input-file | FileCheck %s |
| 2 | + |
| 3 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 4 | + |
| 5 | +// CHECK-LABEL: func @propagate_func_to_call |
| 6 | +func.func @propagate_func_to_call(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { |
| 7 | + // CHECK-NEXT: %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 8 | + %0 = call @foo(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 9 | + return %0 : tensor<8x2xi32> |
| 10 | +} |
| 11 | + |
| 12 | +func.func private @foo(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { |
| 13 | + return %arg0 : tensor<8x2xi32> |
| 14 | +} |
| 15 | + |
| 16 | +// ----- |
| 17 | + |
| 18 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 19 | + |
| 20 | +// CHECK-LABEL: func @do_not_overwrite_call_sharding |
| 21 | +func.func @do_not_overwrite_call_sharding(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { |
| 22 | + // CHECK-NEXT: %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 23 | + %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 24 | + return %0 : tensor<8x2xi32> |
| 25 | +} |
| 26 | + |
| 27 | +func.func private @foo(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { |
| 28 | + return %arg0 : tensor<8x2xi32> |
| 29 | +} |
| 30 | + |
| 31 | +// ----- |
| 32 | + |
| 33 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 34 | + |
| 35 | +// CHECK-LABEL: func @both_call_and_func_has_empty_result_shardings |
| 36 | +func.func @both_call_and_func_has_empty_result_shardings(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { |
| 37 | + // CHECK-NEXT: %0 = call @foo(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 38 | + %0 = call @foo(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 39 | + return %0 : tensor<8x2xi32> |
| 40 | +} |
| 41 | + |
| 42 | +func.func private @foo(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { |
| 43 | + return %arg0 : tensor<8x2xi32> |
| 44 | +} |
| 45 | + |
| 46 | +// ----- |
| 47 | + |
| 48 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 49 | + |
| 50 | +// CHECK-LABEL: func @multiple_results |
| 51 | +func.func @multiple_results(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) { |
| 52 | + // CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 53 | + %0:2 = call @foo(%arg0, %arg1) : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 54 | + return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 55 | +} |
| 56 | + |
| 57 | +func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) { |
| 58 | + return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 59 | +} |
| 60 | + |
| 61 | +// ----- |
| 62 | + |
| 63 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 64 | + |
| 65 | +// CHECK-LABEL: func @keep_empty_call_sharding |
| 66 | +func.func @keep_empty_call_sharding(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { |
| 67 | + // CHECK-NEXT: %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 68 | + %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> |
| 69 | + return %0 : tensor<8x2xi32> |
| 70 | +} |
| 71 | + |
| 72 | +func.func private @foo(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { |
| 73 | + return %arg0 : tensor<8x2xi32> |
| 74 | +} |
| 75 | + |
| 76 | +// ----- |
| 77 | + |
| 78 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 79 | + |
| 80 | +// CHECK-LABEL: func @multiple_results_one_same_one_is_empty |
| 81 | +func.func @multiple_results_one_same_one_is_empty(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) { |
| 82 | + // CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 83 | + %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 84 | + return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 85 | +} |
| 86 | + |
| 87 | +func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) { |
| 88 | + return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 89 | +} |
| 90 | + |
| 91 | +// ----- |
| 92 | + |
| 93 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 94 | + |
| 95 | +// CHECK-LABEL: func @multiple_results_one_different_one_is_empty |
| 96 | +func.func @multiple_results_one_different_one_is_empty(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) { |
| 97 | + // CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 98 | + %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 99 | + return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 100 | +} |
| 101 | + |
| 102 | +func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) { |
| 103 | + return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 104 | +} |
| 105 | + |
| 106 | +// ----- |
| 107 | + |
| 108 | +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> |
| 109 | + |
| 110 | +// CHECK-LABEL: func @multiple_results_call_no_sharding_func_has_sharding_on_one_no_sharding_on_the_other |
| 111 | +func.func @multiple_results_call_no_sharding_func_has_sharding_on_one_no_sharding_on_the_other(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) { |
| 112 | + // CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{?}, {?}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 113 | + %0:2 = call @foo(%arg0, %arg1) : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) |
| 114 | + return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 115 | +} |
| 116 | + |
| 117 | +func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32>) { |
| 118 | + return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32> |
| 119 | +} |
0 commit comments