Skip to content

[llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop #139386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
});
}

bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
lower::StatementContext stmtCtx;
return findRepeatableClause<
omp::clause::Linear>([&](const omp::clause::Linear &clause,
const parser::CharBlock &) {
auto &objects = std::get<omp::ObjectList>(clause.t);
for (const omp::Object &object : objects) {
semantics::Symbol *sym = object.sym();
const mlir::Value variable = converter.getSymbolAddress(*sym);
result.linearVars.push_back(variable);
}
if (objects.size()) {
if (auto &mod =
std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
clause.t)) {
mlir::Value operand =
fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
result.linearStepVars.append(objects.size(), operand);
} else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
clause.t)) {
mlir::Location currentLocation = converter.getCurrentLocation();
TODO(currentLocation, "Linear modifiers not yet implemented");
} else {
// If nothing is present, add the default step of 1.
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
mlir::Value operand = firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getI32Type(), 1);
result.linearStepVars.append(objects.size(), operand);
}
}
});
}

bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class ClauseProcessor {
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;

Expand Down
5 changes: 3 additions & 2 deletions flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
// so, we won't need to explicitely handle block objects (or forget to do
// so).
for (auto *sym : explicitlyPrivatizedSymbols)
allPrivatizedSymbols.insert(sym);
if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
allPrivatizedSymbols.insert(sym);
}

bool DataSharingProcessor::needBarrier() {
// Emit implicit barrier to synchronize threads and avoid data races on
// initialization of firstprivate variables and post-update of lastprivate
// variables.
// Emit implicit barrier for linear clause. Maybe on somewhere else.
// Emit implicit barrier for linear clause in the OpenMPIRBuilder.
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
(sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1841,13 +1841,13 @@ static void genWsloopClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processNowait(clauseOps);
cp.processLinear(clauseOps);
cp.processOrder(clauseOps);
cp.processOrdered(clauseOps);
cp.processReduction(loc, clauseOps, reductionSyms);
cp.processSchedule(stmtCtx, clauseOps);

cp.processTODO<clause::Allocate, clause::Linear>(
loc, llvm::omp::Directive::OMPD_do);
cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
}

//===----------------------------------------------------------------------===//
Expand Down
57 changes: 57 additions & 0 deletions flang/test/Lower/OpenMP/wsloop-linear.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
! This test checks lowering of OpenMP DO Directive (Worksharing)
! with linear clause

! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s

!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[const:.*]] = arith.constant 1 : i32
subroutine simple_linear
implicit none
integer :: x, y, i
!CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
!$omp do linear(x)
!CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
!CHECK: %[[const:.*]] = arith.constant 2 : i32
!CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
do i = 1, 10
y = x + 2
end do
!$omp end do
end subroutine


!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
subroutine linear_step
implicit none
integer :: x, y, i
!CHECK: %[[const:.*]] = arith.constant 4 : i32
!CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
!$omp do linear(x:4)
!CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
!CHECK: %[[const:.*]] = arith.constant 2 : i32
!CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
do i = 1, 10
y = x + 2
end do
!$omp end do
end subroutine

!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
subroutine linear_expr
implicit none
integer :: x, y, i, a
!CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
!CHECK: %[[const:.*]] = arith.constant 4 : i32
!CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
!CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
!$omp do linear(x:a+4)
do i = 1, 10
y = x + 2
end do
!$omp end do
end subroutine
15 changes: 15 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
BasicBlock *Latch = nullptr;
BasicBlock *Exit = nullptr;

// Hold the MLIR value for the `lastiter` of the canonical loop.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Hold the MLIR value for the `lastiter` of the canonical loop.
// Hold the LLVM value for the `lastiter` of the canonical loop.

Value *LastIter = nullptr;

/// Add the control blocks of this loop to \p BBs.
///
/// This does not include any block from the body, including the one returned
Expand Down Expand Up @@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);

public:
/// Sets the last iteration variable for this loop.
void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }

/// Returns the last iteration variable for this loop.
/// Certain use-cases (like translation of linear clause) may access
/// this variable even after a loop transformation. Hence, do not guard
/// this getter function by `isValid`. It is the responsibility of the
/// callee to ensure this functionality is not invoked by a non-outlined
/// CanonicalLoopInfo object (in which case, `setLastIter` will never be
/// invoked and `LastIter` will be by default `nullptr`).
Value *getLastIter() { return LastIter; }

/// Returns whether this object currently represents the IR of a loop. If
/// returning false, it may have been consumed by a loop transformation or not
/// been intialized. Do not use in this case;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
CLI->setLastIter(PLastIter);

// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
Expand Down Expand Up @@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
Value *PUpperBound =
Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
CLI->setLastIter(PLastIter);

// Set up the source location value for the OpenMP runtime.
Builder.restoreIP(CLI->getPreheaderIP());
Expand Down Expand Up @@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
CLI->setLastIter(PLastIter);

// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
Expand Down
Loading
Loading