-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[Flang][MLIR][OpenMP] Improve use_device_* handling #137198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-openmp @llvm/pr-subscribers-flang-fir-hlfir Author: Sergio Afonso (skatrak) ChangesThis patch updates MLIR op verifiers for operations taking arguments that must always be defined by an It also modifies Flang lowering for Full diff: https://github.com/llvm/llvm-project/pull/137198.diff 6 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index cce6dc32bad4b..d03020707ef99 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1340,8 +1340,7 @@ bool ClauseProcessor::processUseDeviceAddr(
const parser::CharBlock &source) {
mlir::Location location = converter.genLocation(source);
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.useDeviceAddrVars,
useDeviceSyms);
@@ -1362,8 +1361,7 @@ bool ClauseProcessor::processUseDevicePtr(
const parser::CharBlock &source) {
mlir::Location location = converter.genLocation(source);
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.useDevicePtrVars,
useDeviceSyms);
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 3f4cfb8c11a9d..bdeed8f6d213e 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -398,7 +398,7 @@ mlir::Value createParentSymAndGenIntermediateMaps(
interimBounds, treatIndexAsSection);
}
- // Remove all map TO, FROM and TOFROM bits, from the intermediate
+ // Remove all map TO, FROM and RETURN_PARAM bits, from the intermediate
// allocatable maps, we simply wish to alloc or release them. It may be
// safer to just pass OMP_MAP_NONE as the map type, but we may still
// need some of the other map types the mapped member utilises, so for
@@ -406,6 +406,8 @@ mlir::Value createParentSymAndGenIntermediateMaps(
llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits;
interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ interimMapType &=
+ ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
// Create a map for the intermediate member and insert it and it's
// indices into the parentMemberIndices list to track it.
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8019ecf7f6a05..b13921f822b4d 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() {
func.func @_QPomp_target_data_empty() {
%0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
- omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
+ %1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>) map_clauses(return_param) capture(ByRef) -> !fir.ref<!fir.array<1024xi32>> {name = ""}
+ omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
omp.terminator
}
return
}
// CHECK-LABEL: llvm.func @_QPomp_target_data_empty
-// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
+// CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) {
// CHECK: }
// -----
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 4815e6564fc7e..f04aacc63fc2b 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -544,7 +544,7 @@ subroutine omp_target_device_addr
!CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
!CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
- !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
+ !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
!$omp target data map(tofrom: a) use_device_addr(a)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dd701da507fc6..a81f9a63a8ebb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1520,6 +1520,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
if (mapTypeMod == "delete")
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+ if (mapTypeMod == "return_param")
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+
return success();
};
@@ -1582,6 +1585,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
emitAllocRelease = false;
mapTypeStrs.push_back("delete");
}
+ if (mapTypeToBitFlag(
+ mapTypeBits,
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
+ emitAllocRelease = false;
+ mapTypeStrs.push_back("return_param");
+ }
if (emitAllocRelease)
mapTypeStrs.push_back("exit_release_or_enter_alloc");
@@ -1776,6 +1785,17 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
// MapInfoOp
//===----------------------------------------------------------------------===//
+static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
+ StringRef clauseName,
+ OperandRange vars) {
+ for (Value var : vars)
+ if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
+ return op->emitOpError()
+ << "'" << clauseName
+ << "' arguments must be defined by 'omp.map.info' ops";
+ return success();
+}
+
LogicalResult MapInfoOp::verify() {
if (getMapperId() &&
!SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
@@ -1783,6 +1803,9 @@ LogicalResult MapInfoOp::verify() {
return emitError("invalid mapper id");
}
+ if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
+ return failure();
+
return success();
}
@@ -1804,6 +1827,15 @@ LogicalResult TargetDataOp::verify() {
"At least one of map, use_device_ptr_vars, or "
"use_device_addr_vars operand must be present");
}
+
+ if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
+ getUseDevicePtrVars())))
+ return failure();
+
+ if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
+ getUseDeviceAddrVars())))
+ return failure();
+
return verifyMapClause(*this, getMapVars());
}
@@ -1888,16 +1920,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult TargetOp::verify() {
- LogicalResult verifyDependVars =
- verifyDependVarList(*this, getDependKinds(), getDependVars());
-
- if (failed(verifyDependVars))
- return verifyDependVars;
+ if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
+ return failure();
- LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
+ if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
+ getHasDeviceAddrVars())))
+ return failure();
- if (failed(verifyMapVars))
- return verifyMapVars;
+ if (failed(verifyMapClause(*this, getMapVars())))
+ return failure();
return verifyPrivateVarsMapping(*this);
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d5e2bfa5d3949..3cecc2188aabd 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref<?xi32>){}
- // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
- // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref<i32>)
+ // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
+ // CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+ // CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+ // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref<i32>)
%mapv2 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
- omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addr -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptr -> %arg1 : memref<i32>) {
+ %device_addrv1 = omp.map.info var_ptr(%device_addr : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+ %device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+ omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addrv1 -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptrv1 -> %arg1 : memref<i32>) {
omp.terminator
}
|
de00b77
to
beb8430
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit added. LGTM otherwise :)
This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an `omp.map.info` operation to check this requirement. It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the custom MLIR printer and parser for these clauses, to support initializing it to `OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as `return_param`. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.
beb8430
to
59bc1c9
Compare
This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an `omp.map.info` operation to check this requirement. It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the custom MLIR printer and parser for these clauses, to support initializing it to `OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as `return_param`. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.
This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an
omp.map.info
operation to check this requirement.It also modifies Flang lowering for
use_device_{addr, ptr}
, as well as the custom MLIR printer and parser for these clauses, to support initializing it toOMP_MAP_RETURN_PARAM
and represent this in the MLIR representation asreturn_param
. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.