Skip to content

Commit de35293

Browse files
refusenoict-qlbuggfg
authored
[flang][lowering] Implement component-wise initialization for derived types (llvm#187465)
Currently, the compiler defaults to a full `memcpy` when initializing derived types. This patch introduces component-wise initialization for pointer / allocatable components, avoiding unnecessary initialization data generation and redundant copies. Ineligible cases continue to use the existing `memcpy` initialization path. RFC: https://discourse.llvm.org/t/rfc-automatic-static-promotion-of-large-local-variables-in-flang/89539 Key changes: - In `flang/lib/Lower/ConvertVariable.cpp`: - Add `genDerivedTypeComponentInit` for component-wise derived type initialization. - Add `isEligibleForComponentWiseInit` to guard the new initialization path. - Add `genInlinedInitWithMemcpy` to factor out the existing full `memcpy` initialization logic. - Update `defaultInitializeAtRuntime` to select the appropriate initialization path. - Add and update regression tests. --------- Co-authored-by: ict-ql <168183727+ict-ql@users.noreply.github.com> Co-authored-by: buggfg <wangyingying@bosc.ac.cn>
1 parent ceb18ff commit de35293

10 files changed

Lines changed: 399 additions & 71 deletions

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 176 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "flang/Optimizer/Dialect/FIRAttr.h"
3838
#include "flang/Optimizer/Dialect/FIRDialect.h"
3939
#include "flang/Optimizer/Dialect/FIROps.h"
40+
#include "flang/Optimizer/Dialect/FIRType.h"
4041
#include "flang/Optimizer/Dialect/MIF/MIFOps.h"
4142
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
4243
#include "flang/Optimizer/HLFIR/HLFIROps.h"
@@ -46,7 +47,9 @@
4647
#include "flang/Runtime/allocator-registry-consts.h"
4748
#include "flang/Semantics/runtime-type-info.h"
4849
#include "flang/Semantics/tools.h"
50+
#include "flang/Semantics/type.h"
4951
#include "mlir/Dialect/OpenACC/OpenACC.h"
52+
#include "llvm/ADT/SmallVector.h"
5053
#include "llvm/Support/CommandLine.h"
5154
#include "llvm/Support/Debug.h"
5255
#include <optional>
@@ -830,6 +833,164 @@ mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) {
830833
return Fortran::lower::hasDefaultInitialization(sym);
831834
}
832835

836+
/// Performs component-wise initialization for \p derivedSpec,
837+
/// selectively generating IR only for components that require it.
838+
static void genDerivedTypeComponentInit(
839+
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
840+
const Fortran::semantics::DerivedTypeSpec &derivedSpec,
841+
mlir::Value baseAddr, fir::RecordType recTy) {
842+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
843+
// Flattened DFS traversal to visit only the leaf components.
844+
Fortran::semantics::UltimateComponentIterator ultimateIter{derivedSpec};
845+
for (auto it = ultimateIter.begin(); it != ultimateIter.end(); ++it) {
846+
const Fortran::semantics::Symbol &comp = *it;
847+
const auto *objDetails =
848+
comp.detailsIf<Fortran::semantics::ObjectEntityDetails>();
849+
const auto *procDetails =
850+
comp.detailsIf<Fortran::semantics::ProcEntityDetails>();
851+
if ((!objDetails && !procDetails) ||
852+
!Fortran::semantics::IsAllocatableOrPointer(comp))
853+
continue;
854+
// Retrieve the nested symbol path from the root type to this ultimate
855+
// component.
856+
auto path = it.GetComponentPath();
857+
mlir::Value currentAddr = baseAddr;
858+
fir::RecordType currentRecTy = recTy;
859+
// Traverse the path and calculate memory coordinates.
860+
for (const Fortran::semantics::Symbol &pathSym : path) {
861+
const Fortran::semantics::Symbol *symPtr = &pathSym;
862+
// Generate coordinate for this nesting level.
863+
std::string name = converter.getRecordTypeFieldName(*symPtr);
864+
mlir::Type compFirTy = currentRecTy.getType(name);
865+
assert(compFirTy && "Component field type not found in RecordType");
866+
auto fieldIdx = fir::FieldIndexOp::create(
867+
builder, loc, fir::FieldType::get(currentRecTy.getContext()), name,
868+
currentRecTy, mlir::ValueRange{});
869+
currentAddr =
870+
fir::CoordinateOp::create(builder, loc, builder.getRefType(compFirTy),
871+
currentAddr, mlir::ValueRange{fieldIdx});
872+
currentRecTy = mlir::dyn_cast<fir::RecordType>(compFirTy);
873+
}
874+
mlir::Type finalCompFirTy = fir::unwrapPassByRefType(currentAddr.getType());
875+
mlir::Value initVal;
876+
if (objDetails) {
877+
initVal = fir::factory::createUnallocatedBox(builder, loc, finalCompFirTy,
878+
mlir::ValueRange{});
879+
} else {
880+
initVal = fir::factory::createNullBoxProc(builder, loc, finalCompFirTy);
881+
}
882+
fir::StoreOp::create(builder, loc, initVal, currentAddr);
883+
}
884+
}
885+
886+
/// Initializes a derived type via a bulk memory copy (memcpy).
887+
/// This method generates a global constant containing the default
888+
/// initialized state of the type, and copies it directly into the
889+
/// target memory location.
890+
static void
891+
genInlinedInitWithMemcpy(Fortran::lower::AbstractConverter &converter,
892+
const Fortran::semantics::Symbol &sym,
893+
mlir::Type symTy, const fir::ExtendedValue &exv,
894+
const Fortran::semantics::ObjectEntityDetails *details,
895+
const Fortran::semantics::DeclTypeSpec *declTy) {
896+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
897+
mlir::Location symLoc = genLocation(converter, sym);
898+
std::string globalName = fir::NameUniquer::doGenerated(
899+
(converter.mangleName(*declTy->AsDerived()) + fir::kNameSeparator +
900+
fir::kDerivedTypeInitSuffix)
901+
.str());
902+
mlir::StringAttr linkage = builder.createInternalLinkage();
903+
fir::GlobalOp global = builder.getNamedGlobal(globalName);
904+
if (!global && details->init()) {
905+
global = builder.createGlobal(symLoc, symTy, globalName, linkage,
906+
mlir::Attribute{},
907+
/*isConst=*/true,
908+
/*isTarget=*/false,
909+
/*dataAttr=*/{});
910+
createGlobalInitialization(
911+
builder, global, [&](fir::FirOpBuilder &builder) {
912+
Fortran::lower::StatementContext stmtCtx(
913+
/*cleanupProhibited=*/true);
914+
fir::ExtendedValue initVal = genInitializerExprValue(
915+
converter, symLoc, details->init().value(), stmtCtx);
916+
mlir::Value castTo =
917+
builder.createConvert(symLoc, symTy, fir::getBase(initVal));
918+
fir::HasValueOp::create(builder, symLoc, castTo);
919+
});
920+
} else if (!global) {
921+
global = builder.createGlobal(symLoc, symTy, globalName, linkage,
922+
mlir::Attribute{},
923+
/*isConst=*/true,
924+
/*isTarget=*/false,
925+
/*dataAttr=*/{});
926+
createGlobalInitialization(
927+
builder, global, [&](fir::FirOpBuilder &builder) {
928+
Fortran::lower::StatementContext stmtCtx(
929+
/*cleanupProhibited=*/true);
930+
mlir::Value initVal = genDefaultInitializerValue(converter, symLoc,
931+
sym, symTy, stmtCtx);
932+
mlir::Value castTo = builder.createConvert(symLoc, symTy, initVal);
933+
fir::HasValueOp::create(builder, symLoc, castTo);
934+
});
935+
}
936+
auto addrOf = fir::AddrOfOp::create(builder, symLoc, global.resultType(),
937+
global.getSymbol());
938+
fir::CopyOp::create(builder, symLoc, addrOf, fir::getBase(exv),
939+
/*noOverlap=*/true);
940+
}
941+
942+
/// Checks if a derived type is eligible for component-wise initialization.
943+
/// This is preferred over a bulk memcpy when the type contains at least
944+
/// one pointer/allocatable component, but no components with default
945+
/// initializers.
946+
static bool isEligibleForComponentWiseInit(
947+
const Fortran::semantics::DerivedTypeSpec &derivedSpec) {
948+
bool hasPtrOrAlloc = false;
949+
// Worklist for iterative traversal of the derived type tree.
950+
llvm::SmallVector<const Fortran::semantics::DerivedTypeSpec *> worklist;
951+
worklist.push_back(&derivedSpec);
952+
while (!worklist.empty()) {
953+
const Fortran::semantics::DerivedTypeSpec *currentSpec =
954+
worklist.pop_back_val();
955+
const Fortran::semantics::Scope *scope = currentSpec->scope();
956+
assert(scope && "derived type has no scope");
957+
const auto &typeDetails =
958+
currentSpec->typeSymbol().get<Fortran::semantics::DerivedTypeDetails>();
959+
for (const auto &compName : typeDetails.componentNames()) {
960+
auto scopeIter = scope->find(compName);
961+
assert(scopeIter != scope->cend() &&
962+
"component name must exist in its scope");
963+
const Fortran::semantics::Symbol &comp = scopeIter->second.get();
964+
// Return false if any component has a default initializer.
965+
const auto *objDetails =
966+
comp.detailsIf<Fortran::semantics::ObjectEntityDetails>();
967+
const auto *procDetails =
968+
comp.detailsIf<Fortran::semantics::ProcEntityDetails>();
969+
if ((objDetails && objDetails->init()) ||
970+
(procDetails && procDetails->init()))
971+
return false;
972+
if (Fortran::semantics::IsAllocatableOrPointer(comp)) {
973+
hasPtrOrAlloc = true;
974+
continue;
975+
}
976+
// Traverse nested derived types.
977+
if (const Fortran::semantics::DeclTypeSpec *declTy = comp.GetType()) {
978+
if (const Fortran::semantics::DerivedTypeSpec *nestedSpec =
979+
declTy->AsDerived()) {
980+
if (Fortran::lower::hasDefaultInitialization(comp)) {
981+
// Return false for arrays of derived types requiring
982+
// initialization.
983+
if (comp.Rank() > 0)
984+
return false;
985+
worklist.push_back(nestedSpec);
986+
}
987+
}
988+
}
989+
}
990+
}
991+
return hasPtrOrAlloc;
992+
}
993+
833994
/// Call default initialization runtime routine to initialize \p var.
834995
void Fortran::lower::defaultInitializeAtRuntime(
835996
Fortran::lower::AbstractConverter &converter,
@@ -857,56 +1018,27 @@ void Fortran::lower::defaultInitializeAtRuntime(
8571018
mlir::Type symTy = converter.genType(sym);
8581019
const auto *details =
8591020
sym.detailsIf<Fortran::semantics::ObjectEntityDetails>();
860-
if (details && !Fortran::semantics::IsPolymorphic(sym) &&
1021+
bool isDerivedTypeScalar =
1022+
details && !Fortran::semantics::IsPolymorphic(sym) &&
8611023
declTy->category() ==
8621024
Fortran::semantics::DeclTypeSpec::Category::TypeDerived &&
8631025
!mlir::isa<fir::SequenceType>(symTy) &&
8641026
!sym.test(Fortran::semantics::Symbol::Flag::OmpPrivate) &&
8651027
!sym.test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) &&
866-
!Fortran::semantics::HasCUDAComponent(sym)) {
867-
std::string globalName = fir::NameUniquer::doGenerated(
868-
(converter.mangleName(*declTy->AsDerived()) + fir::kNameSeparator +
869-
fir::kDerivedTypeInitSuffix)
870-
.str());
871-
mlir::Location loc = genLocation(converter, sym);
872-
mlir::StringAttr linkage = builder.createInternalLinkage();
873-
fir::GlobalOp global = builder.getNamedGlobal(globalName);
874-
if (!global && details->init()) {
875-
global = builder.createGlobal(loc, symTy, globalName, linkage,
876-
mlir::Attribute{},
877-
/*isConst=*/true,
878-
/*isTarget=*/false,
879-
/*dataAttr=*/{});
880-
createGlobalInitialization(
881-
builder, global, [&](fir::FirOpBuilder &builder) {
882-
Fortran::lower::StatementContext stmtCtx(
883-
/*cleanupProhibited=*/true);
884-
fir::ExtendedValue initVal = genInitializerExprValue(
885-
converter, loc, details->init().value(), stmtCtx);
886-
mlir::Value castTo =
887-
builder.createConvert(loc, symTy, fir::getBase(initVal));
888-
fir::HasValueOp::create(builder, loc, castTo);
889-
});
890-
} else if (!global) {
891-
global = builder.createGlobal(loc, symTy, globalName, linkage,
892-
mlir::Attribute{},
893-
/*isConst=*/true,
894-
/*isTarget=*/false,
895-
/*dataAttr=*/{});
896-
createGlobalInitialization(
897-
builder, global, [&](fir::FirOpBuilder &builder) {
898-
Fortran::lower::StatementContext stmtCtx(
899-
/*cleanupProhibited=*/true);
900-
mlir::Value initVal = genDefaultInitializerValue(
901-
converter, loc, sym, symTy, stmtCtx);
902-
mlir::Value castTo = builder.createConvert(loc, symTy, initVal);
903-
fir::HasValueOp::create(builder, loc, castTo);
904-
});
1028+
!Fortran::semantics::HasCUDAComponent(sym);
1029+
if (isDerivedTypeScalar) {
1030+
const auto &derivedSpec = *declTy->AsDerived();
1031+
if (isEligibleForComponentWiseInit(derivedSpec)) {
1032+
// Component-wise initialization.
1033+
mlir::Value baseAddr = fir::getBase(exv);
1034+
auto recTy = mlir::cast<fir::RecordType>(
1035+
fir::unwrapPassByRefType(baseAddr.getType()));
1036+
genDerivedTypeComponentInit(converter, loc, derivedSpec, baseAddr,
1037+
recTy);
1038+
} else {
1039+
// Initialize via bulk memory copy from a global constant.
1040+
genInlinedInitWithMemcpy(converter, sym, symTy, exv, details, declTy);
9051041
}
906-
auto addrOf = fir::AddrOfOp::create(builder, loc, global.resultType(),
907-
global.getSymbol());
908-
fir::CopyOp::create(builder, loc, addrOf, fir::getBase(exv),
909-
/*noOverlap=*/true);
9101042
} else {
9111043
mlir::Value box = builder.createBox(loc, exv);
9121044
fir::runtime::genDerivedTypeInitialize(builder, loc, box);

0 commit comments

Comments
 (0)