Skip to content

Commit c6c8443

Browse files
committed
[Flang][OpenMP] Minimize host ops remaining in device compilation
This patch updates the function filtering OpenMP pass intended to remove host functions from the MLIR module created by Flang lowering when targeting an OpenMP target device. Host functions holding target regions must be kept, so that the target regions within them can be translated for the device. The issue is that non-target operations inside these functions cannot be discarded because some of them hold information that is also relevant during target device codegen. Specifically, mapping information resides outside of `omp.target` regions. This patch updates the previous behavior where all host operations were preserved to then ignore all of those that are not actually needed by target device codegen. This, in practice, means only keeping target regions and mapping information needed by the device. Arguments for some of these remaining operations are replaced by placeholder allocations and `fir.undefined`, since they are only actually defined inside of the target regions themselves. As a result, this set of changes makes it possible to later simplify target device codegen, as it is no longer necessary to handle host operations differently to avoid issues.
1 parent b70cbfa commit c6c8443

File tree

7 files changed

+738
-32
lines changed

7 files changed

+738
-32
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
4646
"for the target device.";
4747
let dependentDialects = [
4848
"mlir::func::FuncDialect",
49-
"fir::FIROpsDialect"
49+
"fir::FIROpsDialect",
50+
"mlir::omp::OpenMPDialect"
5051
];
5152
}
5253

flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
#include "flang/Optimizer/Dialect/FIRDialect.h"
1515
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
16+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1617
#include "flang/Optimizer/OpenMP/Passes.h"
1718

1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
1920
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2021
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
2122
#include "mlir/IR/BuiltinOps.h"
23+
#include "llvm/ADT/SetVector.h"
2224
#include "llvm/ADT/SmallVector.h"
2325

2426
namespace flangomp {
@@ -94,12 +96,298 @@ class FunctionFilteringPass
9496
funcOp.erase();
9597
return WalkResult::skip();
9698
}
99+
100+
if (failed(rewriteHostRegion(funcOp.getRegion()))) {
101+
funcOp.emitOpError() << "could not be rewritten for target device";
102+
return WalkResult::interrupt();
103+
}
104+
97105
if (declareTargetOp)
98106
declareTargetOp.setDeclareTarget(declareType,
99107
omp::DeclareTargetCaptureClause::to);
100108
}
101109
return WalkResult::advance();
102110
});
103111
}
112+
113+
private:
114+
/// Add the given \c omp.map.info to a sorted set while taking into account
115+
/// its dependencies.
116+
static void collectMapInfos(omp::MapInfoOp mapOp, Region &region,
117+
llvm::SetVector<omp::MapInfoOp> &mapInfos) {
118+
for (Value member : mapOp.getMembers())
119+
collectMapInfos(cast<omp::MapInfoOp>(member.getDefiningOp()), region,
120+
mapInfos);
121+
122+
if (region.isAncestor(mapOp->getParentRegion()))
123+
mapInfos.insert(mapOp);
124+
}
125+
126+
/// Add the given value to a sorted set if it should be replaced by a
127+
/// placeholder when used as a pointer-like argument to an operation
128+
/// participating in the initialization of an \c omp.map.info.
129+
static void markPtrOperandForRewrite(Value value,
130+
llvm::SetVector<Value> &rewriteValues) {
131+
// We don't need to rewrite operands if they are defined by block arguments
132+
// of operations that will still remain after the region is rewritten.
133+
if (isa<BlockArgument>(value) &&
134+
isa<func::FuncOp, omp::TargetDataOp>(
135+
cast<BlockArgument>(value).getOwner()->getParentOp()))
136+
return;
137+
138+
rewriteValues.insert(value);
139+
}
140+
141+
/// Rewrite the given host device region belonging to a function that contains
142+
/// \c omp.target operations, to remove host-only operations that are not used
143+
/// by device codegen.
144+
///
145+
/// It is based on the expected form of the MLIR module as produced by Flang
146+
/// lowering and it performs the following mutations:
147+
/// - Replace all values returned by the function with \c fir.undefined.
148+
/// - Operations taking map-like clauses (e.g. \c omp.target,
149+
/// \c omp.target_data, etc) are moved to the end of the function. If they
150+
/// are nested inside of any other operations, they are hoisted out of
151+
/// them. If the region belongs to \c omp.target_data, these operations
152+
/// are hoisted to its top level, rather than to the parent function.
153+
/// - Only \c omp.map.info operations associated to these target regions are
154+
/// preserved. These are moved above all \c omp.target and sorted to
155+
/// satisfy dependencies among them.
156+
/// - \c bounds arguments are removed from \c omp.map.info operations.
157+
/// - \c var_ptr and \c var_ptr_ptr arguments of \c omp.map.info are
158+
/// handled as follows:
159+
/// - \c var_ptr_ptr is expected to be defined by a \c fir.box_offset
160+
/// operation which is preserved. Otherwise, the pass will fail.
161+
/// - \c var_ptr can be defined by an \c hlfir.declare which is also
162+
/// preserved. If the \c var_ptr or \c hlfir.declare \c memref argument
163+
/// is a \c fir.address_of operation, that operation is also maintained.
164+
/// Otherwise, it is replaced by a placeholder \c fir.alloca and a
165+
/// \c fir.convert or kept unmodified when it is defined by an entry
166+
/// block argument. If it has \c shape or \c typeparams arguments, they
167+
/// are also replaced by applicable constants. \c dummy_scope arguments
168+
/// are discarded.
169+
/// - Every other operation not located inside of an \c omp.target is
170+
/// removed.
171+
LogicalResult rewriteHostRegion(Region &region) {
172+
// Extract parent op information.
173+
auto [funcOp, targetDataOp] = [&region]() {
174+
Operation *parent = region.getParentOp();
175+
return std::make_tuple(dyn_cast<func::FuncOp>(parent),
176+
dyn_cast<omp::TargetDataOp>(parent));
177+
}();
178+
assert((bool)funcOp != (bool)targetDataOp &&
179+
"region must be defined by either func.func or omp.target_data");
180+
181+
// Collect operations that have mapping information associated to them.
182+
llvm::SmallVector<
183+
std::variant<omp::TargetOp, omp::TargetDataOp, omp::TargetEnterDataOp,
184+
omp::TargetExitDataOp, omp::TargetUpdateOp>>
185+
targetOps;
186+
187+
WalkResult result = region.walk<WalkOrder::PreOrder>([&](Operation *op) {
188+
// Skip the inside of omp.target regions, since these contain device code.
189+
if (auto targetOp = dyn_cast<omp::TargetOp>(op)) {
190+
targetOps.push_back(targetOp);
191+
return WalkResult::skip();
192+
}
193+
194+
if (auto targetOp = dyn_cast<omp::TargetDataOp>(op)) {
195+
// Recursively rewrite omp.target_data regions as well.
196+
if (failed(rewriteHostRegion(targetOp.getRegion()))) {
197+
targetOp.emitOpError() << "rewrite for target device failed";
198+
return WalkResult::interrupt();
199+
}
200+
201+
targetOps.push_back(targetOp);
202+
return WalkResult::skip();
203+
}
204+
205+
if (auto targetOp = dyn_cast<omp::TargetEnterDataOp>(op))
206+
targetOps.push_back(targetOp);
207+
if (auto targetOp = dyn_cast<omp::TargetExitDataOp>(op))
208+
targetOps.push_back(targetOp);
209+
if (auto targetOp = dyn_cast<omp::TargetUpdateOp>(op))
210+
targetOps.push_back(targetOp);
211+
212+
return WalkResult::advance();
213+
});
214+
215+
if (result.wasInterrupted())
216+
return failure();
217+
218+
// Make a temporary clone of the parent operation with an empty region,
219+
// and update all references to entry block arguments to those of the new
220+
// region. Users will later either be moved to the new region or deleted
221+
// when the original region is replaced by the new.
222+
OpBuilder builder(&getContext());
223+
builder.setInsertionPointAfter(region.getParentOp());
224+
Operation *newOp = builder.cloneWithoutRegions(*region.getParentOp());
225+
Block &block = newOp->getRegion(0).emplaceBlock();
226+
227+
llvm::SmallVector<Location> locs;
228+
locs.reserve(region.getNumArguments());
229+
llvm::transform(region.getArguments(), std::back_inserter(locs),
230+
[](const BlockArgument &arg) { return arg.getLoc(); });
231+
block.addArguments(region.getArgumentTypes(), locs);
232+
233+
for (auto [oldArg, newArg] :
234+
llvm::zip_equal(region.getArguments(), block.getArguments()))
235+
oldArg.replaceAllUsesWith(newArg);
236+
237+
// Collect omp.map.info ops while satisfying interdependencies. This must be
238+
// updated whenever new map-like clauses are introduced or they are attached
239+
// to other operations.
240+
llvm::SetVector<omp::MapInfoOp> mapInfos;
241+
for (auto targetOp : targetOps) {
242+
std::visit(
243+
[&region, &mapInfos](auto op) {
244+
for (Value mapVar : op.getMapVars())
245+
collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()),
246+
region, mapInfos);
247+
248+
if constexpr (std::is_same_v<decltype(op), omp::TargetOp>) {
249+
for (Value mapVar : op.getHasDeviceAddrVars())
250+
collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()),
251+
region, mapInfos);
252+
} else if constexpr (std::is_same_v<decltype(op),
253+
omp::TargetDataOp>) {
254+
for (Value mapVar : op.getUseDeviceAddrVars())
255+
collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()),
256+
region, mapInfos);
257+
for (Value mapVar : op.getUseDevicePtrVars())
258+
collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()),
259+
region, mapInfos);
260+
}
261+
},
262+
targetOp);
263+
}
264+
265+
// Move omp.map.info ops to the new block and collect dependencies.
266+
llvm::SetVector<hlfir::DeclareOp> declareOps;
267+
llvm::SetVector<fir::BoxOffsetOp> boxOffsets;
268+
llvm::SetVector<Value> rewriteValues;
269+
for (omp::MapInfoOp mapOp : mapInfos) {
270+
// Handle var_ptr: hlfir.declare.
271+
if (auto declareOp = dyn_cast_if_present<hlfir::DeclareOp>(
272+
mapOp.getVarPtr().getDefiningOp())) {
273+
if (region.isAncestor(declareOp->getParentRegion()))
274+
declareOps.insert(declareOp);
275+
} else {
276+
markPtrOperandForRewrite(mapOp.getVarPtr(), rewriteValues);
277+
}
278+
279+
// Handle var_ptr_ptr: fir.box_offset.
280+
if (Value varPtrPtr = mapOp.getVarPtrPtr()) {
281+
if (auto boxOffset = llvm::dyn_cast_if_present<fir::BoxOffsetOp>(
282+
varPtrPtr.getDefiningOp())) {
283+
if (region.isAncestor(boxOffset->getParentRegion()))
284+
boxOffsets.insert(boxOffset);
285+
} else {
286+
return mapOp->emitOpError() << "var_ptr_ptr rewrite only supported "
287+
"if defined by fir.box_offset";
288+
}
289+
}
290+
291+
// Bounds are not used during target device codegen.
292+
mapOp.getBoundsMutable().clear();
293+
mapOp->moveBefore(&block, block.end());
294+
}
295+
296+
// Create a temporary marker to simplify the op moving process below.
297+
builder.setInsertionPointToStart(&block);
298+
auto marker = builder.create<fir::UndefOp>(builder.getUnknownLoc(),
299+
builder.getNoneType());
300+
builder.setInsertionPoint(marker);
301+
302+
// Move dependencies of hlfir.declare ops.
303+
for (hlfir::DeclareOp declareOp : declareOps) {
304+
Value memref = declareOp.getMemref();
305+
306+
// If it's defined by fir.address_of, then we need to keep that op as well
307+
// because it might be pointing to a 'declare target' global.
308+
if (auto addressOf =
309+
dyn_cast_if_present<fir::AddrOfOp>(memref.getDefiningOp()))
310+
addressOf->moveBefore(marker);
311+
else
312+
markPtrOperandForRewrite(memref, rewriteValues);
313+
314+
// Shape and typeparams aren't needed for target device codegen, but
315+
// removing them would break verifiers.
316+
Value zero;
317+
if (declareOp.getShape() || !declareOp.getTypeparams().empty())
318+
zero = builder.create<arith::ConstantOp>(declareOp.getLoc(),
319+
builder.getI64IntegerAttr(0));
320+
321+
if (auto shape = declareOp.getShape()) {
322+
Operation *shapeOp = shape.getDefiningOp();
323+
unsigned numArgs = shapeOp->getNumOperands();
324+
if (isa<fir::ShapeShiftOp>(shapeOp))
325+
numArgs /= 2;
326+
327+
// Since the pre-cg rewrite pass requires the shape to be defined by one
328+
// of fir.shape, fir.shapeshift or fir.shift, we need to create one of
329+
// these.
330+
llvm::SmallVector<Value> extents(numArgs, zero);
331+
auto newShape = builder.create<fir::ShapeOp>(shape.getLoc(), extents);
332+
declareOp.getShapeMutable().assign(newShape);
333+
}
334+
335+
for (OpOperand &typeParam : declareOp.getTypeparamsMutable())
336+
typeParam.assign(zero);
337+
338+
declareOp.getDummyScopeMutable().clear();
339+
}
340+
341+
// We don't actually need the proper local allocations, but rather maintain
342+
// the basic form of map operands. We create 1-bit placeholder allocas
343+
// that we "typecast" to the expected pointer type and replace all uses.
344+
// Using fir.undefined here instead is not possible because these variables
345+
// cannot be constants, as that would trigger different codegen for target
346+
// regions.
347+
for (Value value : rewriteValues) {
348+
Location loc = value.getLoc();
349+
Value placeholder =
350+
builder.create<fir::AllocaOp>(loc, builder.getI1Type());
351+
value.replaceAllUsesWith(
352+
builder.create<fir::ConvertOp>(loc, value.getType(), placeholder));
353+
}
354+
355+
// Move omp.map.info dependencies.
356+
for (hlfir::DeclareOp declareOp : declareOps)
357+
declareOp->moveBefore(marker);
358+
359+
// The box_ref argument of fir.box_offset is expected to be the same value
360+
// that was passed as var_ptr to the corresponding omp.map.info, so we don't
361+
// need to move its defining op here.
362+
for (fir::BoxOffsetOp boxOffset : boxOffsets)
363+
boxOffset->moveBefore(marker);
364+
365+
marker->erase();
366+
367+
// Move mapping information users to the end of the new block.
368+
for (auto targetOp : targetOps)
369+
std::visit([&block](auto op) { op->moveBefore(&block, block.end()); },
370+
targetOp);
371+
372+
// Add terminator to the new block.
373+
builder.setInsertionPointToEnd(&block);
374+
if (funcOp) {
375+
llvm::SmallVector<Value> returnValues;
376+
returnValues.reserve(funcOp.getNumResults());
377+
for (auto type : funcOp.getResultTypes())
378+
returnValues.push_back(
379+
builder.create<fir::UndefOp>(funcOp.getLoc(), type));
380+
381+
builder.create<func::ReturnOp>(funcOp.getLoc(), returnValues);
382+
} else {
383+
builder.create<omp::TerminatorOp>(targetDataOp.getLoc());
384+
}
385+
386+
// Replace old (now missing ops) region with the new one and remove the
387+
// temporary clone.
388+
region.takeBody(newOp->getRegion(0));
389+
newOp->erase();
390+
return success();
391+
}
104392
};
105393
} // namespace

flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
2-
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s
3-
!RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s
4-
!RUN: bbc -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
1+
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefixes=BOTH,HOST
2+
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE
3+
!RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefixes=BOTH,HOST
4+
!RUN: bbc -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE
55

66
program test_link
77

@@ -20,13 +20,14 @@ program test_link
2020
integer, pointer :: test_ptr2
2121
!$omp declare target link(test_ptr2)
2222

23-
!CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "test_int"}
23+
!BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "test_int"}
2424
!$omp target
2525
test_int = test_int + 1
2626
!$omp end target
2727

2828

29-
!CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<!fir.array<3xi32>>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds({{%.*}}) -> !fir.ref<!fir.array<3xi32>> {name = "test_array_1d"}
29+
!HOST-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<!fir.array<3xi32>>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds({{%.*}}) -> !fir.ref<!fir.array<3xi32>> {name = "test_array_1d"}
30+
!DEVICE-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<!fir.array<3xi32>>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!fir.array<3xi32>> {name = "test_array_1d"}
3031
!$omp target
3132
do i = 1,3
3233
test_array_1d(i) = i * 2
@@ -35,18 +36,18 @@ program test_link
3536

3637
allocate(test_ptr1)
3738
test_ptr1 = 1
38-
!CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "test_ptr1"}
39+
!BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "test_ptr1"}
3940
!$omp target
4041
test_ptr1 = test_ptr1 + 1
4142
!$omp end target
4243

43-
!CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "test_target"}
44+
!BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "test_target"}
4445
!$omp target
4546
test_target = test_target + 1
4647
!$omp end target
4748

4849

49-
!CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "test_ptr2"}
50+
!BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "test_ptr2"}
5051
test_ptr2 => test_target
5152
!$omp target
5253
test_ptr2 = test_ptr2 + 1

0 commit comments

Comments
 (0)