diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index f8e880ea43b75..16c14ef085d6d 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1102,7 +1102,10 @@ class OpenMP_PrivateClauseSkip< let arguments = (ins Variadic:$private_vars, - OptionalAttr:$private_syms + OptionalAttr:$private_syms, + // Set this attribute if a barrier is needed after initialization and + // copying of lastprivate variables. + UnitAttr:$private_needs_barrier ); // TODO: Add description. diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 5a79fbf77a268..036c6a6e350a8 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -213,8 +213,8 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, - $reduction_syms) attr-dict + $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars, + type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; let hasVerifier = 1; @@ -258,8 +258,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, - $reduction_syms) attr-dict + $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars, + type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; let hasVerifier = 1; @@ -317,8 +317,8 @@ def SectionsOp : OpenMP_Op<"sections", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, - $reduction_syms) attr-dict + $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars, + type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; let hasVerifier = 1; @@ -350,7 +350,7 @@ def SingleOp : OpenMP_Op<"single", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms) attr-dict + $private_syms, $private_needs_barrier) attr-dict }]; let hasVerifier = 1; @@ -505,8 +505,8 @@ def LoopOp : OpenMP_Op<"loop", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, - $reduction_syms) attr-dict + $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars, + type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; let builders = [ @@ -557,8 +557,8 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, - $reduction_syms) attr-dict + $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars, + type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; let hasVerifier = 1; @@ -611,8 +611,8 @@ def SimdOp : OpenMP_Op<"simd", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref, - $reduction_syms) attr-dict + $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars, + type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict }]; let hasVerifier = 1; @@ -690,7 +690,7 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [ let assemblyFormat = clausesAssemblyFormat # [{ custom($region, $private_vars, type($private_vars), - $private_syms) attr-dict + $private_syms, $private_needs_barrier) attr-dict }]; let hasVerifier = 1; @@ -740,7 +740,7 @@ def TaskOp custom( $region, $in_reduction_vars, type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms, $private_vars, - type($private_vars), $private_syms) attr-dict + type($private_vars), $private_syms, $private_needs_barrier) attr-dict }]; let hasVerifier = 1; @@ -816,8 +816,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [ custom( $region, $in_reduction_vars, type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms, $private_vars, - type($private_vars), $private_syms, $reduction_mod, $reduction_vars, - type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict + type($private_vars), $private_syms, $private_needs_barrier, + $reduction_mod, $reduction_vars, type($reduction_vars), + $reduction_byref, $reduction_syms) attr-dict }]; let extraClassDeclaration = [{ @@ -1324,7 +1325,7 @@ def TargetOp : OpenMP_Op<"target", traits = [ $host_eval_vars, type($host_eval_vars), $in_reduction_vars, type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars), $private_vars, type($private_vars), - $private_syms, $private_maps) attr-dict + $private_syms, $private_needs_barrier, $private_maps) attr-dict }]; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 233739e1d6d91..71786e856c6db 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -450,6 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern { /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, + /* private_needs_barrier = */ nullptr, /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{}, /* reduction_mod = */ nullptr, /* reduction_vars = */ llvm::SmallVector{}, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 2bf7aaa46db11..57a54f21fe9de 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -581,11 +581,14 @@ struct PrivateParseArgs { llvm::SmallVectorImpl &vars; llvm::SmallVectorImpl &types; ArrayAttr &syms; + UnitAttr &needsBarrier; DenseI64ArrayAttr *mapIndices; PrivateParseArgs(SmallVectorImpl &vars, SmallVectorImpl &types, ArrayAttr &syms, + UnitAttr &needsBarrier, DenseI64ArrayAttr *mapIndices = nullptr) - : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {} + : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier), + mapIndices(mapIndices) {} }; struct ReductionParseArgs { @@ -613,6 +616,10 @@ struct AllRegionParseArgs { }; } // namespace +static inline constexpr StringRef getPrivateNeedsBarrierSpelling() { + return "private_barrier"; +} + static ParseResult parseClauseWithRegionArgs( OpAsmParser &parser, SmallVectorImpl &operands, @@ -620,7 +627,8 @@ static ParseResult parseClauseWithRegionArgs( SmallVectorImpl ®ionPrivateArgs, ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr, DenseBoolArrayAttr *byref = nullptr, - ReductionModifierAttr *modifier = nullptr) { + ReductionModifierAttr *modifier = nullptr, + UnitAttr *needsBarrier = nullptr) { SmallVector symbolVec; SmallVector mapIndicesVec; SmallVector isByRefVec; @@ -688,6 +696,12 @@ static ParseResult parseClauseWithRegionArgs( if (parser.parseRParen()) return failure(); + if (needsBarrier) { + if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling()) + .succeeded()) + *needsBarrier = mlir::UnitAttr::get(parser.getContext()); + } + auto *argsBegin = regionPrivateArgs.begin(); MutableArrayRef argsSubrange(argsBegin + regionArgOffset, argsBegin + regionArgOffset + types.size()); @@ -735,7 +749,8 @@ static ParseResult parseBlockArgClause( if (failed(parseClauseWithRegionArgs( parser, privateArgs->vars, privateArgs->types, entryBlockArgs, - &privateArgs->syms, privateArgs->mapIndices))) + &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr, + /*modifier=*/nullptr, &privateArgs->needsBarrier))) return failure(); } return success(); @@ -824,7 +839,7 @@ static ParseResult parseTargetOpRegion( SmallVectorImpl &mapTypes, llvm::SmallVectorImpl &privateVars, llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, - DenseI64ArrayAttr &privateMaps) { + UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) { AllRegionParseArgs args; args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes); args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); @@ -832,7 +847,7 @@ static ParseResult parseTargetOpRegion( inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); args.privateArgs.emplace(privateVars, privateTypes, privateSyms, - &privateMaps); + privateNeedsBarrier, &privateMaps); return parseBlockArgRegion(parser, region, args); } @@ -842,11 +857,13 @@ static ParseResult parseInReductionPrivateRegion( SmallVectorImpl &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl &privateVars, - llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms) { + llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, + UnitAttr &privateNeedsBarrier) { AllRegionParseArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier); return parseBlockArgRegion(parser, region, args); } @@ -857,14 +874,15 @@ static ParseResult parseInReductionPrivateReductionRegion( DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl &privateVars, llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, - ReductionModifierAttr &reductionMod, + UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl &reductionVars, SmallVectorImpl &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms) { AllRegionParseArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, reductionSyms, &reductionMod); return parseBlockArgRegion(parser, region, args); @@ -873,9 +891,11 @@ static ParseResult parseInReductionPrivateReductionRegion( static ParseResult parsePrivateRegion( OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl &privateVars, - llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms) { + llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, + UnitAttr &privateNeedsBarrier) { AllRegionParseArgs args; - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier); return parseBlockArgRegion(parser, region, args); } @@ -883,12 +903,13 @@ static ParseResult parsePrivateReductionRegion( OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl &privateVars, llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, - ReductionModifierAttr &reductionMod, + UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl &reductionVars, SmallVectorImpl &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms) { AllRegionParseArgs args; - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, reductionSyms, &reductionMod); return parseBlockArgRegion(parser, region, args); @@ -931,10 +952,12 @@ struct PrivatePrintArgs { ValueRange vars; TypeRange types; ArrayAttr syms; + UnitAttr needsBarrier; DenseI64ArrayAttr mapIndices; PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms, - DenseI64ArrayAttr mapIndices) - : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {} + UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices) + : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier), + mapIndices(mapIndices) {} }; struct ReductionPrintArgs { ValueRange vars; @@ -964,7 +987,7 @@ static void printClauseWithRegionArgs( ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr, DenseBoolArrayAttr byref = nullptr, - ReductionModifierAttr modifier = nullptr) { + ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) { if (argsSubrange.empty()) return; @@ -1006,6 +1029,9 @@ static void printClauseWithRegionArgs( p << " : "; llvm::interleaveComma(types, p); p << ") "; + + if (needsBarrier) + p << getPrivateNeedsBarrierSpelling() << " "; } static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, @@ -1020,9 +1046,10 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional privateArgs) { if (privateArgs) - printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, - privateArgs->vars, privateArgs->types, - privateArgs->syms, privateArgs->mapIndices); + printClauseWithRegionArgs( + p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types, + privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr, + /*modifier=*/nullptr, privateArgs->needsBarrier); } static void @@ -1068,23 +1095,23 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, // These parseXyz functions correspond to the custom definitions // in the .td file(s). -static void -printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, - ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, - ValueRange hostEvalVars, TypeRange hostEvalTypes, - ValueRange inReductionVars, TypeRange inReductionTypes, - DenseBoolArrayAttr inReductionByref, - ArrayAttr inReductionSyms, ValueRange mapVars, - TypeRange mapTypes, ValueRange privateVars, - TypeRange privateTypes, ArrayAttr privateSyms, - DenseI64ArrayAttr privateMaps) { +static void printTargetOpRegion( + OpAsmPrinter &p, Operation *op, Region ®ion, + ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, + ValueRange hostEvalVars, TypeRange hostEvalTypes, + ValueRange inReductionVars, TypeRange inReductionTypes, + DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, + ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, + TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, + DenseI64ArrayAttr privateMaps) { AllRegionPrintArgs args; args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes); args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); - args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier, privateMaps); printBlockArgRegion(p, op, region, args); } @@ -1092,11 +1119,12 @@ static void printInReductionPrivateRegion( OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, - ArrayAttr privateSyms) { + ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) { AllRegionPrintArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier, /*mapIndices=*/nullptr); printBlockArgRegion(p, op, region, args); } @@ -1105,13 +1133,15 @@ static void printInReductionPrivateReductionRegion( OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, - ArrayAttr privateSyms, ReductionModifierAttr reductionMod, - ValueRange reductionVars, TypeRange reductionTypes, - DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) { + ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, + ReductionModifierAttr reductionMod, ValueRange reductionVars, + TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, + ArrayAttr reductionSyms) { AllRegionPrintArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier, /*mapIndices=*/nullptr); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, reductionSyms, reductionMod); @@ -1120,21 +1150,24 @@ static void printInReductionPrivateReductionRegion( static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, - ArrayAttr privateSyms) { + ArrayAttr privateSyms, + UnitAttr privateNeedsBarrier) { AllRegionPrintArgs args; args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier, /*mapIndices=*/nullptr); printBlockArgRegion(p, op, region, args); } static void printPrivateReductionRegion( OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, - TypeRange privateTypes, ArrayAttr privateSyms, + TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) { AllRegionPrintArgs args; args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + privateNeedsBarrier, /*mapIndices=*/nullptr); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, reductionSyms, reductionMod); @@ -1884,7 +1917,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit, + makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.threadLimit, /*private_maps=*/nullptr); } @@ -2149,7 +2183,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), - /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr, + /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, + /*proc_bind_kind=*/nullptr, /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr); state.addAttributes(attributes); @@ -2161,8 +2196,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, clauses.ifExpr, clauses.numThreads, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.procBindKind, clauses.reductionMod, - clauses.reductionVars, + clauses.privateNeedsBarrier, clauses.procBindKind, + clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms)); } @@ -2266,11 +2301,12 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) { void TeamsOp::build(OpBuilder &builder, OperationState &state, const TeamsOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: privateVars, privateSyms. + // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, /*private_vars=*/{}, /*private_syms=*/nullptr, - clauses.reductionMod, clauses.reductionVars, + /*private_needs_barrier=*/nullptr, clauses.reductionMod, + clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit); @@ -2327,11 +2363,11 @@ OperandRange SectionOp::getReductionVars() { void SectionsOp::build(OpBuilder &builder, OperationState &state, const SectionsOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: privateVars, privateSyms. + // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, clauses.nowait, /*private_vars=*/{}, - /*private_syms=*/nullptr, clauses.reductionMod, - clauses.reductionVars, + /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, + clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms)); } @@ -2363,11 +2399,12 @@ LogicalResult SectionsOp::verifyRegions() { void SingleOp::build(OpBuilder &builder, OperationState &state, const SingleOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: privateVars, privateSyms. + // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, clauses.copyprivateVars, makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait, - /*private_vars=*/{}, /*private_syms=*/nullptr); + /*private_vars=*/{}, /*private_syms=*/nullptr, + /*private_needs_barrier=*/nullptr); } LogicalResult SingleOp::verify() { @@ -2443,8 +2480,9 @@ void LoopOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), clauses.order, - clauses.orderMod, clauses.reductionMod, clauses.reductionVars, + makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.order, clauses.orderMod, + clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms)); } @@ -2472,6 +2510,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, + /*private_needs_barrier=*/false, /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr, @@ -2483,18 +2522,17 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, void WsloopOp::build(OpBuilder &builder, OperationState &state, const WsloopOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars, - // privateSyms. - WsloopOp::build(builder, state, - /*allocate_vars=*/{}, /*allocator_vars=*/{}, - clauses.linearVars, clauses.linearStepVars, clauses.nowait, - clauses.order, clauses.orderMod, clauses.ordered, - clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), - clauses.scheduleKind, clauses.scheduleChunk, - clauses.scheduleMod, clauses.scheduleSimd); + // TODO: Store clauses in op: allocateVars, allocatorVars + WsloopOp::build( + builder, state, + /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars, + clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod, + clauses.ordered, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, + clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind, + clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd); } LogicalResult WsloopOp::verify() { @@ -2534,14 +2572,14 @@ LogicalResult WsloopOp::verifyRegions() { void SimdOp::build(OpBuilder &builder, OperationState &state, const SimdOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: linearVars, linearStepVars, privateVars, - // privateSyms. + // TODO Store clauses in op: linearVars, linearStepVars SimdOp::build(builder, state, clauses.alignedVars, makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, /*linear_vars=*/{}, /*linear_step_vars=*/{}, clauses.nontemporalVars, clauses.order, clauses.orderMod, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.reductionMod, clauses.reductionVars, + clauses.privateNeedsBarrier, clauses.reductionMod, + clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, clauses.simdlen); @@ -2591,7 +2629,8 @@ void DistributeOp::build(OpBuilder &builder, OperationState &state, clauses.allocatorVars, clauses.distScheduleStatic, clauses.distScheduleChunkSize, clauses.order, clauses.orderMod, clauses.privateVars, - makeArrayAttr(builder.getContext(), clauses.privateSyms)); + makeArrayAttr(builder.getContext(), clauses.privateSyms), + clauses.privateNeedsBarrier); } LogicalResult DistributeOp::verify() { @@ -2747,7 +2786,8 @@ void TaskOp::build(OpBuilder &builder, OperationState &state, makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, clauses.priority, /*private_vars=*/clauses.privateVars, /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), - clauses.untied, clauses.eventHandle); + clauses.privateNeedsBarrier, clauses.untied, + clauses.eventHandle); } LogicalResult TaskOp::verify() { @@ -2786,18 +2826,18 @@ LogicalResult TaskgroupOp::verify() { void TaskloopOp::build(OpBuilder &builder, OperationState &state, const TaskloopOperands &clauses) { MLIRContext *ctx = builder.getContext(); - TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.final, clauses.grainsizeMod, clauses.grainsize, - clauses.ifExpr, clauses.inReductionVars, - makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), - makeArrayAttr(ctx, clauses.inReductionSyms), - clauses.mergeable, clauses.nogroup, clauses.numTasksMod, - clauses.numTasks, clauses.priority, - /*private_vars=*/clauses.privateVars, - /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), - clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied); + TaskloopOp::build( + builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr, + clauses.inReductionVars, + makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), + makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, + clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority, + /*private_vars=*/clauses.privateVars, + /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied); } LogicalResult TaskloopOp::verify() { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index b7e16b7ec35e2..3eef3799c4b45 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2872,6 +2872,23 @@ func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { return } +// CHECK-LABEL: parallel_op_privatizers_barrier +// CHECK-SAME: (%[[ARG0:[^[:space:]]+]]: !llvm.ptr, %[[ARG1:[^[:space:]]+]]: !llvm.ptr) +func.func @parallel_op_privatizers_barrier(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { + // CHECK: omp.parallel private( + // CHECK-SAME: @x.privatizer %[[ARG0]] -> %[[ARG0_PRIV:[^[:space:]]+]], + // CHECK-SAME: @y.privatizer %[[ARG1]] -> %[[ARG1_PRIV:[^[:space:]]+]] : !llvm.ptr, !llvm.ptr) + // CHECK-SAME: private_barrier + omp.parallel private(@x.privatizer %arg0 -> %arg2, @y.privatizer %arg1 -> %arg3 : !llvm.ptr, !llvm.ptr) private_barrier { + // CHECK: llvm.load %[[ARG0_PRIV]] + %0 = llvm.load %arg2 : !llvm.ptr -> i32 + // CHECK: llvm.load %[[ARG1_PRIV]] + %1 = llvm.load %arg3 : !llvm.ptr -> i32 + omp.terminator + } + return +} + // CHECK-LABEL: omp.private {type = private} @a.privatizer : !llvm.ptr init { omp.private {type = private} @a.privatizer : !llvm.ptr init { // CHECK: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}):