Skip to content

Commit ffffc63

Browse files
address reviewer comments
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent ffb0b44 commit ffffc63

2 files changed

Lines changed: 3 additions & 106 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -156,68 +156,6 @@ func.func @reduction(%arg0: tensor<128x384x256xf32>) -> tensor<128xf32> {
156156

157157
// -----
158158

159-
// Test coalescing when parent scf.for has iter_args but NOT chained with reduction.
160-
#config2 = #iree_gpu.lowering_config<{reduction = [0, 8, 4]}>
161-
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
162-
#map4 = affine_map<(d0, d1, d2) -> (d0)>
163-
#map5 = affine_map<(d0) -> (d0)>
164-
func.func @reduction_nochain_iter_args(%arg0: tensor<128x384x256xf32>) -> tensor<128xf32> {
165-
%c0 = arith.constant 0 : index
166-
%c1 = arith.constant 1 : index
167-
%c3 = arith.constant 3 : index
168-
%cst = arith.constant 0.000000e+00 : f32
169-
%empty = tensor.empty() : tensor<128xf32>
170-
%ew_init = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
171-
172-
// Parent scf.for loop with iter_args but NOT chained with reduction.
173-
%result = scf.for %iv = %c0 to %c3 step %c1 iter_args(%ew = %ew_init) -> (tensor<128xf32>) {
174-
%empty2 = tensor.empty() : tensor<128xf32>
175-
%init = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<128xf32>) -> tensor<128xf32>
176-
%slice = tensor.extract_slice %arg0[0, 0, 0] [128, 384, 256] [1, 1, 1] : tensor<128x384x256xf32> to tensor<128x384x256xf32>
177-
%reduced = linalg.generic {
178-
indexing_maps = [#map3, #map4],
179-
iterator_types = ["parallel", "reduction", "reduction"]
180-
} ins(%slice : tensor<128x384x256xf32>) outs(%init : tensor<128xf32>) attrs = {lowering_config = #config2} {
181-
^bb0(%in: f32, %out: f32):
182-
%add = arith.addf %in, %out : f32
183-
linalg.yield %add : f32
184-
} -> tensor<128xf32>
185-
186-
// elementwise that uses the parent scf.for iter arg.
187-
%empty3 = tensor.empty() : tensor<128xf32>
188-
%elementwise = linalg.generic {
189-
indexing_maps = [#map5, #map5, #map5],
190-
iterator_types = ["parallel"]
191-
} ins(%ew, %reduced : tensor<128xf32>, tensor<128xf32>) outs(%empty3 : tensor<128xf32>) {
192-
^bb0(%e: f32, %r: f32, %out: f32):
193-
%new = arith.addf %e, %r : f32
194-
linalg.yield %new : f32
195-
} -> tensor<128xf32>
196-
197-
scf.yield %elementwise : tensor<128xf32>
198-
}
199-
return %result : tensor<128xf32>
200-
}
201-
202-
// CHECK-LABEL: func.func @reduction_nochain_iter_args
203-
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x384x256xf32>
204-
// CHECK-DAG: %[[C3072:.+]] = arith.constant 3072 : index
205-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
206-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
207-
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
208-
// CHECK: %[[INIT:.+]] = linalg.fill {{.*}} tensor<128xf32>
209-
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[EW_ARG:.+]] = %[[INIT]])
210-
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C3072]] step %[[C1]] iter_args(%[[RED_ARG:.+]] = %[[INIT]])
211-
// CHECK-NOT: scf.for
212-
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x8x4xf32>) outs(%[[RED_ARG]] : tensor<128xf32>)
213-
// CHECK: linalg.generic {{.*}} ins(%[[EW_ARG]], %{{.*}} : tensor<128xf32>, tensor<128xf32>)
214-
// CHECK: scf.yield
215-
216-
// THREAD-LABEL: func.func @reduction_nochain_iter_args
217-
// THREAD-NOT: scf.forall
218-
219-
// -----
220-
221159
// Test that coalescing is skipped when loops have dynamic trip counts.
222160
#config_dyn = #iree_gpu.lowering_config<{reduction = [0, 8, 4]}>
223161
#map_dyn1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,9 @@ LogicalResult applyTileAndFuseToEachRoot(
498498
Operation *parentOp =
499499
tiledResults->loops.front().getOperation()->getParentOp();
500500
scf::ForOp parentForOp = dyn_cast<scf::ForOp>(parentOp);
501+
if (parentForOp) {
502+
forLoops.push_back(parentForOp);
503+
}
501504

502505
// Collect all the tiled loops first.
503506
for (LoopLikeOpInterface loop : tiledResults->loops) {
@@ -506,50 +509,6 @@ LogicalResult applyTileAndFuseToEachRoot(
506509
}
507510
}
508511

509-
// Only include parent if it forms a proper iter_args chain with the
510-
// tiled loops. This follows the same validation as upstream's
511-
// coalescePerfectlyNestedSCFForLoops.
512-
if (parentForOp && !forLoops.empty()) {
513-
// Check if parent and first child form a valid iter_args chain:
514-
// 1. Must have the same number of iter_args.
515-
// 2. Parent's iter_args must match child's init_args.
516-
// 3. Parent's terminator operands must match child's results.
517-
scf::ForOp firstChild = forLoops.front();
518-
bool formsChain = true;
519-
520-
if (parentForOp.getNumRegionIterArgs() !=
521-
firstChild.getNumRegionIterArgs()) {
522-
formsChain = false;
523-
LLVM_DEBUG(llvm::dbgs()
524-
<< "Skipping parent loop coalescing: different number of "
525-
"iter_args (parent: "
526-
<< parentForOp.getNumRegionIterArgs() << ", child: "
527-
<< firstChild.getNumRegionIterArgs() << ")\n");
528-
}
529-
530-
if (formsChain && !llvm::equal(parentForOp.getRegionIterArgs(),
531-
firstChild.getInitArgs())) {
532-
formsChain = false;
533-
LLVM_DEBUG(llvm::dbgs() << "Skipping parent loop coalescing: parent "
534-
"iter_args don't match child init_args\n");
535-
}
536-
537-
if (formsChain) {
538-
auto parentTerminator = parentForOp.getBody()->getTerminator();
539-
if (!llvm::equal(parentTerminator->getOperands(),
540-
firstChild.getResults())) {
541-
formsChain = false;
542-
LLVM_DEBUG(llvm::dbgs()
543-
<< "Skipping parent loop coalescing: parent yield "
544-
"doesn't match child results\n");
545-
}
546-
}
547-
548-
if (formsChain) {
549-
forLoops.insert(forLoops.begin(), parentForOp);
550-
}
551-
}
552-
553512
// If loops have dynamic trip counts and we coalesce them, it can
554513
// cause range analysis to not find static bounds. This was mainly
555514
// noticed as a problem in applyPaddingLevel, to prevent a regression

0 commit comments

Comments
 (0)