|
37 | 37 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
38 | 38 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
39 | 39 | #include "flang/Optimizer/Dialect/FIROps.h" |
| 40 | +#include "flang/Optimizer/Dialect/FIRType.h" |
40 | 41 | #include "flang/Optimizer/Dialect/MIF/MIFOps.h" |
41 | 42 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
42 | 43 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
|
46 | 47 | #include "flang/Runtime/allocator-registry-consts.h" |
47 | 48 | #include "flang/Semantics/runtime-type-info.h" |
48 | 49 | #include "flang/Semantics/tools.h" |
| 50 | +#include "flang/Semantics/type.h" |
49 | 51 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
| 52 | +#include "llvm/ADT/SmallVector.h" |
50 | 53 | #include "llvm/Support/CommandLine.h" |
51 | 54 | #include "llvm/Support/Debug.h" |
52 | 55 | #include <optional> |
@@ -830,6 +833,164 @@ mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) { |
830 | 833 | return Fortran::lower::hasDefaultInitialization(sym); |
831 | 834 | } |
832 | 835 |
|
| 836 | +/// Performs component-wise initialization for \p derivedSpec, |
| 837 | +/// selectively generating IR only for components that require it. |
| 838 | +static void genDerivedTypeComponentInit( |
| 839 | + Fortran::lower::AbstractConverter &converter, mlir::Location loc, |
| 840 | + const Fortran::semantics::DerivedTypeSpec &derivedSpec, |
| 841 | + mlir::Value baseAddr, fir::RecordType recTy) { |
| 842 | + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| 843 | + // Flattened DFS traversal to visit only the leaf components. |
| 844 | + Fortran::semantics::UltimateComponentIterator ultimateIter{derivedSpec}; |
| 845 | + for (auto it = ultimateIter.begin(); it != ultimateIter.end(); ++it) { |
| 846 | + const Fortran::semantics::Symbol &comp = *it; |
| 847 | + const auto *objDetails = |
| 848 | + comp.detailsIf<Fortran::semantics::ObjectEntityDetails>(); |
| 849 | + const auto *procDetails = |
| 850 | + comp.detailsIf<Fortran::semantics::ProcEntityDetails>(); |
| 851 | + if ((!objDetails && !procDetails) || |
| 852 | + !Fortran::semantics::IsAllocatableOrPointer(comp)) |
| 853 | + continue; |
| 854 | + // Retrieve the nested symbol path from the root type to this ultimate |
| 855 | + // component. |
| 856 | + auto path = it.GetComponentPath(); |
| 857 | + mlir::Value currentAddr = baseAddr; |
| 858 | + fir::RecordType currentRecTy = recTy; |
| 859 | + // Traverse the path and calculate memory coordinates. |
| 860 | + for (const Fortran::semantics::Symbol &pathSym : path) { |
| 861 | + const Fortran::semantics::Symbol *symPtr = &pathSym; |
| 862 | + // Generate coordinate for this nesting level. |
| 863 | + std::string name = converter.getRecordTypeFieldName(*symPtr); |
| 864 | + mlir::Type compFirTy = currentRecTy.getType(name); |
| 865 | + assert(compFirTy && "Component field type not found in RecordType"); |
| 866 | + auto fieldIdx = fir::FieldIndexOp::create( |
| 867 | + builder, loc, fir::FieldType::get(currentRecTy.getContext()), name, |
| 868 | + currentRecTy, mlir::ValueRange{}); |
| 869 | + currentAddr = |
| 870 | + fir::CoordinateOp::create(builder, loc, builder.getRefType(compFirTy), |
| 871 | + currentAddr, mlir::ValueRange{fieldIdx}); |
| 872 | + currentRecTy = mlir::dyn_cast<fir::RecordType>(compFirTy); |
| 873 | + } |
| 874 | + mlir::Type finalCompFirTy = fir::unwrapPassByRefType(currentAddr.getType()); |
| 875 | + mlir::Value initVal; |
| 876 | + if (objDetails) { |
| 877 | + initVal = fir::factory::createUnallocatedBox(builder, loc, finalCompFirTy, |
| 878 | + mlir::ValueRange{}); |
| 879 | + } else { |
| 880 | + initVal = fir::factory::createNullBoxProc(builder, loc, finalCompFirTy); |
| 881 | + } |
| 882 | + fir::StoreOp::create(builder, loc, initVal, currentAddr); |
| 883 | + } |
| 884 | +} |
| 885 | + |
| 886 | +/// Initializes a derived type via a bulk memory copy (memcpy). |
| 887 | +/// This method generates a global constant containing the default |
| 888 | +/// initialized state of the type, and copies it directly into the |
| 889 | +/// target memory location. |
| 890 | +static void |
| 891 | +genInlinedInitWithMemcpy(Fortran::lower::AbstractConverter &converter, |
| 892 | + const Fortran::semantics::Symbol &sym, |
| 893 | + mlir::Type symTy, const fir::ExtendedValue &exv, |
| 894 | + const Fortran::semantics::ObjectEntityDetails *details, |
| 895 | + const Fortran::semantics::DeclTypeSpec *declTy) { |
| 896 | + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| 897 | + mlir::Location symLoc = genLocation(converter, sym); |
| 898 | + std::string globalName = fir::NameUniquer::doGenerated( |
| 899 | + (converter.mangleName(*declTy->AsDerived()) + fir::kNameSeparator + |
| 900 | + fir::kDerivedTypeInitSuffix) |
| 901 | + .str()); |
| 902 | + mlir::StringAttr linkage = builder.createInternalLinkage(); |
| 903 | + fir::GlobalOp global = builder.getNamedGlobal(globalName); |
| 904 | + if (!global && details->init()) { |
| 905 | + global = builder.createGlobal(symLoc, symTy, globalName, linkage, |
| 906 | + mlir::Attribute{}, |
| 907 | + /*isConst=*/true, |
| 908 | + /*isTarget=*/false, |
| 909 | + /*dataAttr=*/{}); |
| 910 | + createGlobalInitialization( |
| 911 | + builder, global, [&](fir::FirOpBuilder &builder) { |
| 912 | + Fortran::lower::StatementContext stmtCtx( |
| 913 | + /*cleanupProhibited=*/true); |
| 914 | + fir::ExtendedValue initVal = genInitializerExprValue( |
| 915 | + converter, symLoc, details->init().value(), stmtCtx); |
| 916 | + mlir::Value castTo = |
| 917 | + builder.createConvert(symLoc, symTy, fir::getBase(initVal)); |
| 918 | + fir::HasValueOp::create(builder, symLoc, castTo); |
| 919 | + }); |
| 920 | + } else if (!global) { |
| 921 | + global = builder.createGlobal(symLoc, symTy, globalName, linkage, |
| 922 | + mlir::Attribute{}, |
| 923 | + /*isConst=*/true, |
| 924 | + /*isTarget=*/false, |
| 925 | + /*dataAttr=*/{}); |
| 926 | + createGlobalInitialization( |
| 927 | + builder, global, [&](fir::FirOpBuilder &builder) { |
| 928 | + Fortran::lower::StatementContext stmtCtx( |
| 929 | + /*cleanupProhibited=*/true); |
| 930 | + mlir::Value initVal = genDefaultInitializerValue(converter, symLoc, |
| 931 | + sym, symTy, stmtCtx); |
| 932 | + mlir::Value castTo = builder.createConvert(symLoc, symTy, initVal); |
| 933 | + fir::HasValueOp::create(builder, symLoc, castTo); |
| 934 | + }); |
| 935 | + } |
| 936 | + auto addrOf = fir::AddrOfOp::create(builder, symLoc, global.resultType(), |
| 937 | + global.getSymbol()); |
| 938 | + fir::CopyOp::create(builder, symLoc, addrOf, fir::getBase(exv), |
| 939 | + /*noOverlap=*/true); |
| 940 | +} |
| 941 | + |
| 942 | +/// Checks if a derived type is eligible for component-wise initialization. |
| 943 | +/// This is preferred over a bulk memcpy when the type contains at least |
| 944 | +/// one pointer/allocatable component, but no components with default |
| 945 | +/// initializers. |
| 946 | +static bool isEligibleForComponentWiseInit( |
| 947 | + const Fortran::semantics::DerivedTypeSpec &derivedSpec) { |
| 948 | + bool hasPtrOrAlloc = false; |
| 949 | + // Worklist for iterative traversal of the derived type tree. |
| 950 | + llvm::SmallVector<const Fortran::semantics::DerivedTypeSpec *> worklist; |
| 951 | + worklist.push_back(&derivedSpec); |
| 952 | + while (!worklist.empty()) { |
| 953 | + const Fortran::semantics::DerivedTypeSpec *currentSpec = |
| 954 | + worklist.pop_back_val(); |
| 955 | + const Fortran::semantics::Scope *scope = currentSpec->scope(); |
| 956 | + assert(scope && "derived type has no scope"); |
| 957 | + const auto &typeDetails = |
| 958 | + currentSpec->typeSymbol().get<Fortran::semantics::DerivedTypeDetails>(); |
| 959 | + for (const auto &compName : typeDetails.componentNames()) { |
| 960 | + auto scopeIter = scope->find(compName); |
| 961 | + assert(scopeIter != scope->cend() && |
| 962 | + "component name must exist in its scope"); |
| 963 | + const Fortran::semantics::Symbol &comp = scopeIter->second.get(); |
| 964 | + // Return false if any component has a default initializer. |
| 965 | + const auto *objDetails = |
| 966 | + comp.detailsIf<Fortran::semantics::ObjectEntityDetails>(); |
| 967 | + const auto *procDetails = |
| 968 | + comp.detailsIf<Fortran::semantics::ProcEntityDetails>(); |
| 969 | + if ((objDetails && objDetails->init()) || |
| 970 | + (procDetails && procDetails->init())) |
| 971 | + return false; |
| 972 | + if (Fortran::semantics::IsAllocatableOrPointer(comp)) { |
| 973 | + hasPtrOrAlloc = true; |
| 974 | + continue; |
| 975 | + } |
| 976 | + // Traverse nested derived types. |
| 977 | + if (const Fortran::semantics::DeclTypeSpec *declTy = comp.GetType()) { |
| 978 | + if (const Fortran::semantics::DerivedTypeSpec *nestedSpec = |
| 979 | + declTy->AsDerived()) { |
| 980 | + if (Fortran::lower::hasDefaultInitialization(comp)) { |
| 981 | + // Return false for arrays of derived types requiring |
| 982 | + // initialization. |
| 983 | + if (comp.Rank() > 0) |
| 984 | + return false; |
| 985 | + worklist.push_back(nestedSpec); |
| 986 | + } |
| 987 | + } |
| 988 | + } |
| 989 | + } |
| 990 | + } |
| 991 | + return hasPtrOrAlloc; |
| 992 | +} |
| 993 | + |
833 | 994 | /// Call default initialization runtime routine to initialize \p var. |
834 | 995 | void Fortran::lower::defaultInitializeAtRuntime( |
835 | 996 | Fortran::lower::AbstractConverter &converter, |
@@ -857,56 +1018,27 @@ void Fortran::lower::defaultInitializeAtRuntime( |
857 | 1018 | mlir::Type symTy = converter.genType(sym); |
858 | 1019 | const auto *details = |
859 | 1020 | sym.detailsIf<Fortran::semantics::ObjectEntityDetails>(); |
860 | | - if (details && !Fortran::semantics::IsPolymorphic(sym) && |
| 1021 | + bool isDerivedTypeScalar = |
| 1022 | + details && !Fortran::semantics::IsPolymorphic(sym) && |
861 | 1023 | declTy->category() == |
862 | 1024 | Fortran::semantics::DeclTypeSpec::Category::TypeDerived && |
863 | 1025 | !mlir::isa<fir::SequenceType>(symTy) && |
864 | 1026 | !sym.test(Fortran::semantics::Symbol::Flag::OmpPrivate) && |
865 | 1027 | !sym.test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) && |
866 | | - !Fortran::semantics::HasCUDAComponent(sym)) { |
867 | | - std::string globalName = fir::NameUniquer::doGenerated( |
868 | | - (converter.mangleName(*declTy->AsDerived()) + fir::kNameSeparator + |
869 | | - fir::kDerivedTypeInitSuffix) |
870 | | - .str()); |
871 | | - mlir::Location loc = genLocation(converter, sym); |
872 | | - mlir::StringAttr linkage = builder.createInternalLinkage(); |
873 | | - fir::GlobalOp global = builder.getNamedGlobal(globalName); |
874 | | - if (!global && details->init()) { |
875 | | - global = builder.createGlobal(loc, symTy, globalName, linkage, |
876 | | - mlir::Attribute{}, |
877 | | - /*isConst=*/true, |
878 | | - /*isTarget=*/false, |
879 | | - /*dataAttr=*/{}); |
880 | | - createGlobalInitialization( |
881 | | - builder, global, [&](fir::FirOpBuilder &builder) { |
882 | | - Fortran::lower::StatementContext stmtCtx( |
883 | | - /*cleanupProhibited=*/true); |
884 | | - fir::ExtendedValue initVal = genInitializerExprValue( |
885 | | - converter, loc, details->init().value(), stmtCtx); |
886 | | - mlir::Value castTo = |
887 | | - builder.createConvert(loc, symTy, fir::getBase(initVal)); |
888 | | - fir::HasValueOp::create(builder, loc, castTo); |
889 | | - }); |
890 | | - } else if (!global) { |
891 | | - global = builder.createGlobal(loc, symTy, globalName, linkage, |
892 | | - mlir::Attribute{}, |
893 | | - /*isConst=*/true, |
894 | | - /*isTarget=*/false, |
895 | | - /*dataAttr=*/{}); |
896 | | - createGlobalInitialization( |
897 | | - builder, global, [&](fir::FirOpBuilder &builder) { |
898 | | - Fortran::lower::StatementContext stmtCtx( |
899 | | - /*cleanupProhibited=*/true); |
900 | | - mlir::Value initVal = genDefaultInitializerValue( |
901 | | - converter, loc, sym, symTy, stmtCtx); |
902 | | - mlir::Value castTo = builder.createConvert(loc, symTy, initVal); |
903 | | - fir::HasValueOp::create(builder, loc, castTo); |
904 | | - }); |
| 1028 | + !Fortran::semantics::HasCUDAComponent(sym); |
| 1029 | + if (isDerivedTypeScalar) { |
| 1030 | + const auto &derivedSpec = *declTy->AsDerived(); |
| 1031 | + if (isEligibleForComponentWiseInit(derivedSpec)) { |
| 1032 | + // Component-wise initialization. |
| 1033 | + mlir::Value baseAddr = fir::getBase(exv); |
| 1034 | + auto recTy = mlir::cast<fir::RecordType>( |
| 1035 | + fir::unwrapPassByRefType(baseAddr.getType())); |
| 1036 | + genDerivedTypeComponentInit(converter, loc, derivedSpec, baseAddr, |
| 1037 | + recTy); |
| 1038 | + } else { |
| 1039 | + // Initialize via bulk memory copy from a global constant. |
| 1040 | + genInlinedInitWithMemcpy(converter, sym, symTy, exv, details, declTy); |
905 | 1041 | } |
906 | | - auto addrOf = fir::AddrOfOp::create(builder, loc, global.resultType(), |
907 | | - global.getSymbol()); |
908 | | - fir::CopyOp::create(builder, loc, addrOf, fir::getBase(exv), |
909 | | - /*noOverlap=*/true); |
910 | 1042 | } else { |
911 | 1043 | mlir::Value box = builder.createBox(loc, exv); |
912 | 1044 | fir::runtime::genDerivedTypeInitialize(builder, loc, box); |
|
0 commit comments