@@ -112,28 +112,42 @@ func.func @matmul_transpose_b(%5: tensor<64x64xf32>, %6: tensor<64x1280xf16>, %7
112112
113113// -----
114114
115- #config = #iree_gpu.lowering_config <{reduction = [0 , 8 ]}>
116- #map1 = affine_map <(d0 , d1 ) -> (d0 , d1 )>
117- #map2 = affine_map <(d0 , d1 ) -> (d0 )>
118- func.func @reduction (%3: tensor <128 x384 xf32 >) -> tensor <128 xf32 > {
115+ #config = #iree_gpu.lowering_config <{reduction = [0 , 8 , 4 ]}>
116+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
117+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 )>
118+ func.func @reduction (%arg0: tensor <128 x384 x256 xf32 >) -> tensor <128 xf32 > {
119+ %c0 = arith.constant 0 : index
120+ %c1 = arith.constant 1 : index
121+ %c3 = arith.constant 3 : index
119122 %cst = arith.constant 0.000000e+00 : f32
120123 %empty = tensor.empty () : tensor <128 xf32 >
121- %4 = linalg.fill ins (%cst : f32 ) outs (%empty : tensor <128 xf32 >) -> tensor <128 xf32 >
122- %5 = linalg.generic {
123- indexing_maps = [#map1 , #map2 ],
124- iterator_types = [" parallel" , " reduction" ]
125- } ins (%3 : tensor <128 x384 xf32 >) outs (%4 : tensor <128 xf32 >) attrs = {lowering_config = #config } {
126- ^bb0 (%in: f32 , %out: f32 ):
127- %7 = arith.addf %in , %out : f32
128- linalg.yield %7 : f32
129- } -> tensor <128 xf32 >
130- return %5 : tensor <128 xf32 >
124+ %init = linalg.fill ins (%cst : f32 ) outs (%empty : tensor <128 xf32 >) -> tensor <128 xf32 >
125+
126+ // Parent scf.for loop that will be coalesced with reduction tiling loops.
127+ %result = scf.for %iv = %c0 to %c3 step %c1 iter_args (%arg1 = %init ) -> (tensor <128 xf32 >) {
128+ %slice = tensor.extract_slice %arg0 [0 , 0 , 0 ] [128 , 384 , 256 ] [1 , 1 , 1 ] : tensor <128 x384 x256 xf32 > to tensor <128 x384 x256 xf32 >
129+ %reduced = linalg.generic {
130+ indexing_maps = [#map1 , #map2 ],
131+ iterator_types = [" parallel" , " reduction" , " reduction" ]
132+ } ins (%slice : tensor <128 x384 x256 xf32 >) outs (%arg1 : tensor <128 xf32 >) attrs = {lowering_config = #config } {
133+ ^bb0 (%in: f32 , %out: f32 ):
134+ %add = arith.addf %in , %out : f32
135+ linalg.yield %add : f32
136+ } -> tensor <128 xf32 >
137+ scf.yield %reduced : tensor <128 xf32 >
138+ }
139+ return %result : tensor <128 xf32 >
131140}
132141
133142// CHECK-LABEL: func.func @reduction
134- // CHECK: %[[FILL:.+]] = linalg.fill {{.*}} tensor<128xf32>
135- // CHECK: scf.for %{{.*}} = %c0 to %c384 step %c8 iter_args(%{{.*}} = %[[FILL]])
136- // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8xf32>)
143+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
144+ // CHECK-DAG: %[[C9216:.+]] = arith.constant 9216 : index
145+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
146+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
147+ // CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
148+ // CHECK: scf.for %{{.*}} = %[[C0]] to %[[C9216]] step %[[C1]] iter_args(%[[ARG:.+]] = %[[INIT]])
149+ // CHECK-NOT: scf.for
150+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[ARG]] : tensor<128xf32>)
137151// CHECK: scf.yield
138152
139153// Verify that no tiling happens in the thread case.
@@ -142,6 +156,68 @@ func.func @reduction(%3: tensor<128x384xf32>) -> tensor<128xf32> {
142156
143157// -----
144158
159+ // Test coalescing when parent scf.for has iter_args but NOT chained with reduction.
160+ #config2 = #iree_gpu.lowering_config <{reduction = [0 , 8 , 4 ]}>
161+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
162+ #map4 = affine_map <(d0 , d1 , d2 ) -> (d0 )>
163+ #map5 = affine_map <(d0 ) -> (d0 )>
164+ func.func @reduction_nochain_iter_args (%arg0: tensor <128 x384 x256 xf32 >) -> tensor <128 xf32 > {
165+ %c0 = arith.constant 0 : index
166+ %c1 = arith.constant 1 : index
167+ %c3 = arith.constant 3 : index
168+ %cst = arith.constant 0.000000e+00 : f32
169+ %empty = tensor.empty () : tensor <128 xf32 >
170+ %ew_init = linalg.fill ins (%cst : f32 ) outs (%empty : tensor <128 xf32 >) -> tensor <128 xf32 >
171+
172+ // Parent scf.for loop with iter_args but NOT chained with reduction.
173+ %result = scf.for %iv = %c0 to %c3 step %c1 iter_args (%ew = %ew_init ) -> (tensor <128 xf32 >) {
174+ %empty2 = tensor.empty () : tensor <128 xf32 >
175+ %init = linalg.fill ins (%cst : f32 ) outs (%empty2 : tensor <128 xf32 >) -> tensor <128 xf32 >
176+ %slice = tensor.extract_slice %arg0 [0 , 0 , 0 ] [128 , 384 , 256 ] [1 , 1 , 1 ] : tensor <128 x384 x256 xf32 > to tensor <128 x384 x256 xf32 >
177+ %reduced = linalg.generic {
178+ indexing_maps = [#map3 , #map4 ],
179+ iterator_types = [" parallel" , " reduction" , " reduction" ]
180+ } ins (%slice : tensor <128 x384 x256 xf32 >) outs (%init : tensor <128 xf32 >) attrs = {lowering_config = #config2 } {
181+ ^bb0 (%in: f32 , %out: f32 ):
182+ %add = arith.addf %in , %out : f32
183+ linalg.yield %add : f32
184+ } -> tensor <128 xf32 >
185+
186+ // elementwise that uses the parent scf.for iter arg.
187+ %empty3 = tensor.empty () : tensor <128 xf32 >
188+ %elementwise = linalg.generic {
189+ indexing_maps = [#map5 , #map5 , #map5 ],
190+ iterator_types = [" parallel" ]
191+ } ins (%ew , %reduced : tensor <128 xf32 >, tensor <128 xf32 >) outs (%empty3 : tensor <128 xf32 >) {
192+ ^bb0 (%e: f32 , %r: f32 , %out: f32 ):
193+ %new = arith.addf %e , %r : f32
194+ linalg.yield %new : f32
195+ } -> tensor <128 xf32 >
196+
197+ scf.yield %elementwise : tensor <128 xf32 >
198+ }
199+ return %result : tensor <128 xf32 >
200+ }
201+
202+ // CHECK-LABEL: func.func @reduction_nochain_iter_args
203+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
204+ // CHECK-DAG: %[[C3072:.+]] = arith.constant 3072 : index
205+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
206+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
207+ // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
208+ // CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
209+ // CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[EW_ARG:.+]] = %[[INIT]])
210+ // CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3072]] step %[[C1]] iter_args(%[[RED_ARG:.+]] = %[[INIT]])
211+ // CHECK-NOT: scf.for
212+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[RED_ARG]] : tensor<128xf32>)
213+ // CHECK: linalg.generic {{.*}} ins(%[[EW_ARG]], %{{.*}} : tensor<128xf32>, tensor<128xf32>)
214+ // CHECK: scf.yield
215+
216+ // THREAD-LABEL: func.func @reduction_no_iter_args
217+ // THREAD-NOT: scf.forall
218+
219+ // -----
220+
145221#config = #iree_gpu.lowering_config <{reduction = [0 , 0 , 8 ]}>
146222#map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
147223func.func @matmul_fuse (%3: tensor <64 x64 xf32 >, %4: tensor <64 x64 xf32 >, %5: tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 > {
0 commit comments