@@ -108,3 +108,63 @@ module @fully_replicated_tensor {
108108 return %0 : tensor <16 xf32 >
109109 }
110110}
111+
112+ // -----
113+
114+ // CHECK-LABEL: module @replicated_axes
115+ module @replicated_axes {
116+ // CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
117+ // CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
118+ // CHECK-NOT: sdy.mesh @mesh
119+ // CHECK-NOT: sdy.mesh @mesh_0
120+ sdy.mesh @mesh = <[" tpu_x" =8 , " tpu_y" =8 ]>
121+ sdy.mesh @mesh_0 = <[" cpu_z" =8 ]>
122+
123+ // CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}], replicated={"y"}>}
124+ // CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
125+ func.func @main (
126+ %arg0: tensor <16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" tpu_x" }], replicated ={" tpu_y" }>},
127+ %arg1: tensor <16 xf32 > {sdy.sharding = #sdy.sharding <@mesh_0 , [{" cpu_z" :(1 )2 }]>})
128+ -> (tensor <16 xf32 >) attributes {
129+ topology = #mpmd.topology <<" tpu" : <[" x" =2 , " y" =4 ]>>, <" cpu" : <[" z" =8 ]>>>} {
130+ %0 = mpmd.named_computation <" stage1" > (%arg0 , %arg0 ) (%arg3: tensor <16 xf32 >, %arg4: tensor <16 xf32 >) {
131+ %2 = stablehlo.add %arg4 , %arg3 : tensor <16 xf32 >
132+ mpmd.return %2 : tensor <16 xf32 >
133+ } : (tensor <16 xf32 >, tensor <16 xf32 >) -> tensor <16 xf32 >
134+ %1 = mpmd.named_computation <" stage2" > (%arg1 , %0 ) (%arg3: tensor <16 xf32 >, %arg4: tensor <16 xf32 >) {
135+ %2 = stablehlo.add %arg4 , %arg3 : tensor <16 xf32 >
136+ mpmd.return %2 : tensor <16 xf32 >
137+ } : (tensor <16 xf32 >, tensor <16 xf32 >) -> tensor <16 xf32 >
138+ return %1 : tensor <16 xf32 >
139+ }
140+ }
141+
142+ // -----
143+
144+ // CHECK-LABEL: module @unreduced_axes
145+ module @unreduced_axes {
146+ // CHECK-DAG: sdy.mesh @tpu = <["x"=2, "y"=4]>
147+ // CHECK-DAG: sdy.mesh @cpu = <["z"=8]>
148+ // CHECK-NOT: sdy.mesh @mesh
149+ // CHECK-NOT: sdy.mesh @mesh_0
150+ sdy.mesh @mesh = <[" tpu_x" =8 , " tpu_y" =8 ]>
151+ sdy.mesh @mesh_0 = <[" cpu_z" =8 ]>
152+
153+ // CHECK: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@tpu, [{"x"}], unreduced={"y"}>}
154+ // CHECK: %arg1: tensor<16xf32> {sdy.sharding = #sdy.sharding<@cpu, [{"z":(1)2}]>}
155+ func.func @main (
156+ %arg0: tensor <16 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" tpu_x" }], unreduced ={" tpu_y" }>},
157+ %arg1: tensor <16 xf32 > {sdy.sharding = #sdy.sharding <@mesh_0 , [{" cpu_z" :(1 )2 }]>})
158+ -> (tensor <16 xf32 >) attributes {
159+ topology = #mpmd.topology <<" tpu" : <[" x" =2 , " y" =4 ]>>, <" cpu" : <[" z" =8 ]>>>} {
160+ %0 = mpmd.named_computation <" stage1" > (%arg0 , %arg0 ) (%arg3: tensor <16 xf32 >, %arg4: tensor <16 xf32 >) {
161+ %2 = stablehlo.add %arg4 , %arg3 : tensor <16 xf32 >
162+ mpmd.return %2 : tensor <16 xf32 >
163+ } : (tensor <16 xf32 >, tensor <16 xf32 >) -> tensor <16 xf32 >
164+ %1 = mpmd.named_computation <" stage2" > (%arg1 , %0 ) (%arg3: tensor <16 xf32 >, %arg4: tensor <16 xf32 >) {
165+ %2 = stablehlo.add %arg4 , %arg3 : tensor <16 xf32 >
166+ mpmd.return %2 : tensor <16 xf32 >
167+ } : (tensor <16 xf32 >, tensor <16 xf32 >) -> tensor <16 xf32 >
168+ return %1 : tensor <16 xf32 >
169+ }
170+ }
0 commit comments