Skip to content

[Clang] [OpenMP] Support NOWAIT with optional argument #135030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
58 changes: 51 additions & 7 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -2216,18 +2216,62 @@ class OMPOrderedClause final
/// This represents 'nowait' clause in the '#pragma omp ...' directive.
///
/// \code
/// #pragma omp for nowait
/// #pragma omp for nowait (cond)
/// \endcode
/// In this example directive '#pragma omp for' has 'nowait' clause.
class OMPNowaitClause final : public OMPNoChildClause<llvm::omp::OMPC_nowait> {
/// In this example directive '#pragma omp for' has simple 'nowait' clause with
/// condition 'cond'.
class OMPNowaitClause final : public OMPClause {
friend class OMPClauseReader;

/// Location of '('.
SourceLocation LParenLoc;

/// Condition of the 'nowait' clause.
Stmt *Condition = nullptr;

/// Set condition.
void setCondition(Expr *Cond) { Condition = Cond; }

public:
/// Build 'nowait' clause.
/// Build 'nowait' clause with condition \a Cond.
///
/// \param Cond Condition of the clause.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPNowaitClause(SourceLocation StartLoc = SourceLocation(),
SourceLocation EndLoc = SourceLocation())
: OMPNoChildClause(StartLoc, EndLoc) {}
OMPNowaitClause(Expr *Cond, SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_nowait, StartLoc, EndLoc),
LParenLoc(LParenLoc), Condition(Cond) {}

/// Build an empty clause.
OMPNowaitClause()
: OMPClause(llvm::omp::OMPC_nowait, SourceLocation(), SourceLocation()) {}

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Returns condition.
Expr *getCondition() const { return cast_or_null<Expr>(Condition); }

child_range children() { return child_range(&Condition, &Condition + 1); }

const_child_range children() const {
return const_child_range(&Condition, &Condition + 1);
}

child_range used_children();
const_child_range used_children() const {
auto Children = const_cast<OMPNowaitClause *>(this)->used_children();
return const_child_range(Children.begin(), Children.end());
}

static bool classof(const OMPClause *T) {
return T->getClauseKind() == llvm::omp::OMPC_nowait;
}
};

/// This represents 'untied' clause in the '#pragma omp ...' directive.
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3482,7 +3482,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPOrderedClause(OMPOrderedClause *C) {
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNowaitClause(OMPNowaitClause *) {
bool RecursiveASTVisitor<Derived>::VisitOMPNowaitClause(OMPNowaitClause *C) {
TRY_TO(TraverseStmt(C->getCondition()));
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -11530,6 +11530,8 @@ def note_omp_nested_teams_construct_here : Note<
"nested teams construct here">;
def note_omp_nested_statement_here : Note<
"%select{statement|directive}0 outside teams construct here">;
def err_omp_nowait_with_arg_unsupported : Error<
"'nowait' clause with a conditional expression requires OpenMP6.0">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"'nowait' clause with a conditional expression requires OpenMP6.0">;
"'nowait' clause with a conditional expression requires OpenMP version 6.0">;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this error even necessary? The parser for OpenMP should not accept the form nowait(cond) unless -fopenmp-version=6.0 is set on the command line.

def err_omp_single_copyprivate_with_nowait : Error<
"the 'copyprivate' clause must not be used with the 'nowait' clause">;
def err_omp_nowait_clause_without_depend: Error<
Expand Down
6 changes: 4 additions & 2 deletions clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1001,8 +1001,10 @@ class SemaOpenMP : public SemaBase {
OMPClause *ActOnOpenMPClause(OpenMPClauseKind Kind, SourceLocation StartLoc,
SourceLocation EndLoc);
/// Called on well-formed 'nowait' clause.
OMPClause *ActOnOpenMPNowaitClause(SourceLocation StartLoc,
SourceLocation EndLoc);
OMPClause *
ActOnOpenMPNowaitClause(SourceLocation StartLoc, SourceLocation EndLoc,
SourceLocation LParenLoc = SourceLocation(),
Expr *Condition = nullptr);
/// Called on well-formed 'untied' clause.
OMPClause *ActOnOpenMPUntiedClause(SourceLocation StartLoc,
SourceLocation EndLoc);
Expand Down
17 changes: 16 additions & 1 deletion clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ OMPClause::child_range OMPIfClause::used_children() {
return child_range(&Condition, &Condition + 1);
}

/*OMPClause::child_range OMPNowaitClause::used_children() {
return child_range(&Condition, &Condition + 1);
}*/
OMPClause::child_range OMPNowaitClause::used_children() {
if (Condition)
return child_range(&Condition, &Condition + 1);
Stmt *Null = nullptr;
return child_range(&Null, &Null);
}

OMPClause::child_range OMPGrainsizeClause::used_children() {
if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt()))
return child_range(C, C + 1);
Expand Down Expand Up @@ -1995,8 +2005,13 @@ void OMPClausePrinter::VisitOMPOrderedClause(OMPOrderedClause *Node) {
}
}

void OMPClausePrinter::VisitOMPNowaitClause(OMPNowaitClause *) {
void OMPClausePrinter::VisitOMPNowaitClause(OMPNowaitClause *Node) {
OS << "nowait";
if (auto *Cond = Node->getCondition()) {
OS << "(";
Cond->printPretty(OS, nullptr, Policy, 0);
OS << ")";
}
}

void OMPClausePrinter::VisitOMPUntiedClause(OMPUntiedClause *) {
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,10 @@ void OMPClauseProfiler::VisitOMPOrderedClause(const OMPOrderedClause *C) {
Profiler->VisitStmt(Num);
}

void OMPClauseProfiler::VisitOMPNowaitClause(const OMPNowaitClause *) {}
void OMPClauseProfiler::VisitOMPNowaitClause(const OMPNowaitClause *C) {
if (C->getCondition())
Profiler->VisitStmt(C->getCondition());
}

void OMPClauseProfiler::VisitOMPUntiedClause(const OMPUntiedClause *) {}

Expand Down
5 changes: 3 additions & 2 deletions clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3211,6 +3211,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
case OMPC_nowait:
case OMPC_priority:
case OMPC_grainsize:
case OMPC_num_tasks:
Expand Down Expand Up @@ -3258,7 +3259,8 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
ErrorFound = true;
}

if ((CKind == OMPC_ordered || CKind == OMPC_partial) &&
if ((CKind == OMPC_ordered || CKind == OMPC_nowait ||
CKind == OMPC_partial) &&
PP.LookAhead(/*N=*/0).isNot(tok::l_paren))
Clause = ParseOpenMPClause(CKind, WrongDirective);
else if (CKind == OMPC_grainsize || CKind == OMPC_num_tasks)
Expand Down Expand Up @@ -3320,7 +3322,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_holds:
Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective);
break;
case OMPC_nowait:
case OMPC_untied:
case OMPC_mergeable:
case OMPC_read:
Expand Down
31 changes: 28 additions & 3 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15446,6 +15446,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_ordered:
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
case OMPC_nowait:
Res = ActOnOpenMPNowaitClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
case OMPC_priority:
Res = ActOnOpenMPPriorityClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
Expand Down Expand Up @@ -15500,7 +15503,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_aligned:
case OMPC_copyin:
case OMPC_copyprivate:
case OMPC_nowait:
case OMPC_untied:
case OMPC_mergeable:
case OMPC_threadprivate:
Expand Down Expand Up @@ -16959,9 +16961,32 @@ OMPClause *SemaOpenMP::ActOnOpenMPClause(OpenMPClauseKind Kind,
}

OMPClause *SemaOpenMP::ActOnOpenMPNowaitClause(SourceLocation StartLoc,
SourceLocation EndLoc) {
SourceLocation EndLoc,
SourceLocation LParenLoc,
Expr *Condition) {
Expr *ValExpr = Condition;
if (Condition && LParenLoc.isValid()) {
if (!Condition->isValueDependent() && !Condition->isTypeDependent() &&
!Condition->isInstantiationDependent() &&
!Condition->containsUnexpandedParameterPack()) {
ExprResult Val = SemaRef.CheckBooleanCondition(StartLoc, Condition);
if (Val.isInvalid())
return nullptr;

QualType T = ValExpr->getType();
if (T->isFloatingType()) {
SemaRef.Diag(ValExpr->getExprLoc(), diag::err_omp_clause_floating_type_arg)
<< getOpenMPClauseName(OMPC_nowait);
}

ValExpr = Val.get();
}
} else {
ValExpr = nullptr;
}
DSAStack->setNowaitRegion();
return new (getASTContext()) OMPNowaitClause(StartLoc, EndLoc);
return new (getASTContext())
OMPNowaitClause(ValExpr, StartLoc, LParenLoc, EndLoc);
}

OMPClause *SemaOpenMP::ActOnOpenMPUntiedClause(SourceLocation StartLoc,
Expand Down
21 changes: 19 additions & 2 deletions clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1857,6 +1857,17 @@ class TreeTransform {
LParenLoc, Num);
}

/// Build a new OpenMP 'nowait' clause.
///
/// By default, performs semantic analysis to build the new OpenMP clause.
/// Subclasses may override this routine to provide different behavior.
OMPClause *RebuildOMPNowaitClause(Expr *Condition, SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
return getSema().OpenMP().ActOnOpenMPNowaitClause(StartLoc, EndLoc,
LParenLoc, Condition);
}

/// Build a new OpenMP 'private' clause.
///
/// By default, performs semantic analysis to build the new OpenMP clause.
Expand Down Expand Up @@ -10589,8 +10600,14 @@ TreeTransform<Derived>::TransformOMPDetachClause(OMPDetachClause *C) {
template <typename Derived>
OMPClause *
TreeTransform<Derived>::TransformOMPNowaitClause(OMPNowaitClause *C) {
// No need to rebuild this clause, no template-dependent parameters.
return C;
ExprResult Cond;
if (auto *Condition = C->getCondition()) {
Cond = getDerived().TransformExpr(Condition);
if (Cond.isInvalid())
return nullptr;
}
return getDerived().RebuildOMPNowaitClause(Cond.get(), C->getBeginLoc(),
C->getLParenLoc(), C->getEndLoc());
}

template <typename Derived>
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/Serialization/ASTReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11478,7 +11478,10 @@ void OMPClauseReader::VisitOMPDetachClause(OMPDetachClause *C) {
C->setLParenLoc(Record.readSourceLocation());
}

void OMPClauseReader::VisitOMPNowaitClause(OMPNowaitClause *) {}
void OMPClauseReader::VisitOMPNowaitClause(OMPNowaitClause *C) {
C->setCondition(Record.readSubExpr());
C->setLParenLoc(Record.readSourceLocation());
}

void OMPClauseReader::VisitOMPUntiedClause(OMPUntiedClause *) {}

Expand Down
5 changes: 4 additions & 1 deletion clang/lib/Serialization/ASTWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7816,7 +7816,10 @@ void OMPClauseWriter::VisitOMPOrderedClause(OMPOrderedClause *C) {
Record.AddSourceLocation(C->getLParenLoc());
}

void OMPClauseWriter::VisitOMPNowaitClause(OMPNowaitClause *) {}
void OMPClauseWriter::VisitOMPNowaitClause(OMPNowaitClause *C) {
Record.AddStmt(C->getCondition());
Record.AddSourceLocation(C->getLParenLoc());
}

void OMPClauseWriter::VisitOMPUntiedClause(OMPUntiedClause *) {}

Expand Down
56 changes: 56 additions & 0 deletions clang/test/OpenMP/nowait_ast_print.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Check no warnings/errors
// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=60 -fsyntax-only -verify %s
// expected-no-diagnostics

// Check AST and unparsing
// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=60 -ast-dump %s | FileCheck %s --check-prefix=DUMP
// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=60 -ast-print %s | FileCheck %s --check-prefix=PRINT

// Check same results after serialization round-trip
// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=60 -emit-pch -o %t %s
// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=60 -include-pch %t -ast-dump-all %s | FileCheck %s --check-prefix=DUMP
// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=60 -include-pch %t -ast-print %s | FileCheck %s --check-prefix=PRINT

#ifndef HEADER
#define HEADER

void nowait() {
int A=1;

// DUMP: OMPTargetDirective
// DUMP-NEXT: OMPNowaitClause
// DUMP-NEXT: <<<NULL>>>
// PRINT: #pragma omp target nowait
#pragma omp target nowait
{
}

// DUMP: OMPTargetDirective
// DUMP-NEXT: OMPNowaitClause
// DUMP-NEXT: XXBoolLiteralExpr {{.*}} 'bool' false
// PRINT: #pragma omp target nowait(false)
#pragma omp target nowait(false)
{
}

// DUMP: OMPTargetDirective
// DUMP-NEXT: OMPNowaitClause
// DUMP-NEXT: XXBoolLiteralExpr {{.*}} 'bool' true
// PRINT: #pragma omp target nowait(true)
#pragma omp target nowait(true)
{
}

// DUMP: OMPTargetDirective
// DUMP-NEXT: OMPNowaitClause
// DUMP-NEXT: BinaryOperator {{.*}} 'bool' '>'
// DUMP-NEXT: ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// DUMP-NEXT: DeclRefExpr {{.*}} 'int' lvalue Var {{.*}} 'A' 'int'
// DUMP-NEXT: IntegerLiteral {{.*}} 'int' 5
// PRINT: #pragma omp target nowait(A > 5)
#pragma omp target nowait(A>5)
{
}

}
#endif
6 changes: 3 additions & 3 deletions clang/test/OpenMP/target_enter_data_nowait_messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ int main(int argc, char **argv) {
{}
#pragma omp target enter nowait data map(to: i) // expected-error {{expected an OpenMP directive}}
{}
#pragma omp target enter data nowait() map(to: i) // expected-warning {{extra tokens at the end of '#pragma omp target enter data' are ignored}} expected-error {{expected at least one 'map' clause for '#pragma omp target enter data'}}
#pragma omp target enter data nowait() map(to: i) // expected-error {{expected expression}}
{}
#pragma omp target enter data map(to: i) nowait( // expected-warning {{extra tokens at the end of '#pragma omp target enter data' are ignored}}
#pragma omp target enter data map(to: i) nowait( // expected-error {{expected expression}} // expected-error {{expected ')'}} // expected-note {{to match this '('}}
{}
#pragma omp target enter data map(to: i) nowait (argc)) // expected-warning {{extra tokens at the end of '#pragma omp target enter data' are ignored}}
{}
#pragma omp target enter data map(to: i) nowait device (-10u)
{}
#pragma omp target enter data map(to: i) nowait (3.14) device (-10u) // expected-warning {{extra tokens at the end of '#pragma omp target enter data' are ignored}}
#pragma omp target enter data map(to: i) nowait (3.14) device (-10u) // expected-error {{arguments of OpenMP clause 'nowait' with bitwise operators cannot be of floating type}}
{}
#pragma omp target enter data map(to: i) nowait nowait // expected-error {{directive '#pragma omp target enter data' cannot contain more than one 'nowait' clause}}
{}
Expand Down
6 changes: 3 additions & 3 deletions clang/test/OpenMP/target_exit_data_nowait_messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ int main(int argc, char **argv) {
#pragma omp nowait target exit data map(from: i) // expected-error {{expected an OpenMP directive}}
#pragma omp target nowait exit data map(from: i) // expected-warning {{extra tokens at the end of '#pragma omp target' are ignored}}
#pragma omp target exit nowait data map(from: i) // expected-error {{expected an OpenMP directive}}
#pragma omp target exit data nowait() map(from: i) // expected-warning {{extra tokens at the end of '#pragma omp target exit data' are ignored}} expected-error {{expected at least one 'map' clause for '#pragma omp target exit data'}}
#pragma omp target exit data map(from: i) nowait( // expected-warning {{extra tokens at the end of '#pragma omp target exit data' are ignored}}
#pragma omp target exit data nowait() map(from: i) // expected-error {{expected expression}}
#pragma omp target exit data map(from: i) nowait( // expected-error {{expected expression}} // expected-error {{expected ')'}} // expected-note {{to match this '('}}
#pragma omp target exit data map(from: i) nowait (argc)) // expected-warning {{extra tokens at the end of '#pragma omp target exit data' are ignored}}
#pragma omp target exit data map(from: i) nowait device (-10u)
#pragma omp target exit data map(from: i) nowait (3.14) device (-10u) // expected-warning {{extra tokens at the end of '#pragma omp target exit data' are ignored}}
#pragma omp target exit data map(from: i) nowait (3.14) device (-10u) // expected-error {{arguments of OpenMP clause 'nowait' with bitwise operators cannot be of floating type}}
#pragma omp target exit data map(from: i) nowait nowait // expected-error {{directive '#pragma omp target exit data' cannot contain more than one 'nowait' clause}}
#pragma omp target exit data nowait map(from: i) nowait // expected-error {{directive '#pragma omp target exit data' cannot contain more than one 'nowait' clause}}
return 0;
Expand Down
4 changes: 2 additions & 2 deletions clang/test/OpenMP/target_nowait_messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ void foo() {
}

int main(int argc, char **argv) {
#pragma omp target nowait( // expected-warning {{extra tokens at the end of '#pragma omp target' are ignored}}
#pragma omp target nowait(// expected-error {{expected expression}} // expected-error {{expected ')'}} // expected-note {{to match this '('}}
foo();
#pragma omp target nowait (argc)) // expected-warning {{extra tokens at the end of '#pragma omp target' are ignored}}
foo();
#pragma omp target nowait device (-10u)
foo();
#pragma omp target nowait (3.14) device (-10u) // expected-warning {{extra tokens at the end of '#pragma omp target' are ignored}}
#pragma omp target nowait (3.14) device (-10u) // expected-error {{arguments of OpenMP clause 'nowait' with bitwise operators cannot be of floating type}}
foo();

return 0;
Expand Down
Loading
Loading