Skip to content

Commit 580e862

Browse files
committed
[OpenMP][Flang] Emit default declare mappers implicitly for derived types
This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types. This supports nested derived types as well.
1 parent c2d8c55 commit 580e862

File tree

4 files changed

+176
-19
lines changed

4 files changed

+176
-19
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,23 +1102,30 @@ void ClauseProcessor::processMapObjects(
11021102

11031103
auto getDefaultMapperID = [&](const omp::Object &object,
11041104
std::string &mapperIdName) {
1105-
if (!mlir::isa<mlir::omp::DeclareMapperOp>(
1106-
firOpBuilder.getRegion().getParentOp())) {
1107-
const semantics::DerivedTypeSpec *typeSpec = nullptr;
1108-
1109-
if (object.sym()->owner().IsDerivedType())
1110-
typeSpec = object.sym()->owner().derivedTypeSpec();
1111-
else if (object.sym()->GetType() &&
1112-
object.sym()->GetType()->category() ==
1113-
semantics::DeclTypeSpec::TypeDerived)
1114-
typeSpec = &object.sym()->GetType()->derivedTypeSpec();
1115-
1116-
if (typeSpec) {
1117-
mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
1118-
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
1119-
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1120-
}
1105+
const semantics::DerivedTypeSpec *typeSpec = nullptr;
1106+
1107+
if (object.sym()->GetType() && object.sym()->GetType()->category() ==
1108+
semantics::DeclTypeSpec::TypeDerived)
1109+
typeSpec = &object.sym()->GetType()->derivedTypeSpec();
1110+
else if (object.sym()->owner().IsDerivedType())
1111+
typeSpec = object.sym()->owner().derivedTypeSpec();
1112+
1113+
if (typeSpec) {
1114+
mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
1115+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
1116+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1117+
else
1118+
mapperIdName =
1119+
converter.mangleName(mapperIdName, *typeSpec->GetScope());
11211120
}
1121+
1122+
// Make sure we don't return a mapper to self
1123+
llvm::StringRef parentOpName;
1124+
if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
1125+
firOpBuilder.getRegion().getParentOp()))
1126+
parentOpName = declMapOp.getSymName();
1127+
if (mapperIdName == parentOpName)
1128+
mapperIdName = "";
11221129
};
11231130

11241131
// Create the mapper symbol from its name, if specified.

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2348,6 +2348,124 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23482348
queue, item, clauseOps);
23492349
}
23502350

2351+
static mlir::FlatSymbolRefAttr
2352+
genImplicitDefaultDeclareMapper(lower::AbstractConverter &converter,
2353+
mlir::Location loc, fir::RecordType recordType,
2354+
llvm::StringRef mapperNameStr) {
2355+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2356+
lower::StatementContext stmtCtx;
2357+
2358+
// Save current insertion point before moving to the module scope to create
2359+
// the DeclareMapperOp
2360+
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
2361+
2362+
firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
2363+
auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
2364+
loc, mapperNameStr, recordType);
2365+
auto &region = declMapperOp.getRegion();
2366+
firOpBuilder.createBlock(&region);
2367+
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
2368+
2369+
auto declareOp =
2370+
firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/"");
2371+
2372+
const auto genBoundsOps = [&](mlir::Value mapVal,
2373+
llvm::SmallVectorImpl<mlir::Value> &bounds) {
2374+
fir::ExtendedValue extVal =
2375+
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
2376+
hlfir::Entity{mapVal},
2377+
/*contiguousHint=*/true)
2378+
.first;
2379+
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
2380+
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
2381+
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
2382+
mlir::omp::MapBoundsType>(
2383+
firOpBuilder, info, extVal,
2384+
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
2385+
};
2386+
2387+
// Return a reference to the contents of a derived type with one field.
2388+
// Also return the field type.
2389+
const auto getFieldRef =
2390+
[&](mlir::Value rec,
2391+
unsigned index) -> std::tuple<mlir::Value, mlir::Type> {
2392+
auto recType = mlir::dyn_cast<fir::RecordType>(
2393+
fir::unwrapPassByRefType(rec.getType()));
2394+
auto [fieldName, fieldTy] = recType.getTypeList()[index];
2395+
mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
2396+
loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
2397+
fir::getTypeParams(rec));
2398+
return {firOpBuilder.create<fir::CoordinateOp>(
2399+
loc, firOpBuilder.getRefType(fieldTy), rec, field),
2400+
fieldTy};
2401+
};
2402+
2403+
mlir::omp::DeclareMapperInfoOperands clauseOps;
2404+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
2405+
llvm::SmallVector<mlir::Value> memberMapOps;
2406+
2407+
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2408+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
2409+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
2410+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
2411+
mlir::omp::VariableCaptureKind captureKind =
2412+
mlir::omp::VariableCaptureKind::ByRef;
2413+
int64_t index = 0;
2414+
2415+
// Populate the declareMapper region with the map information.
2416+
for (const auto &[memberName, memberType] :
2417+
mlir::dyn_cast<fir::RecordType>(recordType).getTypeList()) {
2418+
auto [ref, type] = getFieldRef(declareOp.getBase(), index);
2419+
mlir::FlatSymbolRefAttr mapperId;
2420+
if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
2421+
std::string mapperIdName =
2422+
recType.getName().str() + ".omp.default.mapper";
2423+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
2424+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2425+
else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName))
2426+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2427+
2428+
if (converter.getModuleOp().lookupSymbol(mapperIdName))
2429+
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2430+
mapperIdName);
2431+
else
2432+
mapperId = genImplicitDefaultDeclareMapper(converter, loc, recType,
2433+
mapperIdName);
2434+
}
2435+
2436+
llvm::SmallVector<mlir::Value> bounds;
2437+
genBoundsOps(ref, bounds);
2438+
mlir::Value mapOp = createMapInfoOp(
2439+
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, "", bounds,
2440+
/*members=*/{},
2441+
/*membersIndex=*/mlir::ArrayAttr{},
2442+
static_cast<
2443+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2444+
mapFlag),
2445+
captureKind, ref.getType(), /*partialMap=*/false, mapperId);
2446+
memberMapOps.emplace_back(mapOp);
2447+
memberPlacementIndices.emplace_back(llvm::SmallVector<int64_t>{index++});
2448+
}
2449+
2450+
llvm::SmallVector<mlir::Value> bounds;
2451+
genBoundsOps(declareOp.getOriginalBase(), bounds);
2452+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
2453+
firOpBuilder, loc, declareOp.getOriginalBase(),
2454+
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
2455+
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices),
2456+
static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2457+
mapFlag),
2458+
captureKind, declareOp.getType(0),
2459+
/*partialMap=*/true);
2460+
2461+
clauseOps.mapVars.emplace_back(mapOp);
2462+
2463+
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
2464+
// declMapperOp->dumpPretty();
2465+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2466+
mapperNameStr);
2467+
}
2468+
23512469
static mlir::omp::TargetOp
23522470
genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23532471
lower::StatementContext &stmtCtx,
@@ -2420,15 +2538,26 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24202538
name << sym.name().ToString();
24212539

24222540
mlir::FlatSymbolRefAttr mapperId;
2423-
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
2541+
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived &&
2542+
defaultMaps.empty()) {
24242543
auto &typeSpec = sym.GetType()->derivedTypeSpec();
24252544
std::string mapperIdName =
24262545
typeSpec.name().ToString() + ".omp.default.mapper";
24272546
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
24282547
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2548+
else
2549+
mapperIdName =
2550+
converter.mangleName(mapperIdName, *typeSpec.GetScope());
2551+
24292552
if (converter.getModuleOp().lookupSymbol(mapperIdName))
24302553
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
24312554
mapperIdName);
2555+
else
2556+
mapperId = genImplicitDefaultDeclareMapper(
2557+
converter, loc,
2558+
mlir::cast<fir::RecordType>(
2559+
converter.genType(sym.GetType()->derivedTypeSpec())),
2560+
mapperIdName);
24322561
}
24332562

24342563
fir::factory::AddrAndBoundsInfo info =
@@ -4039,6 +4168,7 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
40394168
ClauseProcessor cp(converter, semaCtx, clauses);
40404169
cp.processMap(loc, stmtCtx, clauseOps);
40414170
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
4171+
// declMapperOp->dumpPretty();
40424172
}
40434173

40444174
static void

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class MapInfoFinalizationPass
377377
getDescriptorMapType(mapType, target)),
378378
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
379379
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
380-
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
380+
op.getMapperIdAttr(), op.getNameAttr(),
381381
/*partial_map=*/builder.getBoolAttr(false));
382382
op.replaceAllUsesWith(newDescParentMapOp.getResult());
383383
op->erase();

flang/test/Lower/OpenMP/derived-type-map.f90

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
22

3+
!CHECK: omp.declare_mapper @[[MAPPER1:_QQFmaptype_derived_implicit_allocatablescalar_and_array.omp.default.mapper]] : !fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {
4+
!CHECK: omp.declare_mapper @[[MAPPER2:_QQFmaptype_derived_implicitscalar_and_array.omp.default.mapper]] : !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {
35

46
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicitEscalar_arr"}
57
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_implicitEscalar_arr"} : (!fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) -> (!fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>)
6-
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}
8+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(implicit, tofrom) capture(ByRef) mapper(@[[MAPPER2]]) -> !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}
79
!CHECK: omp.target map_entries(%[[MAP]] -> %[[ARG0:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) {
810
subroutine mapType_derived_implicit
911
type :: scalar_and_array
@@ -18,6 +20,24 @@ subroutine mapType_derived_implicit
1820
!$omp end target
1921
end subroutine mapType_derived_implicit
2022

23+
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.heap<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"}
24+
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"} : {{.*}}
25+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : {{.*}}) map_clauses(implicit, to) capture(ByRef) mapper(@[[MAPPER1]])
26+
!CHECK: omp.target map_entries(%[[MAP]] -> %[[ARG0:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>>>, !fir.llvm_ptr<!fir.ref<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>>) {
27+
subroutine mapType_derived_implicit_allocatable
28+
type :: scalar_and_array
29+
real(4) :: real
30+
integer(4) :: array(10)
31+
integer(4) :: int
32+
end type scalar_and_array
33+
type(scalar_and_array), allocatable :: scalar_arr
34+
35+
allocate (scalar_arr)
36+
!$omp target
37+
scalar_arr%int = 1
38+
!$omp end target
39+
end subroutine mapType_derived_implicit_allocatable
40+
2141
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_explicitEscalar_arr"}
2242
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_explicitEscalar_arr"} : (!fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) -> (!fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>)
2343
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}

0 commit comments

Comments
 (0)