Skip to content

Commit 85efe1a

Browse files
[GPU] Add coalescing to reduction tiling
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com> Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent a657d73 commit 85efe1a

4 files changed

Lines changed: 194 additions & 42 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,28 +112,42 @@ func.func @matmul_transpose_b(%5: tensor<64x64xf32>, %6: tensor<64x1280xf16>, %7
112112

113113
// -----
114114

115-
#config = #iree_gpu.lowering_config<{reduction = [0, 8]}>
116-
#map1 = affine_map<(d0, d1) -> (d0, d1)>
117-
#map2 = affine_map<(d0, d1) -> (d0)>
118-
func.func @reduction(%3: tensor<128x384xf32>) -> tensor<128xf32> {
115+
#config = #iree_gpu.lowering_config<{reduction = [0, 8, 4]}>
116+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
117+
#map2 = affine_map<(d0, d1, d2) -> (d0)>
118+
func.func @reduction(%arg0: tensor<128x384x256xf32>) -> tensor<128xf32> {
119+
%c0 = arith.constant 0 : index
120+
%c1 = arith.constant 1 : index
121+
%c3 = arith.constant 3 : index
119122
%cst = arith.constant 0.000000e+00 : f32
120123
%empty = tensor.empty() : tensor<128xf32>
121-
%4 = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
122-
%5 = linalg.generic {
123-
indexing_maps = [#map1, #map2],
124-
iterator_types = ["parallel", "reduction"]
125-
} ins(%3 : tensor<128x384xf32>) outs(%4 : tensor<128xf32>) attrs = {lowering_config = #config} {
126-
^bb0(%in: f32, %out: f32):
127-
%7 = arith.addf %in, %out : f32
128-
linalg.yield %7 : f32
129-
} -> tensor<128xf32>
130-
return %5 : tensor<128xf32>
124+
%init = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
125+
126+
// Parent scf.for loop that will be coalesced with reduction tiling loops.
127+
%result = scf.for %iv = %c0 to %c3 step %c1 iter_args(%arg1 = %init) -> (tensor<128xf32>) {
128+
%slice = tensor.extract_slice %arg0[0, 0, 0] [128, 384, 256] [1, 1, 1] : tensor<128x384x256xf32> to tensor<128x384x256xf32>
129+
%reduced = linalg.generic {
130+
indexing_maps = [#map1, #map2],
131+
iterator_types = ["parallel", "reduction", "reduction"]
132+
} ins(%slice : tensor<128x384x256xf32>) outs(%arg1 : tensor<128xf32>) attrs = {lowering_config = #config} {
133+
^bb0(%in: f32, %out: f32):
134+
%add = arith.addf %in, %out : f32
135+
linalg.yield %add : f32
136+
} -> tensor<128xf32>
137+
scf.yield %reduced : tensor<128xf32>
138+
}
139+
return %result : tensor<128xf32>
131140
}
132141

133142
// CHECK-LABEL: func.func @reduction
134-
// CHECK: %[[FILL:.+]] = linalg.fill {{.*}} tensor<128xf32>
135-
// CHECK: scf.for %{{.*}} = %c0 to %c384 step %c8 iter_args(%{{.*}} = %[[FILL]])
136-
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8xf32>)
143+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
144+
// CHECK-DAG: %[[C9216:.+]] = arith.constant 9216 : index
145+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
146+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
147+
// CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
148+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C9216]] step %[[C1]] iter_args(%[[ARG:.+]] = %[[INIT]])
149+
// CHECK-NOT: scf.for
150+
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[ARG]] : tensor<128xf32>)
137151
// CHECK: scf.yield
138152

139153
// Verify that no tiling happens in the thread case.
@@ -142,6 +156,68 @@ func.func @reduction(%3: tensor<128x384xf32>) -> tensor<128xf32> {
142156

143157
// -----
144158

159+
// Test coalescing when parent scf.for has iter_args but NOT chained with reduction.
160+
#config2 = #iree_gpu.lowering_config<{reduction = [0, 8, 4]}>
161+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
162+
#map4 = affine_map<(d0, d1, d2) -> (d0)>
163+
#map5 = affine_map<(d0) -> (d0)>
164+
func.func @reduction_nochain_iter_args(%arg0: tensor<128x384x256xf32>) -> tensor<128xf32> {
165+
%c0 = arith.constant 0 : index
166+
%c1 = arith.constant 1 : index
167+
%c3 = arith.constant 3 : index
168+
%cst = arith.constant 0.000000e+00 : f32
169+
%empty = tensor.empty() : tensor<128xf32>
170+
%ew_init = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
171+
172+
// Parent scf.for loop with iter_args but NOT chained with reduction.
173+
%result = scf.for %iv = %c0 to %c3 step %c1 iter_args(%ew = %ew_init) -> (tensor<128xf32>) {
174+
%empty2 = tensor.empty() : tensor<128xf32>
175+
%init = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<128xf32>) -> tensor<128xf32>
176+
%slice = tensor.extract_slice %arg0[0, 0, 0] [128, 384, 256] [1, 1, 1] : tensor<128x384x256xf32> to tensor<128x384x256xf32>
177+
%reduced = linalg.generic {
178+
indexing_maps = [#map3, #map4],
179+
iterator_types = ["parallel", "reduction", "reduction"]
180+
} ins(%slice : tensor<128x384x256xf32>) outs(%init : tensor<128xf32>) attrs = {lowering_config = #config2} {
181+
^bb0(%in: f32, %out: f32):
182+
%add = arith.addf %in, %out : f32
183+
linalg.yield %add : f32
184+
} -> tensor<128xf32>
185+
186+
// elementwise that uses the parent scf.for iter arg.
187+
%empty3 = tensor.empty() : tensor<128xf32>
188+
%elementwise = linalg.generic {
189+
indexing_maps = [#map5, #map5, #map5],
190+
iterator_types = ["parallel"]
191+
} ins(%ew, %reduced : tensor<128xf32>, tensor<128xf32>) outs(%empty3 : tensor<128xf32>) {
192+
^bb0(%e: f32, %r: f32, %out: f32):
193+
%new = arith.addf %e, %r : f32
194+
linalg.yield %new : f32
195+
} -> tensor<128xf32>
196+
197+
scf.yield %elementwise : tensor<128xf32>
198+
}
199+
return %result : tensor<128xf32>
200+
}
201+
202+
// CHECK-LABEL: func.func @reduction_nochain_iter_args
203+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
204+
// CHECK-DAG: %[[C3072:.+]] = arith.constant 3072 : index
205+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
206+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
207+
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
208+
// CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
209+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[EW_ARG:.+]] = %[[INIT]])
210+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3072]] step %[[C1]] iter_args(%[[RED_ARG:.+]] = %[[INIT]])
211+
// CHECK-NOT: scf.for
212+
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[RED_ARG]] : tensor<128xf32>)
213+
// CHECK: linalg.generic {{.*}} ins(%[[EW_ARG]], %{{.*}} : tensor<128xf32>, tensor<128xf32>)
214+
// CHECK: scf.yield
215+
216+
// THREAD-LABEL: func.func @reduction_no_iter_args
217+
// THREAD-NOT: scf.forall
218+
219+
// -----
220+
145221
#config = #iree_gpu.lowering_config<{reduction = [0, 0, 8]}>
146222
#map = affine_map<(d0, d1) -> (d0, d1)>
147223
func.func @matmul_fuse(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> {

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1717
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1818
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19+
#include "mlir/Dialect/SCF/Utils/Utils.h"
1920
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2021
#include "mlir/IR/Dominance.h"
2122

@@ -482,6 +483,80 @@ LogicalResult applyTileAndFuseToEachRoot(
482483
tileAndFuseOptions, tiledResults->loops);
483484
}
484485
}
486+
487+
// Coalesce scf.for loops created during reduction tiling.
488+
// This is done at the very end after all other transformations
489+
// to avoid invalidating dominance info or affecting fusion logic.
490+
if (tilingLevel == IREE::GPU::TilingLevel::Reduction &&
491+
!tiledResults->loops.empty()) {
492+
SmallVector<scf::ForOp> forLoops;
493+
494+
// Check if tiling happened inside an existing scf.for loop
495+
// If so, include that parent loop in the coalescing.
496+
Operation *parentOp =
497+
tiledResults->loops.front().getOperation()->getParentOp();
498+
scf::ForOp parentForOp = dyn_cast<scf::ForOp>(parentOp);
499+
500+
// Collect all the tiled loops first.
501+
for (LoopLikeOpInterface loop : tiledResults->loops) {
502+
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
503+
forLoops.push_back(forOp);
504+
}
505+
}
506+
507+
// Only include parent if it forms a proper iter_args chain with the
508+
// tiled loops. This follows the same validation as upstream's
509+
// coalescePerfectlyNestedSCFForLoops.
510+
if (parentForOp && !forLoops.empty()) {
511+
// Check if parent and first child form a valid iter_args chain:
512+
// 1. Must have the same number of iter_args.
513+
// 2. Parent's iter_args must match child's init_args.
514+
// 3. Parent's terminator operands must match child's results.
515+
scf::ForOp firstChild = forLoops.front();
516+
bool formsChain = true;
517+
518+
if (parentForOp.getNumRegionIterArgs() !=
519+
firstChild.getNumRegionIterArgs()) {
520+
formsChain = false;
521+
LLVM_DEBUG(llvm::dbgs()
522+
<< "Skipping parent loop coalescing: different number of "
523+
"iter_args (parent: "
524+
<< parentForOp.getNumRegionIterArgs() << ", child: "
525+
<< firstChild.getNumRegionIterArgs() << ")\n");
526+
}
527+
528+
if (formsChain && !llvm::equal(parentForOp.getRegionIterArgs(),
529+
firstChild.getInitArgs())) {
530+
formsChain = false;
531+
LLVM_DEBUG(llvm::dbgs() << "Skipping parent loop coalescing: parent "
532+
"iter_args don't match child init_args\n");
533+
}
534+
535+
if (formsChain) {
536+
auto parentTerminator = parentForOp.getBody()->getTerminator();
537+
if (!llvm::equal(parentTerminator->getOperands(),
538+
firstChild.getResults())) {
539+
formsChain = false;
540+
LLVM_DEBUG(llvm::dbgs()
541+
<< "Skipping parent loop coalescing: parent yield "
542+
"doesn't match child results\n");
543+
}
544+
}
545+
546+
// If forms a valid chain, insert parent at the beginning.
547+
if (formsChain) {
548+
forLoops.insert(forLoops.begin(), parentForOp);
549+
}
550+
}
551+
552+
// Coalesce if we have multiple loops.
553+
if (forLoops.size() > 1) {
554+
if (failed(coalesceLoops(rewriter, forLoops))) {
555+
// Coalescing failure is not critical, just log and continue.
556+
LLVM_DEBUG(llvm::dbgs() << "Failed to coalesce reduction loops\n");
557+
}
558+
}
559+
}
485560
}
486561
return success();
487562
}

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_direct_conv_tile_and_fuse.mlir

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,18 @@ hal.executable private @main {
6262
// CHECK-DAG: memref.alloc() : memref<1x1x32x68xf16, #gpu.address_space<workgroup>>
6363
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
6464
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
65-
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
66-
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
67-
// CHECK-DAG: %[[C36:.+]] = arith.constant 36 : index
65+
// CHECK-DAG: %[[C81:.+]] = arith.constant 81 : index
6866
// CHECK: scf.forall ({{.*}}) in (16, 48, 9) {
69-
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C3]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
70-
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C3]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
71-
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C36]] step %[[C4]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
72-
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
73-
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<4xf16>
74-
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
75-
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<8xf16>
76-
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
77-
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
78-
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
79-
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
80-
// CHECK-COUNT-4: amdgpu.mfma 16x16x16
67+
// CHECK: scf.for {{.+}} = %[[C0]] to %[[C81]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
68+
// CHECK-NOT: scf.for
69+
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
70+
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<4xf16>
71+
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
72+
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.+}} : {{.*}}vector<8xf16>
73+
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
74+
// CHECK: gpu.barrier memfence [#gpu.address_space<workgroup>]
75+
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
76+
// CHECK-DAG: vector.transfer_read {{.*}} vector<4x4xf16>
77+
// CHECK-COUNT-4: amdgpu.mfma 16x16x16
8178
// CHECK: vector.transfer_write %{{.*}}, %[[BUF2]]
8279
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,16 +1066,20 @@ hal.executable public @main {
10661066
}
10671067

10681068
// CHECK-LABEL: func @elemwise_reduction_elemwise
1069-
// CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1 {{.*}} -> (vector<1xf32>)
1070-
// CHECK: scf.for
1071-
// CHECK: scf.for
1072-
// CHECK: %[[REDUCE:.+]] = vector.multi_reduction
1073-
// CHECK: scf.yield %[[REDUCE]]
1074-
1075-
// CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1
1076-
// CHECK: scf.for
1077-
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32>
1078-
// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #amdgpu.address_space<fat_raw_buffer>>
1069+
// CHECK-DAG: %[[C144:.+]] = arith.constant 144 : index
1070+
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
1071+
// CHECK-DAG: %[[C9:.+]] = arith.constant 9 : index
1072+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1073+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1074+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C144]] step %[[C1]] {{.*}} -> (vector<1xf32>)
1075+
// CHECK-NOT: scf.for
1076+
// CHECK: %[[REDUCE:.+]] = vector.multi_reduction
1077+
// CHECK: scf.yield %[[REDUCE]]
1078+
1079+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C16]] step %[[C1]]
1080+
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C9]] step %[[C1]]
1081+
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32>
1082+
// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #amdgpu.address_space<fat_raw_buffer>>
10791083

10801084
// -----
10811085

0 commit comments

Comments
 (0)