@@ -119,15 +119,19 @@ func.func @constant_to_named_computation_with_only_constant_ops(%arg0: tensor<8x
119119
120120// CHECK-LABEL: func @constant_multiple_users_within_named_computation_with_no_arguments_and_with_only_constant_ops
121121func.func @constant_multiple_users_within_named_computation_with_no_arguments_and_with_only_constant_ops () -> (tensor <8 x16 xf32 >, tensor <8 x16 xf32 >) {
122- // CHECK-NEXT: %[[NC :.*]] = sdy.named_computation<"foo">()
123- // CHECK-NEXT: %[[CONST :.*]] = sdy.constant dense<1.000000e+00>
124- // CHECK-NEXT: %[[NEGATE :.*]] = stablehlo.negate %[[CONST ]]
125- // CHECK-NEXT: sdy.return %[[NEGATE ]]
122+ // CHECK-NEXT: %[[NC0 :.*]] = sdy.named_computation<"foo">()
123+ // CHECK-NEXT: %[[CONST0 :.*]] = sdy.constant dense<1.000000e+00>
124+ // CHECK-NEXT: %[[NEGATE0 :.*]] = stablehlo.negate %[[CONST0 ]]
125+ // CHECK-NEXT: sdy.return %[[NEGATE0 ]]
126126 // CHECK-NEXT: } : () -> tensor<8x16xf32>
127- // CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC]]
128- // CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC]]
127+ // CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">()
128+ // CHECK-NEXT: %[[CONST1:.*]] = sdy.constant dense<1.000000e+00>
129+ // CHECK-NEXT: %[[NEGATE1:.*]] = stablehlo.negate %[[CONST1]]
130+ // CHECK-NEXT: sdy.return %[[NEGATE1]]
131+ // CHECK-NEXT: } : () -> tensor<8x16xf32>
132+ // CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC0]]
133+ // CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC1]]
129134 // CHECK-NEXT: return %[[ABS_0]], %[[ABS_1]]
130- // TODO(enver): The named computation should be splitted.
131135 %0 = sdy.named_computation <" foo" >() () {
132136 %1 = stablehlo.constant dense <1.000000e+00 > : tensor <8 x16 xf32 >
133137 %2 = stablehlo.negate %1 : tensor <8 x16 xf32 >
@@ -140,16 +144,21 @@ func.func @constant_multiple_users_within_named_computation_with_no_arguments_an
140144
141145// CHECK-LABEL: func @constant_to_named_computation_with_one_argument_and_with_only_constant_ops
142146func.func @constant_to_named_computation_with_one_argument_and_with_only_constant_ops () -> (tensor <8 x16 xf32 >, tensor <8 x16 xf32 >) {
143- // CHECK-NEXT: %[[CONST:.*]] = sdy.constant dense<1.000000e+00>
144- // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%[[CONST]]) (%arg0: tensor<8x16xf32>) {
145- // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
146- // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[NEGATE]]
147- // CHECK-NEXT: sdy.return %[[ADD]]
147+ // CHECK-NEXT: %[[CONST0:.*]] = sdy.constant dense<1.000000e+00>
148+ // CHECK-NEXT: %[[CONST1:.*]] = sdy.constant dense<1.000000e+00>
149+ // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[CONST0]]) (%arg0: tensor<8x16xf32>) {
150+ // CHECK-NEXT: %[[NEGATE0:.*]] = stablehlo.negate %arg0
151+ // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[NEGATE0]]
152+ // CHECK-NEXT: sdy.return %[[ADD0]]
148153 // CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
149- // CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC]]
150- // CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC]]
154+ // CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[CONST1]]) (%arg0: tensor<8x16xf32>) {
155+ // CHECK-NEXT: %[[NEGATE1:.*]] = stablehlo.negate %arg0
156+ // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[NEGATE1]]
157+ // CHECK-NEXT: sdy.return %[[ADD1]]
158+ // CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
159+ // CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC0]]
160+ // CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC1]]
151161 // CHECK-NEXT: return %[[ABS_0]], %[[ABS_1]]
152- // TODO(enver): The named computation should be splitted.
153162 %0 = stablehlo.constant dense <1.000000e+00 > : tensor <8 x16 xf32 >
154163 %1 = sdy.named_computation <" foo" >(%0 ) (%arg0: tensor <8 x16 xf32 >) {
155164 %2 = stablehlo.negate %arg0 : tensor <8 x16 xf32 >
@@ -165,15 +174,20 @@ func.func @constant_to_named_computation_with_one_argument_and_with_only_constan
165174func.func @constant_multiple_users_one_to_named_computation_with_one_argument_and_with_only_constant_ops () -> (tensor <8 x16 xf32 >, tensor <8 x16 xf32 >, tensor <8 x16 xf32 >) {
166175 // CHECK-NEXT: %[[CONST_0:.*]] = sdy.constant dense<1.000000e+00>
167176 // CHECK-NEXT: %[[CONST_1:.*]] = sdy.constant dense<1.000000e+00>
168- // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%[[CONST_0]]) (%arg0: tensor<8x16xf32>) {
169- // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
170- // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[NEGATE]]
171- // CHECK-NEXT: sdy.return %[[ADD]]
177+ // CHECK-NEXT: %[[CONST_2:.*]] = sdy.constant dense<1.000000e+00>
178+ // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[CONST_1]]) (%arg0: tensor<8x16xf32>) {
179+ // CHECK-NEXT: %[[NEGATE0:.*]] = stablehlo.negate %arg0
180+ // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[NEGATE0]]
181+ // CHECK-NEXT: sdy.return %[[ADD0]]
172182 // CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
173- // CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC]]
174- // CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC]]
175- // CHECK-NEXT: return %[[CONST_1]], %[[ABS_0]], %[[ABS_1]]
176- // TODO(enver): The named computation should be splitted.
183+ // CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[CONST_2]]) (%arg0: tensor<8x16xf32>) {
184+ // CHECK-NEXT: %[[NEGATE1:.*]] = stablehlo.negate %arg0
185+ // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[NEGATE1]]
186+ // CHECK-NEXT: sdy.return %[[ADD1]]
187+ // CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
188+ // CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC0]]
189+ // CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC1]]
190+ // CHECK-NEXT: return %[[CONST_0]], %[[ABS_0]], %[[ABS_1]]
177191 %0 = stablehlo.constant dense <1.000000e+00 > : tensor <8 x16 xf32 >
178192 %1 = sdy.named_computation <" foo" >(%0 ) (%arg0: tensor <8 x16 xf32 >) {
179193 %2 = stablehlo.negate %arg0 : tensor <8 x16 xf32 >
@@ -687,17 +701,25 @@ func.func @constant_both_to_named_computation_and_inside_named_computation_and_n
687701func.func @constant_both_to_named_computation_and_inside_named_computation_and_named_computation_is_constant () -> (tensor <8 x16 xf32 >, tensor <8 x16 xf32 >, tensor <8 x16 xf32 >) {
688702 // CHECK-NEXT: %[[CONST0:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
689703 // CHECK-NEXT: %[[CONST1:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
690- // CHECK-NEXT: %[[NC :.*]] = sdy.named_computation<"foo">(%[[CONST0]]) (%arg0 : tensor<8x16xf32>) {
691- // CHECK-NEXT: %[[CONST2 :.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
704+ // CHECK-NEXT: %[[CONST2 :.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
705+ // CHECK-NEXT: %[[NC0 :.*]] = sdy.named_computation<"foo">(%[[CONST1]]) (%arg0 : tensor<8x16xf32>) {
692706 // CHECK-NEXT: %[[CONST3:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
693- // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[CONST2]] : tensor<8x16xf32>
694- // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[CONST3]] : tensor<8x16xf32>
695- // CHECK-NEXT: %[[MULTIPLY:.*]] = stablehlo.multiply %[[ADD0]], %[[ADD1]] : tensor<8x16xf32>
696- // CHECK-NEXT: sdy.return %[[MULTIPLY]] : tensor<8x16xf32>
707+ // CHECK-NEXT: %[[CONST4:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
708+ // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[CONST3]] : tensor<8x16xf32>
709+ // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[CONST4]] : tensor<8x16xf32>
710+ // CHECK-NEXT: %[[MULTIPLY0:.*]] = stablehlo.multiply %[[ADD0]], %[[ADD1]] : tensor<8x16xf32>
711+ // CHECK-NEXT: sdy.return %[[MULTIPLY0]] : tensor<8x16xf32>
697712 // CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
698- // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[NC]] : tensor<8x16xf32>
699- // CHECK-NEXT: return %[[CONST1]], %[[NC]], %[[NEGATE]] : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>
700- // TODO(enver): The named computation should be splitted as well.
713+ // CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[CONST2]]) (%arg0: tensor<8x16xf32>) {
714+ // CHECK-NEXT: %[[CONST5:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
715+ // CHECK-NEXT: %[[CONST6:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
716+ // CHECK-NEXT: %[[ADD2:.*]] = stablehlo.add %arg0, %[[CONST5]] : tensor<8x16xf32>
717+ // CHECK-NEXT: %[[ADD3:.*]] = stablehlo.add %arg0, %[[CONST6]] : tensor<8x16xf32>
718+ // CHECK-NEXT: %[[MULTIPLY1:.*]] = stablehlo.multiply %[[ADD2]], %[[ADD3]] : tensor<8x16xf32>
719+ // CHECK-NEXT: sdy.return %[[MULTIPLY1]] : tensor<8x16xf32>
720+ // CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
721+ // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[NC1]] : tensor<8x16xf32>
722+ // CHECK-NEXT: return %[[CONST0]], %[[NC0]], %[[NEGATE]] : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>
701723 %0 = stablehlo.constant dense <1.000000e+00 > : tensor <8 x16 xf32 >
702724 %1 = sdy.named_computation <" foo" >(%0 ) (%arg0: tensor <8 x16 xf32 >) {
703725 %2 = stablehlo.constant dense <1.000000e+00 > : tensor <8 x16 xf32 >
0 commit comments