Skip to content

Commit ffb0b44

Browse files
[GPU] Add coalescing to reduction tiling
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com> Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent 85dd208 commit ffb0b44

5 files changed

Lines changed: 271 additions & 58 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir

Lines changed: 134 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,28 +112,42 @@ func.func @matmul_transpose_b(%5: tensor<64x64xf32>, %6: tensor<64x1280xf16>, %7
112112

113113
// -----
114114

115-
#config = #iree_gpu.lowering_config<{reduction = [0, 8]}>
116-
#map1 = affine_map<(d0, d1) -> (d0, d1)>
117-
#map2 = affine_map<(d0, d1) -> (d0)>
118-
func.func @reduction(%3: tensor<128x384xf32>) -> tensor<128xf32> {
115+
#config = #iree_gpu.lowering_config<{reduction = [0, 8, 4]}>
116+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
117+
#map2 = affine_map<(d0, d1, d2) -> (d0)>
118+
func.func @reduction(%arg0: tensor<128x384x256xf32>) -> tensor<128xf32> {
119+
%c0 = arith.constant 0 : index
120+
%c1 = arith.constant 1 : index
121+
%c3 = arith.constant 3 : index
119122
%cst = arith.constant 0.000000e+00 : f32
120123
%empty = tensor.empty() : tensor<128xf32>
121-
%4 = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
122-
%5 = linalg.generic {
123-
indexing_maps = [#map1, #map2],
124-
iterator_types = ["parallel", "reduction"]
125-
} ins(%3 : tensor<128x384xf32>) outs(%4 : tensor<128xf32>) attrs = {lowering_config = #config} {
126-
^bb0(%in: f32, %out: f32):
127-
%7 = arith.addf %in, %out : f32
128-
linalg.yield %7 : f32
129-
} -> tensor<128xf32>
130-
return %5 : tensor<128xf32>
124+
%init = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
125+
126+
// Parent scf.for loop that will be coalesced with reduction tiling loops.
127+
%result = scf.for %iv = %c0 to %c3 step %c1 iter_args(%arg1 = %init) -> (tensor<128xf32>) {
128+
%slice = tensor.extract_slice %arg0[0, 0, 0] [128, 384, 256] [1, 1, 1] : tensor<128x384x256xf32> to tensor<128x384x256xf32>
129+
%reduced = linalg.generic {
130+
indexing_maps = [#map1, #map2],
131+
iterator_types = ["parallel", "reduction", "reduction"]
132+
} ins(%slice : tensor<128x384x256xf32>) outs(%arg1 : tensor<128xf32>) attrs = {lowering_config = #config} {
133+
^bb0(%in: f32, %out: f32):
134+
%add = arith.addf %in, %out : f32
135+
linalg.yield %add : f32
136+
} -> tensor<128xf32>
137+
scf.yield %reduced : tensor<128xf32>
138+
}
139+
return %result : tensor<128xf32>
131140
}
132141

133142
// CHECK-LABEL: func.func @reduction
134-
// CHECK: %[[FILL:.+]] = linalg.fill {{.*}} tensor<128xf32>
135-
// CHECK: scf.for %{{.*}} = %c0 to %c384 step %c8 iter_args(%{{.*}} = %[[FILL]])
136-
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8xf32>)
143+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
144+
// CHECK-DAG: %[[C9216:.+]] = arith.constant 9216 : index
145+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
146+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
147+
// CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
148+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C9216]] step %[[C1]] iter_args(%[[ARG:.+]] = %[[INIT]])
149+
// CHECK-NOT: scf.for
150+
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[ARG]] : tensor<128xf32>)
137151
// CHECK: scf.yield
138152

139153
// Verify that no tiling happens in the thread case.
@@ -142,6 +156,109 @@ func.func @reduction(%3: tensor<128x384xf32>) -> tensor<128xf32> {
142156

143157
// -----
144158

159+
// Test coalescing when parent scf.for has iter_args but NOT chained with reduction.
160+
#config2 = #iree_gpu.lowering_config<{reduction = [0, 8, 4]}>
161+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
162+
#map4 = affine_map<(d0, d1, d2) -> (d0)>
163+
#map5 = affine_map<(d0) -> (d0)>
164+
func.func @reduction_nochain_iter_args(%arg0: tensor<128x384x256xf32>) -> tensor<128xf32> {
165+
%c0 = arith.constant 0 : index
166+
%c1 = arith.constant 1 : index
167+
%c3 = arith.constant 3 : index
168+
%cst = arith.constant 0.000000e+00 : f32
169+
%empty = tensor.empty() : tensor<128xf32>
170+
%ew_init = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
171+
172+
// Parent scf.for loop with iter_args but NOT chained with reduction.
173+
%result = scf.for %iv = %c0 to %c3 step %c1 iter_args(%ew = %ew_init) -> (tensor<128xf32>) {
174+
%empty2 = tensor.empty() : tensor<128xf32>
175+
%init = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<128xf32>) -> tensor<128xf32>
176+
%slice = tensor.extract_slice %arg0[0, 0, 0] [128, 384, 256] [1, 1, 1] : tensor<128x384x256xf32> to tensor<128x384x256xf32>
177+
%reduced = linalg.generic {
178+
indexing_maps = [#map3, #map4],
179+
iterator_types = ["parallel", "reduction", "reduction"]
180+
} ins(%slice : tensor<128x384x256xf32>) outs(%init : tensor<128xf32>) attrs = {lowering_config = #config2} {
181+
^bb0(%in: f32, %out: f32):
182+
%add = arith.addf %in, %out : f32
183+
linalg.yield %add : f32
184+
} -> tensor<128xf32>
185+
186+
// elementwise that uses the parent scf.for iter arg.
187+
%empty3 = tensor.empty() : tensor<128xf32>
188+
%elementwise = linalg.generic {
189+
indexing_maps = [#map5, #map5, #map5],
190+
iterator_types = ["parallel"]
191+
} ins(%ew, %reduced : tensor<128xf32>, tensor<128xf32>) outs(%empty3 : tensor<128xf32>) {
192+
^bb0(%e: f32, %r: f32, %out: f32):
193+
%new = arith.addf %e, %r : f32
194+
linalg.yield %new : f32
195+
} -> tensor<128xf32>
196+
197+
scf.yield %elementwise : tensor<128xf32>
198+
}
199+
return %result : tensor<128xf32>
200+
}
201+
202+
// CHECK-LABEL: func.func @reduction_nochain_iter_args
203+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
204+
// CHECK-DAG: %[[C3072:.+]] = arith.constant 3072 : index
205+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
206+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
207+
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
208+
// CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
209+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[EW_ARG:.+]] = %[[INIT]])
210+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3072]] step %[[C1]] iter_args(%[[RED_ARG:.+]] = %[[INIT]])
211+
// CHECK-NOT: scf.for
212+
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[RED_ARG]] : tensor<128xf32>)
213+
// CHECK: linalg.generic {{.*}} ins(%[[EW_ARG]], %{{.*}} : tensor<128xf32>, tensor<128xf32>)
214+
// CHECK: scf.yield
215+
216+
// THREAD-LABEL: func.func @reduction_nochain_iter_args
217+
// THREAD-NOT: scf.forall
218+
219+
// -----
220+
221+
// Test that coalescing is skipped when loops have dynamic trip counts.
222+
#config_dyn = #iree_gpu.lowering_config<{reduction = [0, 8, 4]}>
223+
#map_dyn1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
224+
#map_dyn2 = affine_map<(d0, d1, d2) -> (d0)>
225+
func.func @reduction_dynamic_trip_count(%arg0: tensor<128x384x256xf32>, %dyn_ub: index) -> tensor<128xf32> {
226+
%c0 = arith.constant 0 : index
227+
%c1 = arith.constant 1 : index
228+
%cst = arith.constant 0.000000e+00 : f32
229+
%empty = tensor.empty() : tensor<128xf32>
230+
%init = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
231+
232+
// Parent scf.for loop with dynamic upper bound.
233+
// This should NOT be coalesced with reduction tiling loops.
234+
%result = scf.for %iv = %c0 to %dyn_ub step %c1 iter_args(%arg1 = %init) -> (tensor<128xf32>) {
235+
%slice = tensor.extract_slice %arg0[0, 0, 0] [128, 384, 256] [1, 1, 1] : tensor<128x384x256xf32> to tensor<128x384x256xf32>
236+
%reduced = linalg.generic {
237+
indexing_maps = [#map_dyn1, #map_dyn2],
238+
iterator_types = ["parallel", "reduction", "reduction"]
239+
} ins(%slice : tensor<128x384x256xf32>) outs(%arg1 : tensor<128xf32>) attrs = {lowering_config = #config_dyn} {
240+
^bb0(%in: f32, %out: f32):
241+
%add = arith.addf %in, %out : f32
242+
linalg.yield %add : f32
243+
} -> tensor<128xf32>
244+
scf.yield %reduced : tensor<128xf32>
245+
}
246+
return %result : tensor<128xf32>
247+
}
248+
249+
// CHECK-LABEL: func.func @reduction_dynamic_trip_count
250+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
251+
// CHECK-SAME: %[[DYN_UB:[A-Za-z0-9]+]]: index
252+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
253+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
254+
// CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
255+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[DYN_UB]] step %[[C1]] iter_args(%[[ARG1:.+]] = %[[INIT]])
256+
// CHECK: scf.for %{{.*}} = %[[C0]] to %{{.*}} step %{{.*}} iter_args(%[[ARG2:.+]] = %[[ARG1]])
257+
// CHECK: scf.for %{{.*}} = %[[C0]] to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[ARG2]])
258+
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[ARG3]] : tensor<128xf32>)
259+
260+
// -----
261+
145262
#config = #iree_gpu.lowering_config<{reduction = [0, 0, 8]}>
146263
#map = affine_map<(d0, d1) -> (d0, d1)>
147264
func.func @matmul_fuse(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> {

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tile_and_convert_conv_to_matmul.mlir

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ func.func @conv_nhwc_generic(%a: tensor<1x3x66x8xf32>, %b: tensor<32x3x3x8xf32>,
1919
}
2020

2121
// CHECK-LABEL: func.func @conv_nhwc_generic
22-
// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
23-
// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
24-
// CHECK: linalg.generic
25-
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
22+
// CHECK-DAG: %[[C9:.+]] = arith.constant 9 : index
23+
// CHECK: scf.for %{{.*}} = %c0 to %[[C9]] step %c1
24+
// CHECK-NOT: scf.for
25+
// CHECK: linalg.generic
26+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
2627

2728
// -----
2829

@@ -35,10 +36,11 @@ func.func @conv_nhwc_named_dilated(%a: tensor<1x5x68x8xf32>, %b: tensor<32x3x3x8
3536
}
3637

3738
// CHECK-LABEL: func.func @conv_nhwc_named_dilated
38-
// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
39-
// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
40-
// CHECK: linalg.generic
41-
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
39+
// CHECK-DAG: %[[C9:.+]] = arith.constant 9 : index
40+
// CHECK: scf.for %{{.*}} = %c0 to %[[C9]] step %c1
41+
// CHECK-NOT: scf.for
42+
// CHECK: linalg.generic
43+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
4244

4345
// -----
4446

@@ -51,10 +53,11 @@ func.func @conv_nchw_named(%arg0: tensor<2x16x130x130xf32>, %arg1: tensor<32x16x
5153
}
5254

5355
// CHECK-LABEL: func.func @conv_nchw_named
54-
// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
55-
// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
56-
// CHECK: linalg.generic
57-
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)>
56+
// CHECK-DAG: %[[C9:.+]] = arith.constant 9 : index
57+
// CHECK: scf.for %{{.*}} = %c0 to %[[C9]] step %c1
58+
// CHECK-NOT: scf.for
59+
// CHECK: linalg.generic
60+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)>
5861

5962
// -----
6063

@@ -77,7 +80,8 @@ func.func @conv_chwn_generic(%a: tensor<16x24x16x16xf32>, %b: tensor<16x24x16x16
7780
}
7881

7982
// CHECK-LABEL: func.func @conv_chwn_generic
80-
// CHECK: scf.for %{{.*}} = %c0 to %c24 step %c1
81-
// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c1
82-
// CHECK: linalg.generic
83-
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2, d3)>
83+
// CHECK-DAG: %[[C384:.+]] = arith.constant 384 : index
84+
// CHECK: scf.for %{{.*}} = %c0 to %[[C384]] step %c1
85+
// CHECK-NOT: scf.for
86+
// CHECK: linalg.generic
87+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2, d3)>

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
#include "llvm/Support/Casting.h"
1414
#include "llvm/Support/Debug.h"
1515
#include "mlir/Analysis/TopologicalSortUtils.h"
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1718
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1819
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20+
#include "mlir/Dialect/SCF/Utils/Utils.h"
1921
#include "mlir/Dialect/Tensor/IR/Tensor.h"
22+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2023
#include "mlir/IR/Dominance.h"
2124

2225
#include <cassert>
@@ -482,6 +485,94 @@ LogicalResult applyTileAndFuseToEachRoot(
482485
tileAndFuseOptions, tiledResults->loops);
483486
}
484487
}
488+
489+
// Coalesce scf.for loops created during reduction tiling.
490+
// This is done at the very end after all other transformations
491+
// to avoid invalidating dominance info or affecting fusion logic.
492+
if (tilingLevel == IREE::GPU::TilingLevel::Reduction &&
493+
!tiledResults->loops.empty()) {
494+
SmallVector<scf::ForOp> forLoops;
495+
496+
// Check if tiling happened inside an existing scf.for loop
497+
// If so, include that parent loop in the coalescing.
498+
Operation *parentOp =
499+
tiledResults->loops.front().getOperation()->getParentOp();
500+
scf::ForOp parentForOp = dyn_cast<scf::ForOp>(parentOp);
501+
502+
// Collect all the tiled loops first.
503+
for (LoopLikeOpInterface loop : tiledResults->loops) {
504+
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
505+
forLoops.push_back(forOp);
506+
}
507+
}
508+
509+
// Only include parent if it forms a proper iter_args chain with the
510+
// tiled loops. This follows the same validation as upstream's
511+
// coalescePerfectlyNestedSCFForLoops.
512+
if (parentForOp && !forLoops.empty()) {
513+
// Check if parent and first child form a valid iter_args chain:
514+
// 1. Must have the same number of iter_args.
515+
// 2. Parent's iter_args must match child's init_args.
516+
// 3. Parent's terminator operands must match child's results.
517+
scf::ForOp firstChild = forLoops.front();
518+
bool formsChain = true;
519+
520+
if (parentForOp.getNumRegionIterArgs() !=
521+
firstChild.getNumRegionIterArgs()) {
522+
formsChain = false;
523+
LLVM_DEBUG(llvm::dbgs()
524+
<< "Skipping parent loop coalescing: different number of "
525+
"iter_args (parent: "
526+
<< parentForOp.getNumRegionIterArgs() << ", child: "
527+
<< firstChild.getNumRegionIterArgs() << ")\n");
528+
}
529+
530+
if (formsChain && !llvm::equal(parentForOp.getRegionIterArgs(),
531+
firstChild.getInitArgs())) {
532+
formsChain = false;
533+
LLVM_DEBUG(llvm::dbgs() << "Skipping parent loop coalescing: parent "
534+
"iter_args don't match child init_args\n");
535+
}
536+
537+
if (formsChain) {
538+
auto parentTerminator = parentForOp.getBody()->getTerminator();
539+
if (!llvm::equal(parentTerminator->getOperands(),
540+
firstChild.getResults())) {
541+
formsChain = false;
542+
LLVM_DEBUG(llvm::dbgs()
543+
<< "Skipping parent loop coalescing: parent yield "
544+
"doesn't match child results\n");
545+
}
546+
}
547+
548+
if (formsChain) {
549+
forLoops.insert(forLoops.begin(), parentForOp);
550+
}
551+
}
552+
553+
// If loops have dynamic trip counts and we coalesce them, it can
554+
// cause range analysis to not find static bounds. This was mainly
555+
// noticed as a problem in applyPaddingLevel, to prevent a regression
556+
// we dont coalesce such loops.
557+
bool hasDynamicTripCount = false;
558+
for (scf::ForOp forOp : forLoops) {
559+
if (!getConstantIntValue(forOp.getLowerBound()) ||
560+
!getConstantIntValue(forOp.getUpperBound()) ||
561+
!getConstantIntValue(forOp.getStep())) {
562+
hasDynamicTripCount = true;
563+
LLVM_DEBUG(llvm::dbgs()
564+
<< "Skipping coalescing: loop has dynamic trip count\n");
565+
break;
566+
}
567+
}
568+
569+
if (forLoops.size() > 1 && !hasDynamicTripCount) {
570+
if (failed(coalesceLoops(rewriter, forLoops))) {
571+
// Coalescing failure is not critical, just log and continue.
572+
LLVM_DEBUG(llvm::dbgs() << "Failed to coalesce reduction loops\n");
573+
}
574+
}
575+
}
485576
}
486577
return success();
487578
}

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_direct_conv_tile_and_fuse.mlir

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,18 @@ hal.executable private @main {
6262
// CHECK-DAG: memref.alloc() : memref<1x1x32x68xf16, #gpu.address_space<workgroup>>
6363
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
6464
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
65-
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
66-
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
67-
// CHECK-DAG: %[[C36:.+]] = arith.constant 36 : index
65+
// CHECK-DAG: %[[C81:.+]] = arith.constant 81 : index
6866
// CHECK: scf.forall ({{.*}}) in (16, 48, 9) {
69-
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C3]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
70-
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C3]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
71-
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C36]] step %[[C4]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
72-
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
73-
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<4xf16>
74-
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
75-
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<8xf16>
76-
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
77-
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
78-
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
79-
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
80-
// CHECK-COUNT-4: amdgpu.mfma 16x16x16
67+
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C81]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
68+
// CHECK-NOT: scf.for
69+
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
70+
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<4xf16>
71+
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
72+
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<8xf16>
73+
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
74+
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
75+
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
76+
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
77+
// CHECK-COUNT-4: amdgpu.mfma 16x16x16
8178
// CHECK: vector.transfer_write %{{.*}}, %[[BUF2]]
8279
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

0 commit comments

Comments
 (0)