99// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
1010// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
1111//
12- func.func @unary_transpose (%A : tensor <16 x8 x32 xf32 >, %B: tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 > {
12+ func.func @unary_transpose (%A: tensor <16 x8 x32 xf32 >, %B: tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 > {
1313 %empty = tensor.empty () : tensor <8 x16 x32 xf32 >
14- %transposed_A = linalg.transpose ins (%A : tensor <16 x8 x32 xf32 >) outs (%empty : tensor <8 x16 x32 xf32 >) permutation = [1 , 0 , 2 ]
14+ %transposed_A = linalg.transpose ins (%A : tensor <16 x8 x32 xf32 >) outs (%empty : tensor <8 x16 x32 xf32 >) permutation = [1 , 0 , 2 ]
1515 %result = linalg.elementwise kind =#linalg.elementwise_kind <exp >
16- ins (%transposed_A : tensor <8 x16 x32 xf32 >) outs (%B: tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 >
16+ ins (%transposed_A : tensor <8 x16 x32 xf32 >) outs (%B : tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 >
1717 return %result : tensor <8 x16 x32 xf32 >
1818}
1919
@@ -28,16 +28,220 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->
2828// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
2929// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
3030//
31- func.func @binary_transposed (%A : tensor <?x?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
31+ func.func @binary_transposed (%A: tensor <?x?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
3232 %c0 = arith.constant 0 : index
3333 %c1 = arith.constant 1 : index
3434 %dim0 = tensor.dim %A , %c0 : tensor <?x?xf32 >
3535 %dim1 = tensor.dim %A , %c1 : tensor <?x?xf32 >
3636
3737 %empty = tensor.empty (%dim1 , %dim0 ) : tensor <?x?xf32 >
38- %transposed_B = linalg.transpose ins (%B : tensor <?x?xf32 >) outs (%empty : tensor <?x?xf32 >) permutation = [1 , 0 ]
38+ %transposed_B = linalg.transpose ins (%B : tensor <?x?xf32 >) outs (%empty : tensor <?x?xf32 >) permutation = [1 , 0 ]
3939 %result = linalg.elementwise kind =#linalg.elementwise_kind <add >
40- ins (%A , %transposed_B : tensor <?x?xf32 >, tensor <?x?xf32 >)
41- outs (%C: tensor <?x?xf32 >) -> tensor <?x?xf32 >
40+ ins (%A , %transposed_B : tensor <?x?xf32 >, tensor <?x?xf32 >)
41+ outs (%C : tensor <?x?xf32 >) -> tensor <?x?xf32 >
4242 return %result : tensor <?x?xf32 >
4343}
44+
45+ // -----
46+
47+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
48+ // CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
49+ //
50+ // CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
51+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
52+ // CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
53+ // CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
54+ // CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
55+ //
56+ func.func @unary_broadcasted (%A: tensor <8 x32 xf32 >, %B: tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 > {
57+ %empty = tensor.empty () : tensor <8 x16 x32 xf32 >
58+ %broadcasted_A = linalg.broadcast ins (%A : tensor <8 x32 xf32 >) outs (%empty : tensor <8 x16 x32 xf32 >) dimensions = [1 ]
59+ %result = linalg.elementwise kind =#linalg.elementwise_kind <exp >
60+ ins (%broadcasted_A : tensor <8 x16 x32 xf32 >) outs (%B : tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 >
61+ return %result : tensor <8 x16 x32 xf32 >
62+ }
63+
64+ // -----
65+
66+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
67+ // CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
68+ //
69+ // CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
70+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
71+ // CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
72+ // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
73+ // CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
74+ //
75+ func.func @binary_broadcasted (%A: tensor <?x?xf32 >, %B: tensor <?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
76+ %c0 = arith.constant 0 : index
77+ %c1 = arith.constant 1 : index
78+ %dim0 = tensor.dim %A , %c0 : tensor <?x?xf32 >
79+ %dim1 = tensor.dim %A , %c1 : tensor <?x?xf32 >
80+
81+ %empty = tensor.empty (%dim1 , %dim0 ) : tensor <?x?xf32 >
82+ %broadcasted_B = linalg.broadcast ins (%B : tensor <?xf32 >) outs (%empty : tensor <?x?xf32 >) dimensions = [1 ]
83+ %result = linalg.elementwise kind =#linalg.elementwise_kind <add >
84+ ins (%A , %broadcasted_B : tensor <?x?xf32 >, tensor <?x?xf32 >)
85+ outs (%C : tensor <?x?xf32 >) -> tensor <?x?xf32 >
86+ return %result : tensor <?x?xf32 >
87+ }
88+
89+ // -----
90+
91+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
92+ // CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
93+ //
94+ // CHECK: func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> {
95+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
96+ // CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
97+ // CHECK-SAME: ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32>
98+ // CHECK-NEXT: return %[[RES]] : tensor<16x32xf32>
99+ //
100+ #identity = affine_map <(d0 , d1 ) -> (d0 , d1 )>
101+ #transpose = affine_map <(d0 , d1 ) -> (d1 , d0 )>
102+
103+ func.func @fold_broadcast_after_transpose_fold (%A: tensor <16 xf32 >, %B: tensor <16 x32 xf32 >) -> tensor <16 x32 xf32 > {
104+ %empty_b = tensor.empty () : tensor <32 x16 xf32 >
105+
106+ %broadcasted_A = linalg.broadcast ins (%A : tensor <16 xf32 >) outs (%empty_b : tensor <32 x16 xf32 >) dimensions = [0 ]
107+
108+ %result = linalg.elementwise kind =#linalg.elementwise_kind <exp >
109+ indexing_maps = [#transpose , #identity ]
110+ ins (%broadcasted_A : tensor <32 x16 xf32 >) outs (%B : tensor <16 x32 xf32 >) -> tensor <16 x32 xf32 >
111+ return %result : tensor <16 x32 xf32 >
112+ }
113+
114+ // -----
115+
116+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
117+ // CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
118+ //
119+ // CHECK: func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
120+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
121+ // CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
122+ // CHECK-SAME: ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
123+ // CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
124+ //
125+ #identity = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
126+ #broadcast = affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>
127+
128+ func.func @fold_transpose_after_broadcast_fold (%A: tensor <32 x16 xf32 >, %B: tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 > {
129+ %empty_t = tensor.empty () : tensor <16 x32 xf32 >
130+ %transposed_A = linalg.transpose ins (%A : tensor <32 x16 xf32 >) outs (%empty_t : tensor <16 x32 xf32 >) permutation = [1 , 0 ]
131+
132+ %result = linalg.elementwise kind =#linalg.elementwise_kind <exp >
133+ indexing_maps = [#broadcast , #identity ]
134+ ins (%transposed_A : tensor <16 x32 xf32 >) outs (%B : tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 >
135+ return %result : tensor <8 x16 x32 xf32 >
136+ }
137+
138+ // -----
139+
140+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
141+ // CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
142+ //
143+ // CHECK: func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
144+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
145+ // CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
146+ // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
147+ // CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
148+ //
149+ #identity = affine_map <(d0 , d1 ) -> (d0 , d1 )>
150+ #transpose = affine_map <(d0 , d1 ) -> (d1 , d0 )>
151+
152+ func.func @fold_broadcast_after_transpose_fold_binary (%A: tensor <?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
153+ %c0 = arith.constant 0 : index
154+ %c1 = arith.constant 1 : index
155+ %dim0 = tensor.dim %B , %c0 : tensor <?x?xf32 >
156+ %dim1 = tensor.dim %B , %c1 : tensor <?x?xf32 >
157+
158+ %empty_b = tensor.empty (%dim1 , %dim0 ) : tensor <?x?xf32 >
159+ %broadcasted_A = linalg.broadcast ins (%A : tensor <?xf32 >) outs (%empty_b : tensor <?x?xf32 >) dimensions = [0 ]
160+
161+ %result = linalg.elementwise kind =#linalg.elementwise_kind <add >
162+ indexing_maps = [#transpose , #identity , #identity ]
163+ ins (%broadcasted_A , %B : tensor <?x?xf32 >, tensor <?x?xf32 >) outs (%C : tensor <?x?xf32 >) -> tensor <?x?xf32 >
164+
165+ return %result : tensor <?x?xf32 >
166+ }
167+
168+ // -----
169+
170+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
171+ // CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
172+ //
173+ // CHECK: func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?x?xf32>, %[[C:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
174+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
175+ // CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
176+ // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%[[C]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
177+ // CHECK-NEXT: return %[[RES]] : tensor<?x?x?xf32>
178+ //
179+ #identity = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
180+ #broadcast = affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>
181+
182+ func.func @fold_transpose_after_broadcast_fold_binary (%A: tensor <?x?xf32 >, %B: tensor <?x?x?xf32 >, %C: tensor <?x?x?xf32 >) -> tensor <?x?x?xf32 > {
183+ %c0 = arith.constant 0 : index
184+ %c1 = arith.constant 1 : index
185+ %c2 = arith.constant 2 : index
186+ %dim0 = tensor.dim %B , %c0 : tensor <?x?x?xf32 >
187+ %dim1 = tensor.dim %B , %c1 : tensor <?x?x?xf32 >
188+ %dim2 = tensor.dim %B , %c2 : tensor <?x?x?xf32 >
189+
190+ %empty_t = tensor.empty (%dim1 , %dim2 ) : tensor <?x?xf32 >
191+ %transposed_A = linalg.transpose ins (%A : tensor <?x?xf32 >) outs (%empty_t : tensor <?x?xf32 >) permutation = [1 , 0 ]
192+
193+ %result = linalg.elementwise kind =#linalg.elementwise_kind <add >
194+ indexing_maps = [#broadcast , #identity , #identity ]
195+ ins (%transposed_A , %B : tensor <?x?xf32 >, tensor <?x?x?xf32 >) outs (%C : tensor <?x?x?xf32 >) -> tensor <?x?x?xf32 >
196+ return %result : tensor <?x?x?xf32 >
197+ }
198+
199+ // -----
200+
201+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0) -> (d0)>
202+ // CHECK-DAG: #[[DIAGONAL:.+]] = affine_map<(d0) -> (d0, d0)>
203+ //
204+ // CHECK: func.func @fold_failed_diagonal_map(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16xf32>) -> tensor<16xf32> {
205+ // CHECK-NEXT: %[[EMPTY:.+]] = tensor.empty() : tensor<16x16xf32>
206+ // CHECK-NEXT: %[[BROADCASTED_B:.+]] = linalg.broadcast ins(%[[B]] : tensor<16xf32>) outs(%[[EMPTY]] : tensor<16x16xf32>) dimensions = [0]
207+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
208+ // CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[DIAGONAL]], #[[IDENTITY]]]
209+ // CHECK-SAME: ins(%[[A]], %[[BROADCASTED_B]] : tensor<16xf32>, tensor<16x16xf32>) outs(%[[C]] : tensor<16xf32>) -> tensor<16xf32>
210+ // CHECK-NEXT: return %[[RES]] : tensor<16xf32>
211+ //
212+ #identity = affine_map <(d0 ) -> (d0 )>
213+ #diagonal = affine_map <(d0 ) -> (d0 , d0 )>
214+
215+ func.func @fold_failed_diagonal_map (%A: tensor <16 xf32 >, %B: tensor <16 xf32 >, %C: tensor <16 xf32 >) -> tensor <16 xf32 > {
216+ %empty = tensor.empty () : tensor <16 x16 xf32 >
217+ %broadcasted_B = linalg.broadcast ins (%B : tensor <16 xf32 >) outs (%empty : tensor <16 x16 xf32 >) dimensions = [0 ]
218+ %result = linalg.elementwise kind =#linalg.elementwise_kind <add >
219+ indexing_maps = [#identity , #diagonal , #identity ]
220+ ins (%A , %broadcasted_B : tensor <16 xf32 >, tensor <16 x16 xf32 >) outs (%C : tensor <16 xf32 >) -> tensor <16 xf32 >
221+ return %result : tensor <16 xf32 >
222+ }
223+
224+ // -----
225+
226+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0) -> (d0)>
227+ // CHECK-DAG: #[[CONSTANT:.+]] = affine_map<(d0) -> (0, d0)>
228+ //
229+ // CHECK: func.func @fold_failed_constant_map(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>, %[[C:.+]]: tensor<16xf32>) -> tensor<16xf32> {
230+ // CHECK-NEXT: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16xf32>
231+ // CHECK-NEXT: %[[TRANSPOSED_B:.+]] = linalg.transpose ins(%[[B]] : tensor<16x32xf32>) outs(%[[EMPTY]] : tensor<32x16xf32>) permutation = [1, 0]
232+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
233+ // CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[CONSTANT]], #[[IDENTITY]]]
234+ // CHECK-SAME: ins(%[[A]], %[[TRANSPOSED_B]] : tensor<16xf32>, tensor<32x16xf32>) outs(%[[C]] : tensor<16xf32>) -> tensor<16xf32>
235+ // CHECK-NEXT: return %[[RES]] : tensor<16xf32>
236+ //
237+ #identity = affine_map <(d0 ) -> (d0 )>
238+ #constant = affine_map <(d0 ) -> (0 , d0 )>
239+
240+ func.func @fold_failed_constant_map (%A: tensor <16 xf32 >, %B: tensor <16 x32 xf32 >, %C: tensor <16 xf32 >) -> tensor <16 xf32 > {
241+ %empty = tensor.empty () : tensor <32 x16 xf32 >
242+ %transposed_B = linalg.transpose ins (%B : tensor <16 x32 xf32 >) outs (%empty : tensor <32 x16 xf32 >) permutation = [1 , 0 ]
243+ %result = linalg.elementwise kind =#linalg.elementwise_kind <add >
244+ indexing_maps = [#identity , #constant , #identity ]
245+ ins (%A , %transposed_B : tensor <16 xf32 >, tensor <32 x16 xf32 >) outs (%C : tensor <16 xf32 >) -> tensor <16 xf32 >
246+ return %result : tensor <16 xf32 >
247+ }
0 commit comments