Skip to content

Commit c1c0d55

Browse files
authored
[flang] Non-type-bound defined IO lowering for an array of derived type (#134667)
Update Non-type-bound IO lowering to call OutputDerivedType for an array of derived type (rather than OutputDescriptor).
1 parent 23c27f3 commit c1c0d55

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

Diff for: flang/lib/Lower/IO.cpp

+15-4
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,22 @@ static void genNamelistIO(Fortran::lower::AbstractConverter &converter,
609609
ok = builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
610610
}
611611

612+
/// Is \p type a derived type or an array of derived type?
613+
static bool containsDerivedType(mlir::Type type) {
614+
mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(type));
615+
if (mlir::isa<fir::RecordType>(argTy))
616+
return true;
617+
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(argTy))
618+
if (mlir::isa<fir::RecordType>(seqTy.getEleTy()))
619+
return true;
620+
return false;
621+
}
622+
612623
/// Get the output function to call for a value of the given type.
613624
static mlir::func::FuncOp getOutputFunc(mlir::Location loc,
614625
fir::FirOpBuilder &builder,
615626
mlir::Type type, bool isFormatted) {
616-
if (mlir::isa<fir::RecordType>(fir::unwrapPassByRefType(type)))
627+
if (containsDerivedType(type))
617628
return fir::runtime::getIORuntimeFunc<mkIOKey(OutputDerivedType)>(loc,
618629
builder);
619630
if (!isFormatted)
@@ -710,7 +721,7 @@ static void genOutputItemList(
710721
if (mlir::isa<fir::BoxType>(argType)) {
711722
mlir::Value box = fir::getBase(converter.genExprBox(loc, *expr, stmtCtx));
712723
outputFuncArgs.push_back(builder.createConvert(loc, argType, box));
713-
if (mlir::isa<fir::RecordType>(fir::unwrapPassByRefType(itemTy)))
724+
if (containsDerivedType(itemTy))
714725
outputFuncArgs.push_back(getNonTbpDefinedIoTableAddr(converter));
715726
} else if (helper.isCharacterScalar(itemTy)) {
716727
fir::ExtendedValue exv = converter.genExprAddr(loc, expr, stmtCtx);
@@ -745,7 +756,7 @@ static void genOutputItemList(
745756
static mlir::func::FuncOp getInputFunc(mlir::Location loc,
746757
fir::FirOpBuilder &builder,
747758
mlir::Type type, bool isFormatted) {
748-
if (mlir::isa<fir::RecordType>(fir::unwrapPassByRefType(type)))
759+
if (containsDerivedType(type))
749760
return fir::runtime::getIORuntimeFunc<mkIOKey(InputDerivedType)>(loc,
750761
builder);
751762
if (!isFormatted)
@@ -817,7 +828,7 @@ createIoRuntimeCallForItem(Fortran::lower::AbstractConverter &converter,
817828
auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
818829
assert(boxTy && "must be previously emboxed");
819830
inputFuncArgs.push_back(builder.createConvert(loc, argType, box));
820-
if (mlir::isa<fir::RecordType>(fir::unwrapPassByRefType(boxTy)))
831+
if (containsDerivedType(boxTy))
821832
inputFuncArgs.push_back(getNonTbpDefinedIoTableAddr(converter));
822833
} else {
823834
mlir::Value itemAddr = fir::getBase(item);

Diff for: flang/test/Lower/io-derived-type.f90

+11
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ program p
101101
use m
102102
character*3 ccc(4)
103103
namelist /nnn/ jjj, ccc
104+
type(t) :: y(5)
104105

105106
! CHECK: fir.call @_QMmPtest1
106107
call test1
@@ -115,6 +116,16 @@ program p
115116
! CHECK: %[[V_100:[0-9]+]] = fir.convert %[[V_99]] : (!fir.ref<tuple<i64, !fir.ref<!fir.array<1xtuple<!fir.ref<none>, !fir.ref<none>, i32, i1>>>, i1>>) -> !fir.ref<none>
116117
! CHECK: %[[V_101:[0-9]+]] = fir.call @_FortranAioOutputDerivedType(%{{.*}}, %[[V_98]], %[[V_100]]) fastmath<contract> : (!fir.ref<i8>, !fir.box<none>, !fir.ref<none>) -> i1
117118
print *, 'main, should call wft: ', t(4)
119+
120+
! CHECK: %[[V_33:[0-9]+]] = fir.shape %c2{{.*}} : (index) -> !fir.shape<1>
121+
! CHECK: %[[V_34:[0-9]+]] = hlfir.designate %7#0 (%c2{{.*}}:%c3{{.*}}:%c1{{.*}}) shape %[[V_33]] : (!fir.ref<!fir.array<5x!fir.type<_QMmTt{n:i32}>>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<2x!fir.type<_QMmTt{n:i32}>>>
122+
! CHECK: %[[V_35:[0-9]+]] = fir.shape %c2{{.*}} : (index) -> !fir.shape<1>
123+
! CHECK: %[[V_36:[0-9]+]] = fir.embox %[[V_34]](%[[V_35]]) : (!fir.ref<!fir.array<2x!fir.type<_QMmTt{n:i32}>>>, !fir.shape<1>) -> !fir.box<!fir.array<2x!fir.type<_QMmTt{n:i32}>>>
124+
! CHECK: %[[V_37:[0-9]+]] = fir.convert %[[V_36]] : (!fir.box<!fir.array<2x!fir.type<_QMmTt{n:i32}>>>) -> !fir.box<none>
125+
! CHECK: %[[V_38:[0-9]+]] = fir.address_of(@_QQF.nonTbpDefinedIoTable) : !fir.ref<tuple<i64, !fir.ref<!fir.array<1xtuple<!fir.ref<none>, !fir.ref<none>, i32, i1>>>, i1>>
126+
! CHECK: %[[V_39:[0-9]+]] = fir.convert %[[V_38]] : (!fir.ref<tuple<i64, !fir.ref<!fir.array<1xtuple<!fir.ref<none>, !fir.ref<none>, i32, i1>>>, i1>>) -> !fir.ref<none>
127+
! CHECK: %[[V_40:[0-9]+]] = fir.call @_FortranAioOutputDerivedType(%{{.*}}, %[[V_37]], %[[V_39]]) fastmath<contract> : (!fir.ref<i8>, !fir.box<none>, !fir.ref<none>) -> i1
128+
print *, y(2:3)
118129
end
119130

120131
! CHECK: fir.global linkonce @_QQMmFtest1.nonTbpDefinedIoTable.list constant : !fir.array<1xtuple<!fir.ref<none>, !fir.ref<none>, i32, i1>>

0 commit comments

Comments
 (0)