Skip to content

[llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop #139386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

NimishMishra
Copy link
Contributor

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).

@llvmbot
Copy link
Member

llvmbot commented May 10, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-flang-openmp

Author: None (NimishMishra)

Changes

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).


Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+34)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+3-2)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-2)
  • (added) flang/test/Lower/OpenMP/wsloop-linear.f90 (+57)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+3)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+183-2)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+88)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-13)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 10, 2025

@llvm/pr-subscribers-mlir-llvm

Author: None (NimishMishra)

Changes

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).


Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+34)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+3-2)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-2)
  • (added) flang/test/Lower/OpenMP/wsloop-linear.f90 (+57)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+3)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+183-2)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+88)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-13)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 10, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (NimishMishra)

Changes

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).


Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+34)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+3-2)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-2)
  • (added) flang/test/Lower/OpenMP/wsloop-linear.f90 (+57)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+3)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+183-2)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+88)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-13)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]

@NimishMishra
Copy link
Contributor Author

The following test case from OpenMP examples document compiles successfully with a combination of PR #139385 and this PR:

program linear_loop
 use omp_lib
 implicit none
 integer, parameter :: N = 100
 real :: a(N), b(N/2)
 integer :: i, j

 do i = 1, N
 a(i) = i
 end do

 j = 0
 !$omp parallel
 !$omp do linear(j:1)
 do i = 1, N, 2
 j = j + 1
 b(j) = a(i) * 2.0
 end do
 !$omp end parallel

 print *, j, b(1), b(j)
 ! print out: 50 2.0 198.0

 end program

Flang output: 50 2. 198. as expected.

@NimishMishra
Copy link
Contributor Author

This PR is stacked over #139385 to allow for easy testing. I intend to merge this PR only after #139385 is merged.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the progress so far

linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
linearLoopBodyTemps.push_back(linearLoopBodyTemp);
linearOrigVars.push_back(linearVarAlloca);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be some sort of error reporting to catch cases where the linear variable is not an alloca?

What about if the variable was passed as a function argument?

Nothing in the OpenMP MLIR dialect currently guarantees or documents that these variables are passed by-address into the clause, that's just how flang happens to work. I would be open to adding this requirement (as we have done this already for privatization and some reductions). But if you depend on this it does need to be documented. This is supposed to be generic code which supports any valid OpenMP MLIR code.

builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),

linearPreconditionVars[index]);
auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works for integer variables. For floating point you would need fmul. Same for add and fadd.

// Emit condition to check whether last logical iteration is being executed
builder.SetInsertPoint(linearFinalizationBB->getTerminator());
llvm::Value *loopLastIterLoad = builder.CreateLoad(
llvm::Type::getInt32Ty(builder.getContext()), lastIter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this always i32?


// Initialize linear variables and linear step
LinearClauseProcessor linearClauseProcessor;
if (wsloopOp.getLinearVars().size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This if is unnecessary. The for loops will not execute if there are no linear vars.


// Linear variable finalization is conditional on the last logical iteration.
// Create BB splits to manage the same.
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me at least, this isn't what I expect from the term "outline". For me, "outline" would mean moving blocks into a different function. Perhaps a better name would be splitLinearFiniBlock.

Feel free to ignore if others disagree.

@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
BasicBlock *Latch = nullptr;
BasicBlock *Exit = nullptr;

// Hold the MLIR value for the `lastiter` of the canonical loop.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Hold the MLIR value for the `lastiter` of the canonical loop.
// Hold the LLVM value for the `lastiter` of the canonical loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants