Skip to content

[Flang][MLIR][OpenMP] Improve use_device_* handling #137198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Apr 24, 2025

This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an omp.map.info operation to check this requirement.

It also modifies Flang lowering for use_device_{addr, ptr}, as well as the custom MLIR printer and parser for these clauses, to support initializing it to OMP_MAP_RETURN_PARAM and represent this in the MLIR representation as return_param. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.

@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2025

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: Sergio Afonso (skatrak)

Changes

This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an omp.map.info operation to check this requirement.

It also modifies Flang lowering for use_device_{addr, ptr}, as well as the custom MLIR printer and parser for these clauses, to support initializing it to OMP_MAP_RETURN_PARAM and represent this in the MLIR representation as return_param. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.


Full diff: https://github.com/llvm/llvm-project/pull/137198.diff

6 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+2-4)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+3-1)
  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+3-2)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+1-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+39-8)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+7-3)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index cce6dc32bad4b..d03020707ef99 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1340,8 +1340,7 @@ bool ClauseProcessor::processUseDeviceAddr(
           const parser::CharBlock &source) {
         mlir::Location location = converter.genLocation(source);
         llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDeviceAddrVars,
                           useDeviceSyms);
@@ -1362,8 +1361,7 @@ bool ClauseProcessor::processUseDevicePtr(
           const parser::CharBlock &source) {
         mlir::Location location = converter.genLocation(source);
         llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDevicePtrVars,
                           useDeviceSyms);
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 3f4cfb8c11a9d..bdeed8f6d213e 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -398,7 +398,7 @@ mlir::Value createParentSymAndGenIntermediateMaps(
               interimBounds, treatIndexAsSection);
         }
 
-        // Remove all map TO, FROM and TOFROM bits, from the intermediate
+        // Remove all map TO, FROM and RETURN_PARAM bits, from the intermediate
         // allocatable maps, we simply wish to alloc or release them. It may be
         // safer to just pass OMP_MAP_NONE as the map type, but we may still
         // need some of the other map types the mapped member utilises, so for
@@ -406,6 +406,8 @@ mlir::Value createParentSymAndGenIntermediateMaps(
         llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits;
         interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
         interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+        interimMapType &=
+            ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
 
         // Create a map for the intermediate member and insert it and it's
         // indices into the parentMemberIndices list to track it.
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8019ecf7f6a05..b13921f822b4d 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() {
 
 func.func @_QPomp_target_data_empty() {
   %0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
-  omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
+  %1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>) map_clauses(return_param) capture(ByRef) -> !fir.ref<!fir.array<1024xi32>> {name = ""}
+  omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
     omp.terminator
   }
   return
 }
 
 // CHECK-LABEL:   llvm.func @_QPomp_target_data_empty
-// CHECK: omp.target_data   use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
+// CHECK: omp.target_data   use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) {
 // CHECK: }
 
 // -----
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 4815e6564fc7e..f04aacc63fc2b 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -544,7 +544,7 @@ subroutine omp_target_device_addr
    !CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
    !CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
-   !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
+   !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
    !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
    !$omp target data map(tofrom: a) use_device_addr(a)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dd701da507fc6..a81f9a63a8ebb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1520,6 +1520,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (mapTypeMod == "delete")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
 
+    if (mapTypeMod == "return_param")
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+
     return success();
   };
 
@@ -1582,6 +1585,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
     emitAllocRelease = false;
     mapTypeStrs.push_back("delete");
   }
+  if (mapTypeToBitFlag(
+          mapTypeBits,
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
+    emitAllocRelease = false;
+    mapTypeStrs.push_back("return_param");
+  }
   if (emitAllocRelease)
     mapTypeStrs.push_back("exit_release_or_enter_alloc");
 
@@ -1776,6 +1785,17 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
 // MapInfoOp
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
+                                              StringRef clauseName,
+                                              OperandRange vars) {
+  for (Value var : vars)
+    if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
+      return op->emitOpError()
+             << "'" << clauseName
+             << "' arguments must be defined by 'omp.map.info' ops";
+  return success();
+}
+
 LogicalResult MapInfoOp::verify() {
   if (getMapperId() &&
       !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
@@ -1783,6 +1803,9 @@ LogicalResult MapInfoOp::verify() {
     return emitError("invalid mapper id");
   }
 
+  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
+    return failure();
+
   return success();
 }
 
@@ -1804,6 +1827,15 @@ LogicalResult TargetDataOp::verify() {
                        "At least one of map, use_device_ptr_vars, or "
                        "use_device_addr_vars operand must be present");
   }
+
+  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
+                                      getUseDevicePtrVars())))
+    return failure();
+
+  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
+                                      getUseDeviceAddrVars())))
+    return failure();
+
   return verifyMapClause(*this, getMapVars());
 }
 
@@ -1888,16 +1920,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult TargetOp::verify() {
-  LogicalResult verifyDependVars =
-      verifyDependVarList(*this, getDependKinds(), getDependVars());
-
-  if (failed(verifyDependVars))
-    return verifyDependVars;
+  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
+    return failure();
 
-  LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
+  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
+                                      getHasDeviceAddrVars())))
+    return failure();
 
-  if (failed(verifyMapVars))
-    return verifyMapVars;
+  if (failed(verifyMapClause(*this, getMapVars())))
+    return failure();
 
   return verifyPrivateVarsMapping(*this);
 }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d5e2bfa5d3949..3cecc2188aabd 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i
     %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
     omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref<?xi32>){}
 
-    // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
-    // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref<i32>)
+    // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
+    // CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>)   map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+    // CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<i32>, tensor<i32>)   map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+    // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref<i32>)
     %mapv2 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
-    omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addr -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptr -> %arg1 : memref<i32>) {
+    %device_addrv1 = omp.map.info var_ptr(%device_addr : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+    %device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+    omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addrv1 -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptrv1 -> %arg1 : memref<i32>) {
       omp.terminator
     }
 

Copy link
Member

@TIFitis TIFitis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit added. LGTM otherwise :)

This patch updates MLIR op verifiers for operations taking arguments that must
always be defined by an `omp.map.info` operation to check this requirement.

It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the
custom MLIR printer and parser for these clauses, to support initializing it to
`OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as
`return_param`. This internal mapping flag is what eventually is used for
variables passed via these clauses into the target region when translating to
LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in
the current representation.
@skatrak skatrak force-pushed the users/skatrak/map-rework-01-use-device branch from beb8430 to 59bc1c9 Compare May 15, 2025 10:23
@skatrak skatrak merged commit 30b0946 into main May 15, 2025
11 checks passed
@skatrak skatrak deleted the users/skatrak/map-rework-01-use-device branch May 15, 2025 11:28
TIFitis pushed a commit to TIFitis/llvm-project that referenced this pull request May 19, 2025
This patch updates MLIR op verifiers for operations taking arguments
that must always be defined by an `omp.map.info` operation to check this
requirement.

It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as
the custom MLIR printer and parser for these clauses, to support
initializing it to `OMP_MAP_RETURN_PARAM` and represent this in the MLIR
representation as `return_param`. This internal mapping flag is what
eventually is used for variables passed via these clauses into the
target region when translating to LLVM IR, so making it explicit in
Flang and MLIR removes an inconsistency in the current representation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants