Skip to content

[flang] optionally add lifetime markers to alloca created in stack-arrays #140901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ llvm::SmallVector<mlir::Value>
elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);

/// Get the address space which should be used for allocas
uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);
uint64_t getAllocaAddressSpace(const mlir::DataLayout *dataLayout);

/// The two vectors of MLIR values have the following property:
/// \p extents1[i] must have the same value as \p extents2[i]
Expand Down Expand Up @@ -913,6 +913,18 @@ void genDimInfoFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::SmallVectorImpl<mlir::Value> *extents,
llvm::SmallVectorImpl<mlir::Value> *strides);

/// Generate an LLVM dialect lifetime start marker at the current insertion
/// point given an fir.alloca and its constant size in bytes. Returns the value
/// to be passed to the lifetime end marker.
mlir::Value genLifetimeStart(mlir::OpBuilder &builder, mlir::Location loc,
fir::AllocaOp alloc, int64_t size,
const mlir::DataLayout *dl);

/// Generate an LLVM dialect lifetime end marker at the current insertion point
/// given an llvm.ptr value and the constant size in bytes of its storage.
void genLifetimeEnd(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value mem, int64_t size);

} // namespace fir::factory

#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
12 changes: 12 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ static constexpr llvm::StringRef getInternalFuncNameAttrName() {
return "fir.internal_name";
}

/// Attribute to mark alloca that have been given a lifetime marker so that
/// later pass do not try adding new ones.
static constexpr llvm::StringRef getHasLifetimeMarkerAttrName() {
return "fir.has_lifetime";
}

/// Does the function, \p func, have a host-associations tuple argument?
/// Some internal procedures may have access to host procedure variables.
bool hasHostAssociationArgument(mlir::func::FuncOp func);
Expand Down Expand Up @@ -221,6 +227,12 @@ inline bool hasBindcAttr(mlir::Operation *op) {
return hasProcedureAttr<fir::FortranProcedureFlagsEnum::bind_c>(op);
}

/// Get the allocation size of a given alloca if it has compile time constant
/// size.
std::optional<int64_t> getAllocaByteSize(fir::AllocaOp alloca,
const mlir::DataLayout &dl,
const fir::KindMapping &kindMap);

/// Return true, if \p rebox operation keeps the input array
/// continuous if it is initially continuous.
/// When \p checkWhole is false, then the checking is only done
Expand Down
4 changes: 3 additions & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def StackArrays : Pass<"stack-arrays", "mlir::func::FuncOp"> {
Convert heap allocations for arrays, even those of unknown size, into stack
allocations.
}];
let dependentDialects = [ "fir::FIROpsDialect" ];
let dependentDialects = [
"fir::FIROpsDialect", "mlir::DLTIDialect", "mlir::LLVM::LLVMDialect"
];
}

def StackReclaim : Pass<"stack-reclaim"> {
Expand Down
20 changes: 19 additions & 1 deletion flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,8 @@ void fir::factory::setInternalLinkage(mlir::func::FuncOp func) {
func->setAttr("llvm.linkage", linkage);
}

uint64_t fir::factory::getAllocaAddressSpace(mlir::DataLayout *dataLayout) {
uint64_t
fir::factory::getAllocaAddressSpace(const mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
Expand Down Expand Up @@ -1940,3 +1941,20 @@ void fir::factory::genDimInfoFromBox(
strides->push_back(dimInfo.getByteStride());
}
}

mlir::Value fir::factory::genLifetimeStart(mlir::OpBuilder &builder,
mlir::Location loc,
fir::AllocaOp alloc, int64_t size,
const mlir::DataLayout *dl) {
mlir::Type ptrTy = mlir::LLVM::LLVMPointerType::get(
alloc.getContext(), getAllocaAddressSpace(dl));
mlir::Value cast =
builder.create<fir::ConvertOp>(loc, ptrTy, alloc.getResult());
builder.create<mlir::LLVM::LifetimeStartOp>(loc, size, cast);
return cast;
}

void fir::factory::genLifetimeEnd(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value cast, int64_t size) {
builder.create<mlir::LLVM::LifetimeEndOp>(loc, size, cast);
}
13 changes: 13 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4804,6 +4804,19 @@ bool fir::reboxPreservesContinuity(fir::ReboxOp rebox, bool checkWhole) {
return false;
}

std::optional<int64_t> fir::getAllocaByteSize(fir::AllocaOp alloca,
const mlir::DataLayout &dl,
const fir::KindMapping &kindMap) {
mlir::Type type = alloca.getInType();
// TODO: should use the constant operands when all info is not available in
// the type.
if (!alloca.isDynamic())
if (auto sizeAndAlignment =
getTypeSizeAndAlignment(alloca.getLoc(), type, dl, kindMap))
return sizeAndAlignment->first;
return std::nullopt;
}

//===----------------------------------------------------------------------===//
// DeclareOp
//===----------------------------------------------------------------------===//
Expand Down
100 changes: 78 additions & 22 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -48,6 +51,11 @@ static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
"to 0 for no limit."),
llvm::cl::init(1000), llvm::cl::Hidden);

static llvm::cl::opt<bool> emitLifetimeMarkers(
"stack-arrays-lifetime",
llvm::cl::desc("Add lifetime markers to generated constant size allocas"),
llvm::cl::init(false), llvm::cl::Hidden);

namespace {

/// The state of an SSA value at each program point
Expand Down Expand Up @@ -189,8 +197,11 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
public:
explicit AllocMemConversion(
mlir::MLIRContext *ctx,
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
: OpRewritePattern(ctx), candidateOps{candidateOps} {}
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps,
std::optional<mlir::DataLayout> &dl,
std::optional<fir::KindMapping> &kindMap)
: OpRewritePattern(ctx), candidateOps{candidateOps}, dl{dl},
kindMap{kindMap} {}

llvm::LogicalResult
matchAndRewrite(fir::AllocMemOp allocmem,
Expand All @@ -206,6 +217,9 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
/// Handle to the DFA (already run)
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;

const std::optional<mlir::DataLayout> &dl;
const std::optional<fir::KindMapping> &kindMap;

/// If we failed to find an insertion point not inside a loop, see if it would
/// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
static InsertionPoint findAllocaLoopInsertionPoint(
Expand All @@ -218,8 +232,12 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
mlir::PatternRewriter &rewriter) const;

/// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
void insertStackSaveRestore(fir::AllocMemOp oldAlloc,
mlir::PatternRewriter &rewriter) const;
/// Emit lifetime markers for newAlloc between oldAlloc and each freemem.
/// If the allocation is dynamic, no life markers are emitted.
void insertLifetimeMarkers(fir::AllocMemOp oldAlloc, fir::AllocaOp newAlloc,
mlir::PatternRewriter &rewriter) const;
};

class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
Expand Down Expand Up @@ -740,14 +758,34 @@ AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,

llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
oldAlloc.getTypeparams(),
oldAlloc.getShape());
auto alloca = rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
oldAlloc.getTypeparams(),
oldAlloc.getShape());
if (emitLifetimeMarkers)
insertLifetimeMarkers(oldAlloc, alloca, rewriter);

return alloca;
}

static void
visitFreeMemOp(fir::AllocMemOp oldAlloc,
const std::function<void(mlir::Operation *)> &callBack) {
for (mlir::Operation *user : oldAlloc->getUsers()) {
if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
for (mlir::Operation *user : declareOp->getUsers()) {
if (mlir::isa<fir::FreeMemOp>(user))
callBack(user);
}
}

if (mlir::isa<fir::FreeMemOp>(user))
callBack(user);
}
}

void AllocMemConversion::insertStackSaveRestore(
fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
auto oldPoint = rewriter.saveInsertionPoint();
fir::AllocMemOp oldAlloc, mlir::PatternRewriter &rewriter) const {
mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder{rewriter, mod};

Expand All @@ -758,21 +796,30 @@ void AllocMemConversion::insertStackSaveRestore(
builder.setInsertionPoint(user);
builder.genStackRestore(user->getLoc(), sp);
};
visitFreeMemOp(oldAlloc, createStackRestoreCall);
}

for (mlir::Operation *user : oldAlloc->getUsers()) {
if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
for (mlir::Operation *user : declareOp->getUsers()) {
if (mlir::isa<fir::FreeMemOp>(user))
createStackRestoreCall(user);
}
}

if (mlir::isa<fir::FreeMemOp>(user)) {
createStackRestoreCall(user);
}
void AllocMemConversion::insertLifetimeMarkers(
fir::AllocMemOp oldAlloc, fir::AllocaOp newAlloc,
mlir::PatternRewriter &rewriter) const {
if (!dl || !kindMap)
return;
llvm::StringRef attrName = fir::getHasLifetimeMarkerAttrName();
// Do not add lifetime markers if the alloca already has any.
if (newAlloc->hasAttr(attrName))
return;
if (std::optional<int64_t> size =
fir::getAllocaByteSize(newAlloc, *dl, *kindMap)) {
mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(oldAlloc);
mlir::Value ptr = fir::factory::genLifetimeStart(
rewriter, newAlloc.getLoc(), newAlloc, *size, &*dl);
visitFreeMemOp(oldAlloc, [&](mlir::Operation *op) {
rewriter.setInsertionPoint(op);
fir::factory::genLifetimeEnd(rewriter, op->getLoc(), ptr, *size);
});
newAlloc->setAttr(attrName, rewriter.getUnitAttr());
}

rewriter.restoreInsertionPoint(oldPoint);
}

StackArraysPass::StackArraysPass(const StackArraysPass &pass)
Expand Down Expand Up @@ -809,7 +856,16 @@ void StackArraysPass::runOnOperation() {
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

patterns.insert<AllocMemConversion>(&context, *candidateOps);
auto module = func->getParentOfType<mlir::ModuleOp>();
std::optional<mlir::DataLayout> dl =
module ? fir::support::getOrSetMLIRDataLayout(
module, /*allowDefaultLayout=*/false)
: std::nullopt;
std::optional<fir::KindMapping> kindMap;
if (module)
kindMap = fir::getKindMapping(module);

patterns.insert<AllocMemConversion>(&context, *candidateOps, dl, kindMap);
if (mlir::failed(mlir::applyOpPatternsGreedily(
opsToConvert, std::move(patterns), config))) {
mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
Expand Down
96 changes: 96 additions & 0 deletions flang/test/Transforms/stack-arrays-lifetime.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Test insertion of llvm.lifetime for allocmem turn into alloca with constant size.
// RUN: fir-opt --stack-arrays -stack-arrays-lifetime %s | FileCheck %s

module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"} {

func.func @_QPcst_alloca(%arg0: !fir.ref<!fir.array<100000xf32>> {fir.bindc_name = "x"}) {
%c1 = arith.constant 1 : index
%c100000 = arith.constant 100000 : index
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.shape %c100000 : (index) -> !fir.shape<1>
%2 = fir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFcst_allocaEx"} : (!fir.ref<!fir.array<100000xf32>>, !fir.shape<1>, !fir.dscope) -> !fir.ref<!fir.array<100000xf32>>
%3 = fir.allocmem !fir.array<100000xf32> {bindc_name = ".tmp.array", uniq_name = ""}
%4 = fir.declare %3(%1) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<100000xf32>>, !fir.shape<1>) -> !fir.heap<!fir.array<100000xf32>>
fir.do_loop %arg1 = %c1 to %c100000 step %c1 unordered {
%9 = fir.array_coor %2(%1) %arg1 : (!fir.ref<!fir.array<100000xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
%10 = fir.load %9 : !fir.ref<f32>
%11 = arith.addf %10, %10 fastmath<contract> : f32
%12 = fir.array_coor %4(%1) %arg1 : (!fir.heap<!fir.array<100000xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
fir.store %11 to %12 : !fir.ref<f32>
}
%5 = fir.convert %4 : (!fir.heap<!fir.array<100000xf32>>) -> !fir.ref<!fir.array<100000xf32>>
fir.call @_QPbar(%5) fastmath<contract> : (!fir.ref<!fir.array<100000xf32>>) -> ()
fir.freemem %4 : !fir.heap<!fir.array<100000xf32>>
%6 = fir.allocmem !fir.array<100000xi32> {bindc_name = ".tmp.array", uniq_name = ""}
%7 = fir.declare %6(%1) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<100000xi32>>, !fir.shape<1>) -> !fir.heap<!fir.array<100000xi32>>
fir.do_loop %arg1 = %c1 to %c100000 step %c1 unordered {
%9 = fir.array_coor %2(%1) %arg1 : (!fir.ref<!fir.array<100000xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
%10 = fir.load %9 : !fir.ref<f32>
%11 = fir.convert %10 : (f32) -> i32
%12 = fir.array_coor %7(%1) %arg1 : (!fir.heap<!fir.array<100000xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
fir.store %11 to %12 : !fir.ref<i32>
}
%8 = fir.convert %7 : (!fir.heap<!fir.array<100000xi32>>) -> !fir.ref<!fir.array<100000xi32>>
fir.call @_QPibar(%8) fastmath<contract> : (!fir.ref<!fir.array<100000xi32>>) -> ()
fir.freemem %7 : !fir.heap<!fir.array<100000xi32>>
return
}
// CHECK-LABEL: func.func @_QPcst_alloca(
// CHECK-DAG: %[[VAL_0:.*]] = fir.alloca !fir.array<100000xf32> {bindc_name = ".tmp.array", fir.has_lifetime}
// CHECK-DAG: %[[VAL_2:.*]] = fir.alloca !fir.array<100000xi32> {bindc_name = ".tmp.array", fir.has_lifetime}
// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<100000xf32>>) -> !llvm.ptr
// CHECK: llvm.intr.lifetime.start 400000, %[[VAL_9]] : !llvm.ptr
// CHECK: fir.do_loop
// CHECK: fir.call @_QPbar(
// CHECK: llvm.intr.lifetime.end 400000, %[[VAL_9]] : !llvm.ptr
// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<100000xi32>>) -> !llvm.ptr
// CHECK: llvm.intr.lifetime.start 400000, %[[VAL_17]] : !llvm.ptr
// CHECK: fir.do_loop
// CHECK: fir.call @_QPibar(
// CHECK: llvm.intr.lifetime.end 400000, %[[VAL_17]] : !llvm.ptr


func.func @_QPdyn_alloca(%arg0: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i64> {fir.bindc_name = "n"}) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QFdyn_allocaEn"} : (!fir.ref<i64>, !fir.dscope) -> !fir.ref<i64>
%2 = fir.load %1 : !fir.ref<i64>
%3 = fir.convert %2 : (i64) -> index
%4 = arith.cmpi sgt, %3, %c0 : index
%5 = arith.select %4, %3, %c0 : index
%6 = fir.shape %5 : (index) -> !fir.shape<1>
%7 = fir.declare %arg0(%6) dummy_scope %0 {uniq_name = "_QFdyn_allocaEx"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> !fir.ref<!fir.array<?xf32>>
%8 = fir.allocmem !fir.array<?xf32>, %5 {bindc_name = ".tmp.array", uniq_name = ""}
%9 = fir.declare %8(%6) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.heap<!fir.array<?xf32>>
fir.do_loop %arg2 = %c1 to %5 step %c1 unordered {
%14 = fir.array_coor %7(%6) %arg2 : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
%15 = fir.load %14 : !fir.ref<f32>
%16 = arith.addf %15, %15 fastmath<contract> : f32
%17 = fir.array_coor %9(%6) %arg2 : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
fir.store %16 to %17 : !fir.ref<f32>
}
%10 = fir.convert %9 : (!fir.heap<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
fir.call @_QPbar(%10) fastmath<contract> : (!fir.ref<!fir.array<?xf32>>) -> ()
fir.freemem %9 : !fir.heap<!fir.array<?xf32>>
%11 = fir.allocmem !fir.array<?xf32>, %5 {bindc_name = ".tmp.array", uniq_name = ""}
%12 = fir.declare %11(%6) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.heap<!fir.array<?xf32>>
fir.do_loop %arg2 = %c1 to %5 step %c1 unordered {
%14 = fir.array_coor %7(%6) %arg2 : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
%15 = fir.load %14 : !fir.ref<f32>
%16 = arith.mulf %15, %15 fastmath<contract> : f32
%17 = fir.array_coor %12(%6) %arg2 : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
fir.store %16 to %17 : !fir.ref<f32>
}
%13 = fir.convert %12 : (!fir.heap<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
fir.call @_QPbar(%13) fastmath<contract> : (!fir.ref<!fir.array<?xf32>>) -> ()
fir.freemem %12 : !fir.heap<!fir.array<?xf32>>
return
}
// CHECK-LABEL: func.func @_QPdyn_alloca(
// CHECK-NOT: llvm.intr.lifetime.start
// CHECK: return

func.func private @_QPbar(!fir.ref<!fir.array<100000xf32>>)
func.func private @_QPibar(!fir.ref<!fir.array<100000xi32>>)
}
Loading