Skip to content

Commit 33da12a

Browse files
authored
[acc] Lower acc if with multi-block host fallback via scf.execute_region (llvm#188350)
handle multi-block host fallback regions by wrapping them in scf.execute_region, instead of rejecting with `not yet implemented: region with multiple blocks`.
1 parent 7aaec28 commit 33da12a

File tree

2 files changed

+58
-15
lines changed

2 files changed

+58
-15
lines changed

mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
#include "mlir/Dialect/Func/IR/FuncOps.h"
6060
#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
6161
#include "mlir/Dialect/OpenACC/OpenACC.h"
62+
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
6263
#include "mlir/Dialect/SCF/IR/SCF.h"
6364
#include "mlir/IR/Builders.h"
6465
#include "mlir/IR/IRMapping.h"
@@ -215,22 +216,28 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
215216
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
216217

217218
// Host execution path (false branch)
218-
if (!computeConstructOp.getRegion().hasOneBlock()) {
219-
accSupport->emitNYI(computeConstructOp.getLoc(),
220-
"region with multiple blocks");
221-
return;
222-
}
223-
224-
// Don't need to clone original ops, just take them and legalize for host
225-
ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
226-
227-
// Swap acc yield for scf yield
228-
Block &elseBlock = ifOp.getElseRegion().front();
229-
elseBlock.getTerminator()->erase();
230-
rewriter.setInsertionPointToEnd(&elseBlock);
231-
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
219+
Region &hostRegion = computeConstructOp.getRegion();
220+
if (hostRegion.hasOneBlock()) {
221+
// Don't need to clone original ops, just take them and legalize for host.
222+
ifOp.getElseRegion().takeBody(hostRegion);
223+
224+
// Swap acc yield for scf yield.
225+
Block &elseBlock = ifOp.getElseRegion().front();
226+
elseBlock.getTerminator()->erase();
227+
rewriter.setInsertionPointToEnd(&elseBlock);
228+
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
232229

233-
convertHostRegion(computeConstructOp, ifOp.getElseRegion());
230+
convertHostRegion(computeConstructOp, ifOp.getElseRegion());
231+
} else {
232+
// scf.if regions must stay single-block. Wrap the original multi-block ACC
233+
// body in scf.execute_region so it can be hosted in the else branch.
234+
Block &elseBlock = ifOp.getElseRegion().front();
235+
rewriter.setInsertionPoint(elseBlock.getTerminator());
236+
IRMapping hostMapping;
237+
auto hostExecuteRegion = wrapMultiBlockRegionWithSCFExecuteRegion(
238+
hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter);
239+
convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion());
240+
}
234241

235242
// The original op is now empty and can be erased
236243
eraseOps.push_back(computeConstructOp);

mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,42 @@ func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
3737

3838
// -----
3939

40+
// Test acc.parallel if lowering when host fallback region has multiple blocks.
41+
// CHECK-LABEL: func.func @test_parallel_if_multiblock
42+
func.func @test_parallel_if_multiblock(%cond: i1, %n: i32) {
43+
%c0_i32 = arith.constant 0 : i32
44+
%c1_i32 = arith.constant 1 : i32
45+
%counter = memref.alloca() : memref<i32>
46+
memref.store %n, %counter[] : memref<i32>
47+
48+
// CHECK-NOT: acc.parallel if
49+
// CHECK: scf.if %{{.*}} {
50+
// CHECK: acc.parallel {
51+
// CHECK: } else {
52+
// CHECK: scf.execute_region {
53+
// CHECK: ^bb
54+
// CHECK: cf.cond_br
55+
// CHECK: scf.yield
56+
// CHECK: }
57+
// CHECK: }
58+
acc.parallel if(%cond) {
59+
cf.br ^bb1
60+
^bb1:
61+
%v = memref.load %counter[] : memref<i32>
62+
%pred = arith.cmpi sgt, %v, %c0_i32 : i32
63+
cf.cond_br %pred, ^bb2, ^bb3
64+
^bb2:
65+
%next = arith.subi %v, %c1_i32 : i32
66+
memref.store %next, %counter[] : memref<i32>
67+
cf.br ^bb1
68+
^bb3:
69+
acc.yield
70+
}
71+
return
72+
}
73+
74+
// -----
75+
4076
// Test acc.kernels with if condition
4177
// CHECK-LABEL: func.func @test_kernels_if
4278
func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {

0 commit comments

Comments
 (0)