Skip to content

[mlir][OpenMP] add attribute for privatization barrier #140089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tblah
Copy link
Contributor

@tblah tblah commented May 15, 2025

A barrier is needed at the end of initialization/copying of private variables if any of those variables is lastprivate. This ensures that all firstprivate variables receive the original value of the variable before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not a reliable way to put the barrier in the correct place for delayed privatization, and the OpenMP dialect could some day have other users. It is important that there are safe ways to use the constructs available in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so there is no way to reliably determine whether there were lastprivate variables. Therefore the frontend will have to provide this information through this new attribute.

Part of a series of patches to fix
#136357

A barrier is needed at the end of initialization/copying of private
variables if any of those variables is lastprivate. This ensures that
all firstprivate variables receive the original value of the variable
before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not
a reliable way to put the barrier in the correct place for delayed
privatization, and the OpenMP dialect could some day have other users.
It is important that there are safe ways to use the constructs available
in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so
there is no way to reliably determine whether there were lastprivate
variables. Therefore the frontend will have to provide this information
through this new attribute.

Part of a series of patches to fix
#136357
@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir-openmp

Author: Tom Eccles (tblah)

Changes

A barrier is needed at the end of initialization/copying of private variables if any of those variables is lastprivate. This ensures that all firstprivate variables receive the original value of the variable before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not a reliable way to put the barrier in the correct place for delayed privatization, and the OpenMP dialect could some day have other users. It is important that there are safe ways to use the constructs available in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so there is no way to reliably determine whether there were lastprivate variables. Therefore the frontend will have to provide this information through this new attribute.

Part of a series of patches to fix
#136357


Patch is 33.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140089.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+4-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+19-18)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+118-78)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index f8e880ea43b75..16c14ef085d6d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1102,7 +1102,10 @@ class OpenMP_PrivateClauseSkip<
 
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
-    OptionalAttr<SymbolRefArrayAttr>:$private_syms
+    OptionalAttr<SymbolRefArrayAttr>:$private_syms,
+    // Set this attribute if a barrier is needed after initialization and
+    // copying of lastprivate variables.
+    UnitAttr:$private_needs_barrier
   );
 
   // TODO: Add description.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a79fbf77a268..036c6a6e350a8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -213,8 +213,8 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -258,8 +258,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -317,8 +317,8 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -350,7 +350,7 @@ def SingleOp : OpenMP_Op<"single", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -505,8 +505,8 @@ def LoopOp : OpenMP_Op<"loop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let builders = [
@@ -557,8 +557,8 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -611,8 +611,8 @@ def SimdOp : OpenMP_Op<"simd", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -690,7 +690,7 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -740,7 +740,7 @@ def TaskOp
     custom<InReductionPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -816,8 +816,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     custom<InReductionPrivateReductionRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
-        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier,
+        $reduction_mod, $reduction_vars, type($reduction_vars),
+        $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let extraClassDeclaration = [{
@@ -1324,7 +1325,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
         $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
         type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
         $map_vars, type($map_vars), $private_vars, type($private_vars),
-        $private_syms, $private_maps) attr-dict
+        $private_syms, $private_needs_barrier, $private_maps) attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 233739e1d6d91..71786e856c6db 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,6 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
+        /* private_needs_barrier = */ nullptr,
         /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
         /* reduction_mod = */ nullptr,
         /* reduction_vars = */ llvm::SmallVector<Value>{},
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bf7aaa46db11..57a54f21fe9de 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -581,11 +581,14 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  UnitAttr &needsBarrier;
   DenseI64ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                    SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   UnitAttr &needsBarrier,
                    DenseI64ArrayAttr *mapIndices = nullptr)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 
 struct ReductionParseArgs {
@@ -613,6 +616,10 @@ struct AllRegionParseArgs {
 };
 } // namespace
 
+static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
+  return "private_barrier";
+}
+
 static ParseResult parseClauseWithRegionArgs(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
@@ -620,7 +627,8 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
     ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
     DenseBoolArrayAttr *byref = nullptr,
-    ReductionModifierAttr *modifier = nullptr) {
+    ReductionModifierAttr *modifier = nullptr,
+    UnitAttr *needsBarrier = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
   SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
@@ -688,6 +696,12 @@ static ParseResult parseClauseWithRegionArgs(
   if (parser.parseRParen())
     return failure();
 
+  if (needsBarrier) {
+    if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling())
+            .succeeded())
+      *needsBarrier = mlir::UnitAttr::get(parser.getContext());
+  }
+
   auto *argsBegin = regionPrivateArgs.begin();
   MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
                                argsBegin + regionArgOffset + types.size());
@@ -735,7 +749,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
-            &privateArgs->syms, privateArgs->mapIndices)))
+            &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+            /*modifier=*/nullptr, &privateArgs->needsBarrier)))
       return failure();
   }
   return success();
@@ -824,7 +839,7 @@ static ParseResult parseTargetOpRegion(
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    DenseI64ArrayAttr &privateMaps) {
+    UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
@@ -832,7 +847,7 @@ static ParseResult parseTargetOpRegion(
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
-                           &privateMaps);
+                           privateNeedsBarrier, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -842,11 +857,13 @@ static ParseResult parseInReductionPrivateRegion(
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -857,14 +874,15 @@ static ParseResult parseInReductionPrivateReductionRegion(
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -873,9 +891,11 @@ static ParseResult parseInReductionPrivateReductionRegion(
 static ParseResult parsePrivateRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -883,12 +903,13 @@ static ParseResult parsePrivateReductionRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -931,10 +952,12 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
+  UnitAttr needsBarrier;
   DenseI64ArrayAttr mapIndices;
   PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
-                   DenseI64ArrayAttr mapIndices)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+                   UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -964,7 +987,7 @@ static void printClauseWithRegionArgs(
     ValueRange argsSubrange, ValueRange operands, TypeRange types,
     ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
     DenseBoolArrayAttr byref = nullptr,
-    ReductionModifierAttr modifier = nullptr) {
+    ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
   if (argsSubrange.empty())
     return;
 
@@ -1006,6 +1029,9 @@ static void printClauseWithRegionArgs(
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
+
+  if (needsBarrier)
+    p << getPrivateNeedsBarrierSpelling() << " ";
 }
 
 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
@@ -1020,9 +1046,10 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
                                 StringRef clauseName, ValueRange argsSubrange,
                                 std::optional<PrivatePrintArgs> privateArgs) {
   if (privateArgs)
-    printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
-                              privateArgs->vars, privateArgs->types,
-                              privateArgs->syms, privateArgs->mapIndices);
+    printClauseWithRegionArgs(
+        p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
+        privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+        /*modifier=*/nullptr, privateArgs->needsBarrier);
 }
 
 static void
@@ -1068,23 +1095,23 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 // These parseXyz functions correspond to the custom<Xyz> definitions
 // in the .td file(s).
-static void
-printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region,
-                    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
-                    ValueRange hostEvalVars, TypeRange hostEvalTypes,
-                    ValueRange inReductionVars, TypeRange inReductionTypes,
-                    DenseBoolArrayAttr inReductionByref,
-                    ArrayAttr inReductionSyms, ValueRange mapVars,
-                    TypeRange mapTypes, ValueRange privateVars,
-                    TypeRange privateTypes, ArrayAttr privateSyms,
-                    DenseI64ArrayAttr privateMaps) {
+static void printTargetOpRegion(
+    OpAsmPrinter &p, Operation *op, Region &region,
+    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
+    ValueRange hostEvalVars, TypeRange hostEvalTypes,
+    ValueRange inReductionVars, TypeRange inReductionTypes,
+    DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
+    ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
+    TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1092,11 +1119,12 @@ static void printInReductionPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
@@ -1105,13 +1133,15 @@ static void printInReductionPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
-    ValueRange reductionVars, TypeRange reductionTypes,
-    DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    ReductionModifierAttr reductionMod, ValueRange reductionVars,
+    TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
+    ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, reductionMod);
@@ -1120,21 +1150,24 @@ static void printInReductionPrivateReductionRegion(
 
 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
-                               ArrayAttr privateSyms) {
+                               ArrayAttr privateSyms,
+                               UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs ar...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-mlir

Author: Tom Eccles (tblah)

Changes

A barrier is needed at the end of initialization/copying of private variables if any of those variables is lastprivate. This ensures that all firstprivate variables receive the original value of the variable before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not a reliable way to put the barrier in the correct place for delayed privatization, and the OpenMP dialect could some day have other users. It is important that there are safe ways to use the constructs available in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so there is no way to reliably determine whether there were lastprivate variables. Therefore the frontend will have to provide this information through this new attribute.

Part of a series of patches to fix
#136357


Patch is 33.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140089.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+4-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+19-18)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+118-78)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index f8e880ea43b75..16c14ef085d6d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1102,7 +1102,10 @@ class OpenMP_PrivateClauseSkip<
 
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
-    OptionalAttr<SymbolRefArrayAttr>:$private_syms
+    OptionalAttr<SymbolRefArrayAttr>:$private_syms,
+    // Set this attribute if a barrier is needed after initialization and
+    // copying of lastprivate variables.
+    UnitAttr:$private_needs_barrier
   );
 
   // TODO: Add description.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a79fbf77a268..036c6a6e350a8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -213,8 +213,8 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -258,8 +258,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -317,8 +317,8 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -350,7 +350,7 @@ def SingleOp : OpenMP_Op<"single", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -505,8 +505,8 @@ def LoopOp : OpenMP_Op<"loop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let builders = [
@@ -557,8 +557,8 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -611,8 +611,8 @@ def SimdOp : OpenMP_Op<"simd", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -690,7 +690,7 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -740,7 +740,7 @@ def TaskOp
     custom<InReductionPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -816,8 +816,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     custom<InReductionPrivateReductionRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
-        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier,
+        $reduction_mod, $reduction_vars, type($reduction_vars),
+        $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let extraClassDeclaration = [{
@@ -1324,7 +1325,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
         $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
         type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
         $map_vars, type($map_vars), $private_vars, type($private_vars),
-        $private_syms, $private_maps) attr-dict
+        $private_syms, $private_needs_barrier, $private_maps) attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 233739e1d6d91..71786e856c6db 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,6 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
+        /* private_needs_barrier = */ nullptr,
         /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
         /* reduction_mod = */ nullptr,
         /* reduction_vars = */ llvm::SmallVector<Value>{},
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bf7aaa46db11..57a54f21fe9de 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -581,11 +581,14 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  UnitAttr &needsBarrier;
   DenseI64ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                    SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   UnitAttr &needsBarrier,
                    DenseI64ArrayAttr *mapIndices = nullptr)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 
 struct ReductionParseArgs {
@@ -613,6 +616,10 @@ struct AllRegionParseArgs {
 };
 } // namespace
 
+static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
+  return "private_barrier";
+}
+
 static ParseResult parseClauseWithRegionArgs(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
@@ -620,7 +627,8 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
     ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
     DenseBoolArrayAttr *byref = nullptr,
-    ReductionModifierAttr *modifier = nullptr) {
+    ReductionModifierAttr *modifier = nullptr,
+    UnitAttr *needsBarrier = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
   SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
@@ -688,6 +696,12 @@ static ParseResult parseClauseWithRegionArgs(
   if (parser.parseRParen())
     return failure();
 
+  if (needsBarrier) {
+    if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling())
+            .succeeded())
+      *needsBarrier = mlir::UnitAttr::get(parser.getContext());
+  }
+
   auto *argsBegin = regionPrivateArgs.begin();
   MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
                                argsBegin + regionArgOffset + types.size());
@@ -735,7 +749,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
-            &privateArgs->syms, privateArgs->mapIndices)))
+            &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+            /*modifier=*/nullptr, &privateArgs->needsBarrier)))
       return failure();
   }
   return success();
@@ -824,7 +839,7 @@ static ParseResult parseTargetOpRegion(
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    DenseI64ArrayAttr &privateMaps) {
+    UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
@@ -832,7 +847,7 @@ static ParseResult parseTargetOpRegion(
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
-                           &privateMaps);
+                           privateNeedsBarrier, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -842,11 +857,13 @@ static ParseResult parseInReductionPrivateRegion(
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -857,14 +874,15 @@ static ParseResult parseInReductionPrivateReductionRegion(
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -873,9 +891,11 @@ static ParseResult parseInReductionPrivateReductionRegion(
 static ParseResult parsePrivateRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -883,12 +903,13 @@ static ParseResult parsePrivateReductionRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -931,10 +952,12 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
+  UnitAttr needsBarrier;
   DenseI64ArrayAttr mapIndices;
   PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
-                   DenseI64ArrayAttr mapIndices)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+                   UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -964,7 +987,7 @@ static void printClauseWithRegionArgs(
     ValueRange argsSubrange, ValueRange operands, TypeRange types,
     ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
     DenseBoolArrayAttr byref = nullptr,
-    ReductionModifierAttr modifier = nullptr) {
+    ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
   if (argsSubrange.empty())
     return;
 
@@ -1006,6 +1029,9 @@ static void printClauseWithRegionArgs(
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
+
+  if (needsBarrier)
+    p << getPrivateNeedsBarrierSpelling() << " ";
 }
 
 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
@@ -1020,9 +1046,10 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
                                 StringRef clauseName, ValueRange argsSubrange,
                                 std::optional<PrivatePrintArgs> privateArgs) {
   if (privateArgs)
-    printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
-                              privateArgs->vars, privateArgs->types,
-                              privateArgs->syms, privateArgs->mapIndices);
+    printClauseWithRegionArgs(
+        p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
+        privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+        /*modifier=*/nullptr, privateArgs->needsBarrier);
 }
 
 static void
@@ -1068,23 +1095,23 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 // These parseXyz functions correspond to the custom<Xyz> definitions
 // in the .td file(s).
-static void
-printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region,
-                    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
-                    ValueRange hostEvalVars, TypeRange hostEvalTypes,
-                    ValueRange inReductionVars, TypeRange inReductionTypes,
-                    DenseBoolArrayAttr inReductionByref,
-                    ArrayAttr inReductionSyms, ValueRange mapVars,
-                    TypeRange mapTypes, ValueRange privateVars,
-                    TypeRange privateTypes, ArrayAttr privateSyms,
-                    DenseI64ArrayAttr privateMaps) {
+static void printTargetOpRegion(
+    OpAsmPrinter &p, Operation *op, Region &region,
+    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
+    ValueRange hostEvalVars, TypeRange hostEvalTypes,
+    ValueRange inReductionVars, TypeRange inReductionTypes,
+    DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
+    ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
+    TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1092,11 +1119,12 @@ static void printInReductionPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
@@ -1105,13 +1133,15 @@ static void printInReductionPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
-    ValueRange reductionVars, TypeRange reductionTypes,
-    DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    ReductionModifierAttr reductionMod, ValueRange reductionVars,
+    TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
+    ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, reductionMod);
@@ -1120,21 +1150,24 @@ static void printInReductionPrivateReductionRegion(
 
 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
-                               ArrayAttr privateSyms) {
+                               ArrayAttr privateSyms,
+                               UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs ar...
[truncated]

Copy link
Contributor

@NimishMishra NimishMishra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some minor nits.

OptionalAttr<SymbolRefArrayAttr>:$private_syms,
// Set this attribute if a barrier is needed after initialization and
// copying of lastprivate variables.
UnitAttr:$private_needs_barrier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: This attribute is already within private clause. Is private_needs_barrier a better name? In OpenMPDialect.cpp too, needsBarrier is used, which is self-sufficient I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The members for each clause are all combined into a structure containing elements for all arguments to the operation. Therefore there's a convention to prefix clause data with the name of the clause to avoid name collisions (e.g. if a needs_barrier attribute was added to REDUCTION).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name of the clause to avoid name collisions (e.g. if a needs_barrier attribute was added to REDUCTION).

Oh okay thanks. Makes sense

@@ -824,15 +839,15 @@ static ParseResult parseTargetOpRegion(
SmallVectorImpl<Type> &mapTypes,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
DenseI64ArrayAttr &privateMaps) {
UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe we can make this argument consistent. parseClauseWithRegionArgs has needsBarrier; here we have privateNeedsBarrier`.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distinction is generally where this could be confused with future needsBarrier members for other clauses in that context. For example, there is no need for a "private" prefix inside of PrivateParseArgs but there is ambiguity in parseTargetOpRegion because there could one day be needsBarrier members for other clauses on TARGET.

Similar naming choices are used for privateSyms and privateVars. I believe my current usage is consistent with those other variables.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay makes sense. Thanks

// CHECK: omp.parallel private(
// CHECK-SAME: @x.privatizer %[[ARG0]] -> %[[ARG0_PRIV:[^[:space:]]+]],
// CHECK-SAME: @y.privatizer %[[ARG1]] -> %[[ARG1_PRIV:[^[:space:]]+]] : !llvm.ptr, !llvm.ptr)
// CHECK-SAME: private_barrier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For printing the MLIR, is it better to dump omp.barrier as normal? This barrier and normal omp.barrier are indistinguishable. Should we, from the perspective of MLIR printing, introduce new syntax (private_barrier) or keep omp.barrier ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nowhere in MLIR to print the omp.barrier that accurately describes its location in generated code, otherwise I could have moved the omp.barrier operation instead of adding this new attribute.

This barrier is inserted between the copy regions of firstprivate arguments and the initialization call for the openmp construct. If the omp.barrier was placed before the construct this would indicate the barrier came before firstprivate copying. If it was placed inside the nested region it would come after the start of the construct (which I tried but it turns out not to work because some constructs won't run on all active threads e.g. a sections construct with fewer sections than the number of available threads).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nowhere in MLIR to print the omp.barrier that accurately describes its location in generated code, otherwise I could have moved the omp.barrier operation instead of adding this new attribute.

Okay understood. Thanks

@tblah
Copy link
Contributor Author

tblah commented May 16, 2025

Thanks for taking a look Nimish, please let me know whether my explanations are satisfying or if I should go ahead and make some changes.

@NimishMishra
Copy link
Contributor

Thanks for taking a look Nimish, please let me know whether my explanations are satisfying or if I should go ahead and make some changes.

Thanks @tblah. No your explanations make sense; I did not foresee name confusions with future constructs. Let us leave it the way it is then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants