Skip to content

Commit 30b0946

Browse files
authored
[Flang][MLIR][OpenMP] Improve use_device_* handling (#137198)
This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an `omp.map.info` operation to check this requirement. It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the custom MLIR printer and parser for these clauses, to support initializing it to `OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as `return_param`. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.
1 parent 42ee758 commit 30b0946

File tree

6 files changed

+57
-21
lines changed

6 files changed

+57
-21
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -1407,8 +1407,7 @@ bool ClauseProcessor::processUseDeviceAddr(
14071407
const parser::CharBlock &source) {
14081408
mlir::Location location = converter.genLocation(source);
14091409
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1410-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1411-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1410+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
14121411
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
14131412
parentMemberIndices, result.useDeviceAddrVars,
14141413
useDeviceSyms);
@@ -1429,8 +1428,7 @@ bool ClauseProcessor::processUseDevicePtr(
14291428
const parser::CharBlock &source) {
14301429
mlir::Location location = converter.genLocation(source);
14311430
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1432-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1433-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1431+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
14341432
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
14351433
parentMemberIndices, result.useDevicePtrVars,
14361434
useDeviceSyms);

flang/lib/Lower/OpenMP/Utils.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -398,14 +398,16 @@ mlir::Value createParentSymAndGenIntermediateMaps(
398398
interimBounds, treatIndexAsSection);
399399
}
400400

401-
// Remove all map TO, FROM and TOFROM bits, from the intermediate
402-
// allocatable maps, we simply wish to alloc or release them. It may be
403-
// safer to just pass OMP_MAP_NONE as the map type, but we may still
401+
// Remove all map-type bits (e.g. TO, FROM, etc.) from the intermediate
402+
// allocatable maps, as we simply wish to alloc or release them. It may
403+
// be safer to just pass OMP_MAP_NONE as the map type, but we may still
404404
// need some of the other map types the mapped member utilises, so for
405405
// now it's good to keep an eye on this.
406406
llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits;
407407
interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
408408
interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
409+
interimMapType &=
410+
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
409411

410412
// Create a map for the intermediate member and insert it and it's
411413
// indices into the parentMemberIndices list to track it.

flang/test/Fir/convert-to-llvm-openmp-and-fir.fir

+3-2
Original file line numberDiff line numberDiff line change
@@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() {
423423

424424
func.func @_QPomp_target_data_empty() {
425425
%0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
426-
omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
426+
%1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>) map_clauses(return_param) capture(ByRef) -> !fir.ref<!fir.array<1024xi32>> {name = ""}
427+
omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
427428
omp.terminator
428429
}
429430
return
430431
}
431432

432433
// CHECK-LABEL: llvm.func @_QPomp_target_data_empty
433-
// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
434+
// CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) {
434435
// CHECK: }
435436

436437
// -----

flang/test/Lower/OpenMP/target.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ subroutine omp_target_device_addr
544544
!CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
545545
!CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
546546
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
547-
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
547+
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
548548
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
549549
!CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
550550
!$omp target data map(tofrom: a) use_device_addr(a)

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

+39-8
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
15211521
if (mapTypeMod == "delete")
15221522
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
15231523

1524+
if (mapTypeMod == "return_param")
1525+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1526+
15241527
return success();
15251528
};
15261529

@@ -1583,6 +1586,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
15831586
emitAllocRelease = false;
15841587
mapTypeStrs.push_back("delete");
15851588
}
1589+
if (mapTypeToBitFlag(
1590+
mapTypeBits,
1591+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1592+
emitAllocRelease = false;
1593+
mapTypeStrs.push_back("return_param");
1594+
}
15861595
if (emitAllocRelease)
15871596
mapTypeStrs.push_back("exit_release_or_enter_alloc");
15881597

@@ -1777,13 +1786,27 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
17771786
// MapInfoOp
17781787
//===----------------------------------------------------------------------===//
17791788

1789+
static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1790+
StringRef clauseName,
1791+
OperandRange vars) {
1792+
for (Value var : vars)
1793+
if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1794+
return op->emitOpError()
1795+
<< "'" << clauseName
1796+
<< "' arguments must be defined by 'omp.map.info' ops";
1797+
return success();
1798+
}
1799+
17801800
LogicalResult MapInfoOp::verify() {
17811801
if (getMapperId() &&
17821802
!SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
17831803
*this, getMapperIdAttr())) {
17841804
return emitError("invalid mapper id");
17851805
}
17861806

1807+
if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
1808+
return failure();
1809+
17871810
return success();
17881811
}
17891812

@@ -1805,6 +1828,15 @@ LogicalResult TargetDataOp::verify() {
18051828
"At least one of map, use_device_ptr_vars, or "
18061829
"use_device_addr_vars operand must be present");
18071830
}
1831+
1832+
if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
1833+
getUseDevicePtrVars())))
1834+
return failure();
1835+
1836+
if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
1837+
getUseDeviceAddrVars())))
1838+
return failure();
1839+
18081840
return verifyMapClause(*this, getMapVars());
18091841
}
18101842

@@ -1889,16 +1921,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
18891921
}
18901922

18911923
LogicalResult TargetOp::verify() {
1892-
LogicalResult verifyDependVars =
1893-
verifyDependVarList(*this, getDependKinds(), getDependVars());
1894-
1895-
if (failed(verifyDependVars))
1896-
return verifyDependVars;
1924+
if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
1925+
return failure();
18971926

1898-
LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
1927+
if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
1928+
getHasDeviceAddrVars())))
1929+
return failure();
18991930

1900-
if (failed(verifyMapVars))
1901-
return verifyMapVars;
1931+
if (failed(verifyMapClause(*this, getMapVars())))
1932+
return failure();
19021933

19031934
return verifyPrivateVarsMapping(*this);
19041935
}

mlir/test/Dialect/OpenMP/ops.mlir

+7-3
Original file line numberDiff line numberDiff line change
@@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i
802802
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
803803
omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref<?xi32>){}
804804

805-
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
806-
// CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref<i32>)
805+
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
806+
// CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
807+
// CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
808+
// CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref<i32>)
807809
%mapv2 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
808-
omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addr -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptr -> %arg1 : memref<i32>) {
810+
%device_addrv1 = omp.map.info var_ptr(%device_addr : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
811+
%device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
812+
omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addrv1 -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptrv1 -> %arg1 : memref<i32>) {
809813
omp.terminator
810814
}
811815

0 commit comments

Comments
 (0)