Skip to content

Commit 1393d12

Browse files
Tai78641Jerry-Ge
authored andcommitted
[mlir][tosa] EqualizeRanks for ops with SameOperandsAndResultRank
This patch refactors tf/tfl to TOSA lowering to prepare for adding trait SameOperandsAndResultRank to TOSA element wise operators (which will then require these operators be built with operands with same ranks) 1. Refactored to use: getTosaConstTensorSingleF32 getTosaConstTensorSingleI32 getTosagetTosaConstTensorScalarInt to construct tosa constant tensors with single value and specified rank 2. Changed lowering of following tf operatos: BitwiseOr BitwiseXor BitwiseAnd LogicalAnd LogicalOr Pow to go through CreateReplaceOpAndInfer and CreateOpAndInfer builder functions 3. Refactor to use CreateMulOpAndInfer to construct tosa Mul operations - this calls EqualizeRanks on multiply inputs to insert reshape as needed to ensure mul operands have same ranks - and default shift value to 0 if unspecified 4. Refactor callers of CreateOpAndInfer/CreateReplaceOpAndInfer to pass Value arguments for inputs of tosa element wise operators 5. In CreateOpAndInfer, call tosa::CreateOpAndInferShape which in turn checks for operator trait SameOperandsAndResultRank and calls EqualizeRanks to insert reshape as needed to ensure operands have same ranks: 6. Changed lowering of following tfl operations: LogicalAnd LogicalOr Pow to go through CreateReplaceOpAndInfer and CreateOpAndInfer builder functions Signed-off-by: Tai Ly <[email protected]> Change-Id: I598de6fe7e0312513ec73c718774268fa08692c5
1 parent 56f701f commit 1393d12

File tree

11 files changed

+916
-512
lines changed

11 files changed

+916
-512
lines changed

tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -918,9 +918,9 @@ func.func @test_right_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1xi32>) -> t
918918
// -----
919919

920920
// CHECK-LABEL: @test_one_hot
921-
// CHECK-SAME: %[[ARG0_0:.*]]: tensor<4x4xi32>, %[[ARG1_0:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>
922-
// CHECK: %[[CST1:.*]] = tosa.const_shape {value = dense<[16, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
923-
// CHECK: %[[CST2:.*]] = tosa.const_shape {value = dense<[16, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
921+
// CHECK-SAME: %[[ARG0_0:.*]]: tensor<4x4xi32>, %[[ARG1_0:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>
922+
// CHECK-DAG: %[[CST1:.*]] = tosa.const_shape {value = dense<[16, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
923+
// CHECK-DAG: %[[CST2:.*]] = tosa.const_shape {value = dense<[16, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
924924
// CHECK: %[[RESHAPE_0:.*]] = tosa.reshape %[[ARG1_0]] {new_shape = array<i64: 1, 1, 1>}
925925
// CHECK: %[[TILE:.*]] = tosa.tile %[[RESHAPE_0]], %[[CST1]]
926926
// CHECK: %[[RESHAPE_1:.*]] = tosa.reshape %[[ARG2]] {new_shape = array<i64: 1, 1, 1>}
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// RUN: tf-opt --split-input-file --tf-to-tosa-pipeline --verify-each %s | FileCheck %s
2+
3+
// Test tf legalization that produce TOSA ResultsBroadcastableShape operators with unequal ranks
4+
5+
// -----
6+
7+
// CHECK-LABEL: test_add
8+
func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1x13x21x3xf32>) -> tensor<*xf32> {
9+
// CHECK: tosa.add
10+
%2 = "tf.Add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1x13x21x3xf32>) -> tensor<*xf32>
11+
func.return %2 : tensor<*xf32>
12+
}
13+
14+
// -----
15+
16+
// CHECK-LABEL: test_add
17+
func.func @test_addn(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1x13x21x3xf32>, %arg2: tensor<21x3xf32>, %arg3: tensor<3xf32>) -> tensor<*xf32> {
18+
// CHECK: tosa.add
19+
// CHECK: tosa.add
20+
// CHECK: tosa.add
21+
%2 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3) : (tensor<13x21x1xf32>, tensor<1x13x21x3xf32>, tensor<21x3xf32>, tensor<3xf32>) -> tensor<*xf32>
22+
func.return %2 : tensor<*xf32>
23+
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: test_bitwise_and
28+
func.func @test_bitwise_and(%arg0: tensor<8x13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32> {
29+
// CHECK: tosa.bitwise_and
30+
%2 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<8x13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32>
31+
func.return %2 : tensor<8x13x21x3xi32>
32+
}
33+
34+
// -----
35+
36+
// CHECK-LABEL: test_sub
37+
func.func @test_sub(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1x13x21x3xf32>) -> tensor<*xf32> {
38+
// CHECK: tosa.sub
39+
%2 = "tf.Sub"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1x13x21x3xf32>) -> tensor<*xf32>
40+
func.return %2 : tensor<*xf32>
41+
}
42+
43+
// -----
44+
45+
// CHECK-LABEL: test_bitwise_or
46+
func.func @test_bitwise_or(%arg0: tensor<8x13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32> {
47+
// CHECK: tosa.bitwise_or
48+
%2 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<8x13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32>
49+
func.return %2 : tensor<8x13x21x3xi32>
50+
}
51+
52+
// -----
53+
54+
// CHECK-LABEL: test_bitwise_xor
55+
func.func @test_bitwise_xor(%arg0: tensor<8x13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32> {
56+
// CHECK: tosa.bitwise_xor
57+
%2 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<8x13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32>
58+
func.return %2 : tensor<8x13x21x3xi32>
59+
}
60+
61+
// -----
62+
63+
// CHECK-LABEL: test_logical_and
64+
func.func @test_logical_and(%arg0: tensor<8x13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<8x13x21x3xi1> {
65+
// CHECK: tosa.logical_and
66+
%2 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<8x13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<8x13x21x3xi1>
67+
func.return %2 : tensor<8x13x21x3xi1>
68+
}
69+
70+
// -----
71+
72+
// CHECK-LABEL: test_logical_or
73+
func.func @test_logical_or(%arg0: tensor<8x13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<8x13x21x3xi1> {
74+
// CHECK: tosa.logical_or
75+
%2 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<8x13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<8x13x21x3xi1>
76+
func.return %2 : tensor<8x13x21x3xi1>
77+
}
78+
79+
// -----
80+
81+
// CHECK-LABEL: test_floor_div
82+
func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> {
83+
// CHECK: tosa.int_div
84+
%2 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32>
85+
func.return %2 : tensor<1x13x21x3xi32>
86+
}
87+
88+
// -----
89+
90+
// CHECK-LABEL: test_real_div
91+
// CHECK: tosa.int_div
92+
func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> {
93+
%2 = "tf.RealDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32>
94+
func.return %2 : tensor<1x13x21x3xi32>
95+
}
96+
97+
// -----
98+
99+
// CHECK-LABEL: test_left_shift
100+
func.func @test_left_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1x1xi32>) -> tensor<1x4x4xi32> {
101+
// CHECK: tosa.logical_left_shift
102+
%0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4x4xi32>, tensor<1x1x1xi32>) -> tensor<1x4x4xi32>
103+
func.return %0 : tensor<1x4x4xi32>
104+
}
105+
106+
// -----
107+
108+
// CHECK-LABEL: test_right_shift
109+
func.func @test_right_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1x1xi32>) -> tensor<1x4x4xi32> {
110+
// CHECK: tosa.arithmetic_right_shift
111+
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x4xi32>, tensor<1x1x1xi32>) -> tensor<1x4x4xi32>
112+
func.return %0 : tensor<1x4x4xi32>
113+
}
114+
115+
// -----
116+
117+
// CHECK-LABEL: test_max
118+
func.func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x13x21x1xf32>) -> tensor<1x13x21x3xf32> {
119+
// CHECK: tosa.maximum
120+
%2 = "tf.Maximum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x13x21x1xf32>) -> tensor<1x13x21x3xf32>
121+
func.return %2 : tensor<1x13x21x3xf32>
122+
}
123+
124+
// -----
125+
126+
// CHECK-LABEL: test_min
127+
func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x13x21x1xf32>) -> tensor<1x13x21x3xf32> {
128+
// CHECK: tosa.minimum
129+
%2 = "tf.Minimum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x13x21x1xf32>) -> tensor<1x13x21x3xf32>
130+
func.return %2 : tensor<1x13x21x3xf32>
131+
}
132+
133+
// -----
134+
135+
// CHECK-LABEL: test_power
136+
func.func @test_power(%arg0: tensor<8x13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32> {
137+
// CHECK: tosa.pow
138+
%2 = "tf.Pow"(%arg0, %arg1) : (tensor<8x13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<8x13x21x3xi32>
139+
func.return %2 : tensor<8x13x21x3xi32>
140+
}
141+
142+
// -----
143+
144+
// CHECK-LABEL: test_equal
145+
// CHECK: tosa.equal
146+
func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x13x1x3xf32>) -> tensor<1x13x21x3xi1> {
147+
%2 = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<13x21x3xf32>, tensor<1x13x1x3xf32>) -> tensor<1x13x21x3xi1>
148+
func.return %2 : tensor<1x13x21x3xi1>
149+
}
150+
151+
// -----
152+
153+
// CHECK-LABEL: test_greater_equal
154+
// CHECK: tosa.greater_equal
155+
func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xi1> {
156+
%2 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xi1>
157+
func.return %2 : tensor<1x13x21x3xi1>
158+
}
159+
160+
// -----
161+
162+
// CHECK-LABEL: test_greater
163+
// CHECK: tosa.greater
164+
func.func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xi1> {
165+
%2 = "tf.Greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xi1>
166+
func.return %2 : tensor<1x13x21x3xi1>
167+
}
168+
169+
// -----
170+
171+
// CHECK-LABEL: test_less
172+
// CHECK: tosa.greater_equal
173+
// CHECK: tosa.logical_not
174+
func.func @test_less(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xi1> {
175+
%2 = "tf.Less"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xi1>
176+
func.return %2 : tensor<1x13x21x3xi1>
177+
}
178+
179+
// -----
180+
// CHECK-LABEL: test_select
181+
// CHECK: tosa.select
182+
func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<1x13x21x3xf32> {
183+
%2 = "tf.SelectV2"(%arg2, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xf32>
184+
func.return %2 : tensor<1x13x21x3xf32>
185+
}

tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,11 +1692,11 @@ func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32
16921692
// -----
16931693

16941694
// CHECK-LABEL: test_space_to_batch_dyn
1695-
// CHECK-DAG: %[[C0:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}>
1696-
// CHECK-DAG: %[[C1:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 2], [0, 0], [0, 0]]> : tensor<4x2xi32>}>
16971695
// CHECK-DAG: %[[C2:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}>
1698-
// CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[C1]], %[[C0]] : (tensor<?x241x1x80xf32>, tensor<4x2xi32>, tensor<f32>) -> tensor<?x243x1x80xf32>
1699-
// CHECK-DAG: %[[R0:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: -1, 81, 3, 1, 1, 80>}
1696+
// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 2, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
1697+
// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}>
1698+
// CHECK-DAG: %[[VAL_6:.*]] = tosa.pad %arg0, %[[VAL_4]], %[[VAL_5]] : (tensor<?x241x1x80xf32>, !tosa.shape<8>, tensor<f32>) -> tensor<?x243x1x80xf32>
1699+
// CHECK-DAG: %[[R0:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array<i64: -1, 81, 3, 1, 1, 80>}
17001700
// CHECK-DAG: %[[T:.+]] = tosa.transpose %[[R0]], %[[C2]]
17011701
// CHECK-DAG: %[[R1:.+]] = tosa.reshape %[[T]] {new_shape = array<i64: -1, 81, 1, 80>}
17021702
// CHECK: return %[[R1]] : tensor<?x81x1x80xf32>
@@ -2728,28 +2728,26 @@ func.func @test_rfft2d_crop_input(%arg0: tensor<13x21x3xf32>) -> tensor<13x2x2xc
27282728
// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32>
27292729
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
27302730
// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 11, 0, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
2731-
// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[13, 32, 5, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
27322731
// CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x32x8xf32>
27332732
// CHECK: %[[VAL_4:.*]], %[[VAL_5:.*]] = tosa.rfft2d %[[VAL_3]] : (tensor<13x32x8xf32>) -> (tensor<13x32x5xf32>, tensor<13x32x5xf32>)
27342733
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 13, 32, 5, 1>} : (tensor<13x32x5xf32>) -> tensor<13x32x5x1xf32>
27352734
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 13, 32, 5, 1>} : (tensor<13x32x5xf32>) -> tensor<13x32x5x1xf32>
27362735
// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_6]], %[[VAL_7]] {axis = 3 : i32} : (tensor<13x32x5x1xf32>, tensor<13x32x5x1xf32>) -> tensor<13x32x5x2xf32>
27372736
// CHECK: return %[[VAL_8]] : tensor<13x32x5x2xf32>
27382737
func.func @test_rfft2d_pad_input(%arg0: tensor<13x21x3xf32>) -> (tensor<13x32x5xcomplex<f32>>) {
2739-
%0 = "tfl.pseudo_const"() {value = dense<[[0, 0], [0, 11], [0, 5]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
2740-
%1 = "tfl.pad"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x32x8xf32>
2741-
%2 = "tfl.pseudo_const"() {value = dense<[32, 8]> : tensor<2xi32>} : () -> tensor<2xi32>
2742-
%3 = "tfl.rfft2d"(%1, %2) : (tensor<13x32x8xf32>, tensor<2xi32>) -> tensor<13x32x5xcomplex<f32>>
2743-
return %3 : tensor<13x32x5xcomplex<f32>>
2744-
}
2738+
%0 = "tfl.pseudo_const"() {value = dense<[[0, 0], [0, 11], [0, 5]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
2739+
%1 = "tfl.pad"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x32x8xf32>
2740+
%2 = "tfl.pseudo_const"() {value = dense<[32, 8]> : tensor<2xi32>} : () -> tensor<2xi32>
2741+
%3 = "tfl.rfft2d"(%1, %2) : (tensor<13x32x8xf32>, tensor<2xi32>) -> tensor<13x32x5xcomplex<f32>>
2742+
return %3 : tensor<13x32x5xcomplex<f32>>
2743+
}
27452744

27462745
// -----
27472746

27482747
// CHECK-LABEL: test_rfft2d_crop_height_pad_width
27492748
// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32>
27502749
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
27512750
// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 13]> : tensor<6xindex>} : () -> !tosa.shape<6>
2752-
// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[13, 2, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
27532751
// CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x16xf32>
27542752
// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_3]] {size = array<i64: 13, 2, 16>, start = array<i64: 0, 0, 0>} : (tensor<13x21x16xf32>) -> tensor<13x2x16xf32>
27552753
// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = tosa.rfft2d %[[VAL_4]] : (tensor<13x2x16xf32>) -> (tensor<13x2x9xf32>, tensor<13x2x9xf32>

0 commit comments

Comments
 (0)