@@ -1341,10 +1341,10 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
1341
1341
1342
1342
// / Get the address of the type descriptor global variable that was created by
1343
1343
// / lowering for derived type \p recType.
1344
- mlir::Value getTypeDescriptor (mlir::ModuleOp mod,
1345
- mlir::ConversionPatternRewriter &rewriter,
1346
- mlir::Location loc ,
1347
- fir::RecordType recType) const {
1344
+ template < typename ModOpTy>
1345
+ mlir::Value
1346
+ getTypeDescriptor (ModOpTy mod, mlir::ConversionPatternRewriter &rewriter ,
1347
+ mlir::Location loc, fir::RecordType recType) const {
1348
1348
std::string name =
1349
1349
this ->options .typeDescriptorsRenamedForAssembly
1350
1350
? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
@@ -1369,7 +1369,8 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
1369
1369
return rewriter.create <mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
1370
1370
}
1371
1371
1372
- mlir::Value populateDescriptor (mlir::Location loc, mlir::ModuleOp mod,
1372
+ template <typename ModOpTy>
1373
+ mlir::Value populateDescriptor (mlir::Location loc, ModOpTy mod,
1373
1374
fir::BaseBoxType boxTy, mlir::Type inputType,
1374
1375
mlir::ConversionPatternRewriter &rewriter,
1375
1376
unsigned rank, mlir::Value eleSize,
@@ -1508,10 +1509,16 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
1508
1509
extraField =
1509
1510
this ->getExtraFromBox (loc, sourceBoxTyPair, sourceBox, rewriter);
1510
1511
}
1511
- auto mod = box->template getParentOfType <mlir::ModuleOp>();
1512
- mlir::Value descriptor =
1513
- populateDescriptor (loc, mod, boxTy, inputType, rewriter, rank, eleSize,
1514
- cfiTy, typeDesc, allocatorIdx, extraField);
1512
+
1513
+ mlir::Value descriptor;
1514
+ if (auto gpuMod = box->template getParentOfType <mlir::gpu::GPUModuleOp>())
1515
+ descriptor = populateDescriptor (loc, gpuMod, boxTy, inputType, rewriter,
1516
+ rank, eleSize, cfiTy, typeDesc,
1517
+ allocatorIdx, extraField);
1518
+ else if (auto mod = box->template getParentOfType <mlir::ModuleOp>())
1519
+ descriptor = populateDescriptor (loc, mod, boxTy, inputType, rewriter,
1520
+ rank, eleSize, cfiTy, typeDesc,
1521
+ allocatorIdx, extraField);
1515
1522
1516
1523
return {boxTy, descriptor, eleSize};
1517
1524
}
@@ -1554,11 +1561,17 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
1554
1561
mlir::Value extraField =
1555
1562
this ->getExtraFromBox (loc, inputBoxTyPair, loweredBox, rewriter);
1556
1563
1557
- auto mod = box->template getParentOfType <mlir::ModuleOp>();
1558
- mlir::Value descriptor =
1559
- populateDescriptor (loc, mod, boxTy, box.getBox ().getType (), rewriter,
1560
- rank, eleSize, cfiTy, typeDesc,
1561
- /* allocatorIdx=*/ kDefaultAllocator , extraField);
1564
+ mlir::Value descriptor;
1565
+ if (auto gpuMod = box->template getParentOfType <mlir::gpu::GPUModuleOp>())
1566
+ descriptor =
1567
+ populateDescriptor (loc, gpuMod, boxTy, box.getBox ().getType (),
1568
+ rewriter, rank, eleSize, cfiTy, typeDesc,
1569
+ /* allocatorIdx=*/ kDefaultAllocator , extraField);
1570
+ else if (auto mod = box->template getParentOfType <mlir::ModuleOp>())
1571
+ descriptor =
1572
+ populateDescriptor (loc, mod, boxTy, box.getBox ().getType (), rewriter,
1573
+ rank, eleSize, cfiTy, typeDesc,
1574
+ /* allocatorIdx=*/ kDefaultAllocator , extraField);
1562
1575
1563
1576
return {boxTy, descriptor, eleSize};
1564
1577
}
0 commit comments