Skip to content

Commit dc79e2a

Browse files
authored
[flang] avoid introducing iteration dependencies in WHERE and FORALL temporaries (llvm#195053)
This patch improves the addressing of temporaries created when needed for simple FORALL or WHERE as below to not introduce iteration dependencies. ``` subroutine foo(p1, p2, mask) real, pointer :: p1(:), p2(:) logical :: mask(:) where (mask) p1 = p2 end subroutine ``` Instead of using a stack like temporary that uses a counter to push and fetch elements, the loop IVs are directly used to address the temporaries. This makes it easier to later vectorize or parallelize those loops. This is only done when: - This is not a FORALL with array expressions - The dynamic type is the same at each iterations - The WHERE and FORALL do not create loops of depth more than 15. - If there are FORALLs, their strides are constants 1 or -1. Note that only the addressing is impacted, the stack-like approach already allocated a temporary big enough for all the iterations regardless of the masking. So the temporary size will remain the same. Assisted by: Claude
1 parent 94ca490 commit dc79e2a

9 files changed

Lines changed: 620 additions & 218 deletions

File tree

flang/include/flang/Optimizer/Builder/TemporaryStorage.h

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#ifndef FORTRAN_OPTIMIZER_BUILDER_TEMPORARYSTORAGE_H
2020
#define FORTRAN_OPTIMIZER_BUILDER_TEMPORARYSTORAGE_H
2121

22+
#include "flang/Common/idioms.h"
2223
#include "flang/Optimizer/HLFIR/HLFIROps.h"
2324

2425
namespace fir {
@@ -98,6 +99,34 @@ class HomogeneousScalarStack {
9899
mlir::Value temp;
99100
};
100101

102+
/// Multidimensional temporary indexed directly by the enclosing loop induction
103+
/// variables (innermost loop is the first dimension). The indices passed to
104+
/// pushValue/fetch are interpreted in the array's domain, which is described
105+
/// by a fir.shape_shift built from the loop extents and lower bounds. This
106+
/// avoids the loop-carried counter used by HomogeneousScalarStack, keeping
107+
/// loop iterations independent. Limited to Fortran::common::maxRank dimensions.
108+
class ArrayTemp {
109+
public:
110+
ArrayTemp(mlir::Location loc, fir::FirOpBuilder &builder,
111+
fir::SequenceType declaredType, llvm::ArrayRef<mlir::Value> extents,
112+
llvm::ArrayRef<mlir::Value> lowerBounds,
113+
llvm::ArrayRef<mlir::Value> lengths, bool allocateOnHeap,
114+
llvm::StringRef name);
115+
116+
void pushValue(mlir::Location loc, fir::FirOpBuilder &builder,
117+
mlir::Value value, mlir::ValueRange indices);
118+
void resetFetchPosition(mlir::Location loc, fir::FirOpBuilder &builder) {}
119+
mlir::Value fetch(mlir::Location loc, fir::FirOpBuilder &builder,
120+
mlir::ValueRange indices);
121+
void destroy(mlir::Location loc, fir::FirOpBuilder &builder);
122+
bool canBeFetchedAfterPush() const { return true; }
123+
124+
private:
125+
const bool allocateOnHeap;
126+
mlir::Value temp;
127+
llvm::SmallVector<mlir::Value> typeParams;
128+
};
129+
101130
/// Structure to hold the value of a single entity.
102131
class SimpleCopy {
103132
public:
@@ -255,16 +284,26 @@ class TemporaryStorage {
255284
TemporaryStorage(T &&impl) : impl{std::forward<T>(impl)} {}
256285

257286
void pushValue(mlir::Location loc, fir::FirOpBuilder &builder,
258-
mlir::Value value) {
259-
std::visit([&](auto &temp) { temp.pushValue(loc, builder, value); }, impl);
287+
mlir::Value value, mlir::ValueRange indices = {}) {
288+
// Only ArrayTemp uses the loop indices; other temps don't take them.
289+
std::visit(Fortran::common::visitors{
290+
[&](ArrayTemp &temp) {
291+
temp.pushValue(loc, builder, value, indices);
292+
},
293+
[&](auto &temp) { temp.pushValue(loc, builder, value); }},
294+
impl);
260295
}
261296
void resetFetchPosition(mlir::Location loc, fir::FirOpBuilder &builder) {
262297
std::visit([&](auto &temp) { temp.resetFetchPosition(loc, builder); },
263298
impl);
264299
}
265-
mlir::Value fetch(mlir::Location loc, fir::FirOpBuilder &builder) {
266-
return std::visit([&](auto &temp) { return temp.fetch(loc, builder); },
267-
impl);
300+
mlir::Value fetch(mlir::Location loc, fir::FirOpBuilder &builder,
301+
mlir::ValueRange indices = {}) {
302+
return std::visit(
303+
Fortran::common::visitors{
304+
[&](ArrayTemp &temp) { return temp.fetch(loc, builder, indices); },
305+
[&](auto &temp) { return temp.fetch(loc, builder); }},
306+
impl);
268307
}
269308
void destroy(mlir::Location loc, fir::FirOpBuilder &builder) {
270309
std::visit([&](auto &temp) { temp.destroy(loc, builder); }, impl);
@@ -282,8 +321,9 @@ class TemporaryStorage {
282321
}
283322

284323
private:
285-
std::variant<HomogeneousScalarStack, SimpleCopy, SSARegister, AnyValueStack,
286-
AnyVariableStack, AnyVectorSubscriptStack, AnyAddressStack>
324+
std::variant<HomogeneousScalarStack, ArrayTemp, SimpleCopy, SSARegister,
325+
AnyValueStack, AnyVariableStack, AnyVectorSubscriptStack,
326+
AnyAddressStack>
287327
impl;
288328
};
289329
} // namespace fir::factory

flang/lib/Optimizer/Builder/TemporaryStorage.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,88 @@ hlfir::Entity fir::factory::HomogeneousScalarStack::moveStackAsArrayExpr(
134134
return hlfir::Entity{hlfirExpr};
135135
}
136136

137+
//===----------------------------------------------------------------------===//
138+
// fir::factory::ArrayTemp implementation.
139+
//===----------------------------------------------------------------------===//
140+
141+
fir::factory::ArrayTemp::ArrayTemp(mlir::Location loc,
142+
fir::FirOpBuilder &builder,
143+
fir::SequenceType declaredType,
144+
llvm::ArrayRef<mlir::Value> extents,
145+
llvm::ArrayRef<mlir::Value> lowerBounds,
146+
llvm::ArrayRef<mlir::Value> lengths,
147+
bool allocateOnHeap, llvm::StringRef name)
148+
: allocateOnHeap{allocateOnHeap},
149+
typeParams{lengths.begin(), lengths.end()} {
150+
assert(extents.size() == lowerBounds.size() &&
151+
"extents and lowerBounds must have the same size");
152+
assert(extents.size() == declaredType.getDimension() &&
153+
"declared type rank must match the number of extents");
154+
mlir::Value tempStorage;
155+
if (allocateOnHeap)
156+
tempStorage =
157+
builder.createHeapTemporary(loc, declaredType, name, extents, lengths);
158+
else
159+
tempStorage =
160+
builder.createTemporary(loc, declaredType, name, extents, lengths);
161+
// Use a fir.shape_shift so the temp's lower bounds match the loop bounds:
162+
// the indices passed to pushValue/fetch can then index it directly.
163+
mlir::Value shape = builder.genShape(loc, lowerBounds, extents);
164+
temp =
165+
hlfir::DeclareOp::create(builder, loc, tempStorage, name, shape, lengths)
166+
.getBase();
167+
}
168+
169+
/// Generate an hlfir.designate on \p temp for the element at \p indices. The
170+
/// indices are interpreted in the temp's array domain (matching its lower
171+
/// bounds, which were set from the enclosing loop bounds).
172+
static mlir::Value genArrayTempElementAddr(mlir::Location loc,
173+
fir::FirOpBuilder &builder,
174+
mlir::Value temp,
175+
mlir::ValueRange indices,
176+
mlir::ValueRange typeParams) {
177+
hlfir::Entity entity{temp};
178+
mlir::Type refTy = fir::ReferenceType::get(entity.getFortranElementType());
179+
mlir::Type idxTy = builder.getIndexType();
180+
llvm::SmallVector<mlir::Value> idxs;
181+
idxs.reserve(indices.size());
182+
for (mlir::Value idx : indices)
183+
idxs.push_back(builder.createConvert(loc, idxTy, idx));
184+
return hlfir::DesignateOp::create(builder, loc, refTy, temp, idxs,
185+
typeParams);
186+
}
187+
188+
void fir::factory::ArrayTemp::pushValue(mlir::Location loc,
189+
fir::FirOpBuilder &builder,
190+
mlir::Value value,
191+
mlir::ValueRange indices) {
192+
hlfir::Entity entity{value};
193+
assert(entity.isScalar() && "cannot use ArrayTemp with array");
194+
// Match HomogeneousScalarStack: derived types go through the runtime path.
195+
if (!entity.hasIntrinsicType())
196+
TODO(loc, "creating ArrayTemp for derived types");
197+
mlir::Value addr =
198+
genArrayTempElementAddr(loc, builder, temp, indices, typeParams);
199+
hlfir::AssignOp::create(builder, loc, value, addr);
200+
}
201+
202+
mlir::Value fir::factory::ArrayTemp::fetch(mlir::Location loc,
203+
fir::FirOpBuilder &builder,
204+
mlir::ValueRange indices) {
205+
mlir::Value addr =
206+
genArrayTempElementAddr(loc, builder, temp, indices, typeParams);
207+
return hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{addr});
208+
}
209+
210+
void fir::factory::ArrayTemp::destroy(mlir::Location loc,
211+
fir::FirOpBuilder &builder) {
212+
if (allocateOnHeap) {
213+
auto declare = temp.getDefiningOp<hlfir::DeclareOp>();
214+
assert(declare && "temp must have been declared");
215+
fir::FreeMemOp::create(builder, loc, declare.getMemref());
216+
}
217+
}
218+
137219
//===----------------------------------------------------------------------===//
138220
// fir::factory::SimpleCopy implementation.
139221
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//===----------------------------------------------------------------------===//
1919

2020
#include "ScheduleOrderedAssignments.h"
21+
#include "flang/Common/Fortran-consts.h"
2122
#include "flang/Optimizer/Builder/FIRBuilder.h"
2223
#include "flang/Optimizer/Builder/HLFIRTools.h"
2324
#include "flang/Optimizer/Builder/TemporaryStorage.h"
@@ -257,6 +258,11 @@ class OrderedAssignmentRewriter {
257258
bool currentLoopNestIterationNumberCanBeComputed(
258259
llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest);
259260

261+
/// Return the induction variables of the enclosing fir.do_loop nest at the
262+
/// current insertion point, innermost first (same order as
263+
/// currentLoopNestIterationNumberCanBeComputed).
264+
llvm::SmallVector<mlir::Value> getLoopIndices();
265+
260266
template <typename T>
261267
fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region &region,
262268
T &&temp) {
@@ -669,7 +675,8 @@ OrderedAssignmentRewriter::getIfSaved(mlir::Region &region) {
669675
// If the region was saved in a previous run, fetch the saved value.
670676
if (auto temp = savedEntities.find(&region); temp != savedEntities.end()) {
671677
doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); });
672-
return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt};
678+
return ValueAndCleanUp{temp->second.fetch(loc, builder, getLoopIndices()),
679+
std::nullopt};
673680
}
674681
return std::nullopt;
675682
}
@@ -1109,6 +1116,61 @@ computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder,
11091116
return loopExtent;
11101117
}
11111118

1119+
/// If \p value is a compile-time integer constant (possibly hidden behind
1120+
/// fir.convert ops), return its value. Otherwise return std::nullopt.
1121+
static std::optional<int64_t> unwrapConstantInt(mlir::Value value) {
1122+
while (auto convert = value.getDefiningOp<fir::ConvertOp>())
1123+
value = convert.getValue();
1124+
return fir::getIntIfConstant(value);
1125+
}
1126+
1127+
/// Compute the extents and lower bounds of \p loopNest, in the same order as
1128+
/// \p loopNest (innermost first). The lower bound of each dimension is the
1129+
/// smallest induction variable value, so that the loop induction variable
1130+
/// can directly index the temp via fir.shape_shift. This only works when
1131+
/// every loop has a unit step: for step +1 the smallest iv is the loop's
1132+
/// lower bound; for step -1 it is the loop's upper bound. Returns false
1133+
/// (with \p extents and \p lowerBounds left in an unspecified state) when
1134+
/// any loop has a non-unit or non-constant step, signalling that the caller
1135+
/// should fall back to a counter-based temp.
1136+
static bool computeLoopNestExtentsAndLowerBounds(
1137+
mlir::Location loc, fir::FirOpBuilder &builder,
1138+
llvm::ArrayRef<fir::DoLoopOp> loopNest,
1139+
llvm::SmallVectorImpl<mlir::Value> &extents,
1140+
llvm::SmallVectorImpl<mlir::Value> &lowerBounds) {
1141+
extents.reserve(loopNest.size());
1142+
lowerBounds.reserve(loopNest.size());
1143+
for (fir::DoLoopOp doLoop : loopNest) {
1144+
auto step = unwrapConstantInt(doLoop.getStep());
1145+
if (!step || std::abs(*step) != 1)
1146+
return false;
1147+
mlir::Value extent = builder.genExtentFromTriplet(
1148+
loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(),
1149+
builder.getIndexType());
1150+
extents.push_back(extent);
1151+
lowerBounds.push_back(*step == 1 ? doLoop.getLowerBound()
1152+
: doLoop.getUpperBound());
1153+
}
1154+
return true;
1155+
}
1156+
1157+
llvm::SmallVector<mlir::Value> OrderedAssignmentRewriter::getLoopIndices() {
1158+
llvm::SmallVector<mlir::Value> indices;
1159+
if (constructStack.empty())
1160+
return indices;
1161+
mlir::Operation *outerLoop = constructStack[0];
1162+
mlir::Operation *currentConstruct = constructStack.back();
1163+
while (currentConstruct) {
1164+
if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct))
1165+
indices.push_back(doLoop.getInductionVar());
1166+
if (currentConstruct == outerLoop)
1167+
currentConstruct = nullptr;
1168+
else
1169+
currentConstruct = currentConstruct->getParentOp();
1170+
}
1171+
return indices;
1172+
}
1173+
11121174
/// Return a name for temporary storage that indicates in which context
11131175
/// the temporary storage was created.
11141176
static llvm::StringRef
@@ -1160,11 +1222,27 @@ void OrderedAssignmentRewriter::generateSaveEntity(
11601222
bool loopShapeCanBePreComputed =
11611223
currentLoopNestIterationNumberCanBeComputed(loopNest);
11621224
doBeforeLoopNest([&] {
1163-
/// For simple scalars inside loops whose total iteration number can be
1164-
/// pre-computed, create a rank-1 array outside of the loops. It will be
1165-
/// assigned/fetched inside the loops like a normal Fortran array given
1166-
/// the iteration count.
1167-
if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) {
1225+
// For simple scalars in a precomputable loop nest, prefer the
1226+
// multidimensional ArrayTemp (indexed by loop induction variables) so
1227+
// there is no loop-carried counter. Fall back to the 1D counter-based
1228+
// HomogeneousScalarStack when the nest is deeper than the maximum
1229+
// fir.array rank or when any loop has a non-unit/non-constant step
1230+
// (in which case the loop induction variable cannot index the temp
1231+
// directly).
1232+
llvm::SmallVector<mlir::Value> tempExtents;
1233+
llvm::SmallVector<mlir::Value> tempLowerBounds;
1234+
if (loopShapeCanBePreComputed && fir::isa_trivial(entityType) &&
1235+
loopNest.size() <= static_cast<size_t>(Fortran::common::maxRank) &&
1236+
computeLoopNestExtentsAndLowerBounds(loc, builder, loopNest,
1237+
tempExtents, tempLowerBounds)) {
1238+
auto sequenceType = mlir::cast<fir::SequenceType>(
1239+
builder.getVarLenSeqTy(entityType, /*rank=*/loopNest.size()));
1240+
temp = insertSavedEntity(
1241+
region,
1242+
fir::factory::ArrayTemp{loc, builder, sequenceType, tempExtents,
1243+
tempLowerBounds,
1244+
/*lengths=*/{}, allocateOnHeap, tempName});
1245+
} else if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) {
11681246
mlir::Value loopExtent =
11691247
computeLoopNestIterationNumber(loc, builder, loopNest);
11701248
auto sequenceType =
@@ -1174,7 +1252,6 @@ void OrderedAssignmentRewriter::generateSaveEntity(
11741252
loc, builder, sequenceType, loopExtent,
11751253
/*lenParams=*/{}, allocateOnHeap,
11761254
/*stackThroughLoops=*/true, tempName});
1177-
11781255
} else {
11791256
// If the number of iteration is not known, or if the values at each
11801257
// iterations are values that may have different shape, type parameters
@@ -1185,8 +1262,8 @@ void OrderedAssignmentRewriter::generateSaveEntity(
11851262
}
11861263
});
11871264
// Inside the loop nest (and any fir.if if there are active masks), copy
1188-
// the value to the temp and do clean-ups for the value if any.
1189-
temp->pushValue(loc, builder, entity);
1265+
// the value to the temp and do clean-ups of the value if any.
1266+
temp->pushValue(loc, builder, entity, getLoopIndices());
11901267
}
11911268

11921269
// Delay the clean-up if the entity will be used in the same run (i.e., the
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
! Test that the lower-hlfir-ordered-assignments pass falls back to the
2+
! 1D HomogeneousScalarStack temporary (counter-based) when the FORALL loop
3+
! nest is deeper than Fortran::common::maxRank (15), because fir.array can
4+
! only hold up to maxRank dimensions.
5+
!
6+
! Below maxRank, the new ArrayTemp is used and there is no counter; here we
7+
! verify the opposite: the counter (a fir.alloca index, fir.load/addi/store
8+
! pattern) is restored when the loop nest has 16 levels.
9+
!
10+
! The test uses a rank-8 array of derived type with a rank-8 array component
11+
! to spread 16 indexable dimensions across the FORALL header.
12+
!
13+
! RUN: bbc -emit-hlfir -o - %s | fir-opt --lower-hlfir-ordered-assignments | FileCheck %s
14+
15+
module many_forall_mod
16+
type :: t
17+
real :: c(2,2,2,2,2,2,2,2)
18+
end type
19+
contains
20+
subroutine more_than_15_forall(a)
21+
type(t), intent(inout) :: a(2,2,2,2,2,2,2,2)
22+
forall (i1=1:2, i2=1:2, i3=1:2, i4=1:2, i5=1:2, i6=1:2, i7=1:2, i8=1:2, &
23+
j1=1:2, j2=1:2, j3=1:2, j4=1:2, j5=1:2, j6=1:2, j7=1:2, j8=1:2)
24+
a(i1,i2,i3,i4,i5,i6,i7,i8)%c(j1,j2,j3,j4,j5,j6,j7,j8) = &
25+
a(3-i1,3-i2,3-i3,3-i4,3-i5,3-i6,3-i7,3-i8)%c(3-j1,3-j2,3-j3,3-j4,3-j5,3-j6,3-j7,3-j8)
26+
end forall
27+
end subroutine
28+
end module
29+
! With 16 nested loops, the temporary must be the 1D counter-based form
30+
! (HomogeneousScalarStack) instead of a 16D ArrayTemp, since fir.array is
31+
! limited to Fortran::common::maxRank dimensions.
32+
!
33+
! CHECK-LABEL: func.func @_QMmany_forall_modPmore_than_15_forall(
34+
! There must be a counter in memory (fir.alloca index).
35+
! CHECK: %[[CTR:.*]] = fir.alloca index
36+
! The temporary is a 1D fir.array<?xf32>.
37+
! CHECK: %[[ALLOC:.*]] = fir.allocmem !fir.array<?xf32>, %{{.*}} {bindc_name = ".tmp.forall", uniq_name = ""}
38+
! Plain fir.shape (no shift), since the temp is indexed by the counter.
39+
! CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
40+
! CHECK: hlfir.declare %[[ALLOC]](%[[SHAPE]]) {uniq_name = ".tmp.forall"} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.heap<!fir.array<?xf32>>)
41+
! Inside the loop nest the counter is incremented and the temp is indexed
42+
! through the counter (not directly through the loop induction variables).
43+
! CHECK: fir.load %[[CTR]] : !fir.ref<index>
44+
! CHECK: arith.addi %{{.*}}, %{{.*}} : index
45+
! CHECK: fir.store %{{.*}} to %[[CTR]] : !fir.ref<index>

0 commit comments

Comments
 (0)