@@ -301,17 +301,35 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
301
301
ReductionProcessor::ReductionIdentifier redId,
302
302
fir::BaseBoxType boxTy, mlir::Value lhs,
303
303
mlir::Value rhs) {
304
- fir::SequenceType seqTy =
305
- mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy ());
306
- // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
307
- if (!seqTy || seqTy.hasUnknownShape ())
304
+ fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
305
+ fir::unwrapRefType (boxTy.getEleTy ()));
306
+ fir::HeapType heapTy =
307
+ mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy ());
308
+ if ((!seqTy || seqTy.hasUnknownShape ()) && !heapTy)
308
309
TODO (loc, " Unsupported boxed type in OpenMP reduction" );
309
310
310
311
// load fir.ref<fir.box<...>>
311
312
mlir::Value lhsAddr = lhs;
312
313
lhs = builder.create <fir::LoadOp>(loc, lhs);
313
314
rhs = builder.create <fir::LoadOp>(loc, rhs);
314
315
316
+ if (heapTy && !seqTy) {
317
+ // get box contents (heap pointers)
318
+ lhs = builder.create <fir::BoxAddrOp>(loc, lhs);
319
+ rhs = builder.create <fir::BoxAddrOp>(loc, rhs);
320
+ mlir::Value lhsValAddr = lhs;
321
+
322
+ // load heap pointers
323
+ lhs = builder.create <fir::LoadOp>(loc, lhs);
324
+ rhs = builder.create <fir::LoadOp>(loc, rhs);
325
+
326
+ mlir::Value result = ReductionProcessor::createScalarCombiner (
327
+ builder, loc, redId, heapTy.getEleTy (), lhs, rhs);
328
+ builder.create <fir::StoreOp>(loc, result, lhsValAddr);
329
+ builder.create <mlir::omp::YieldOp>(loc, lhsAddr);
330
+ return ;
331
+ }
332
+
315
333
const unsigned rank = seqTy.getDimension ();
316
334
llvm::SmallVector<mlir::Value> extents;
317
335
extents.reserve (rank);
@@ -338,6 +356,10 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
338
356
339
357
// Iterate over array elements, applying the equivalent scalar reduction:
340
358
359
+ // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
360
+ // and so no null check is needed here before indexing into the (possibly
361
+ // allocatable) arrays.
362
+
341
363
// A hlfir::elemental here gets inlined with a temporary so create the
342
364
// loop nest directly.
343
365
// This function already controls all of the code in this region so we
@@ -412,9 +434,11 @@ createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
412
434
413
435
mlir::Type valTy = fir::unwrapRefType (redTy);
414
436
if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
415
- mlir::Type innerTy = fir::extractSequenceType (boxTy);
416
- if (!mlir::isa<fir::SequenceType>(innerTy))
417
- typeError ();
437
+ if (!mlir::isa<fir::HeapType>(boxTy.getEleTy ())) {
438
+ mlir::Type innerTy = fir::extractSequenceType (boxTy);
439
+ if (!mlir::isa<fir::SequenceType>(innerTy))
440
+ typeError ();
441
+ }
418
442
419
443
mlir::Value arg = block->getArgument (0 );
420
444
arg = builder.loadIfRef (loc, arg);
@@ -443,14 +467,27 @@ createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
443
467
typeError ();
444
468
}
445
469
470
+ // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
471
+ static mlir::Type unwrapSeqOrBoxedType (mlir::Type ty) {
472
+ if (auto seqTy = ty.dyn_cast <fir::SequenceType>())
473
+ return seqTy.getEleTy ();
474
+ if (auto boxTy = ty.dyn_cast <fir::BaseBoxType>()) {
475
+ auto eleTy = fir::unwrapRefType (boxTy.getEleTy ());
476
+ if (auto seqTy = eleTy.dyn_cast <fir::SequenceType>())
477
+ return seqTy.getEleTy ();
478
+ return eleTy;
479
+ }
480
+ return ty;
481
+ }
482
+
446
483
static mlir::Value
447
484
createReductionInitRegion (fir::FirOpBuilder &builder, mlir::Location loc,
448
485
mlir::omp::DeclareReductionOp &reductionDecl,
449
486
const ReductionProcessor::ReductionIdentifier redId,
450
487
mlir::Type type, bool isByRef) {
451
488
mlir::Type ty = fir::unwrapRefType (type);
452
489
mlir::Value initValue = ReductionProcessor::getReductionInitValue (
453
- loc, fir::unwrapSeqOrBoxedSeqType (ty), redId, builder);
490
+ loc, unwrapSeqOrBoxedType (ty), redId, builder);
454
491
455
492
if (fir::isa_trivial (ty)) {
456
493
if (isByRef) {
@@ -462,39 +499,99 @@ createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
462
499
return initValue;
463
500
}
464
501
502
+ // check if an allocatable box is unallocated. If so, initialize the boxAlloca
503
+ // to be unallocated e.g.
504
+ // %box_alloca = fir.alloca !fir.box<!fir.heap<...>>
505
+ // %addr = fir.box_addr %box
506
+ // if (%addr == 0) {
507
+ // %nullbox = fir.embox %addr
508
+ // fir.store %nullbox to %box_alloca
509
+ // } else {
510
+ // // ...
511
+ // fir.store %something to %box_alloca
512
+ // }
513
+ // omp.yield %box_alloca
514
+ mlir::Value blockArg =
515
+ builder.loadIfRef (loc, builder.getBlock ()->getArgument (0 ));
516
+ auto handleNullAllocatable = [&](mlir::Value boxAlloca) -> fir::IfOp {
517
+ mlir::Value addr = builder.create <fir::BoxAddrOp>(loc, blockArg);
518
+ mlir::Value isNotAllocated = builder.genIsNullAddr (loc, addr);
519
+ fir::IfOp ifOp = builder.create <fir::IfOp>(loc, isNotAllocated,
520
+ /* withElseRegion=*/ true );
521
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
522
+ // just embox the null address and return
523
+ mlir::Value nullBox = builder.create <fir::EmboxOp>(loc, ty, addr);
524
+ builder.create <fir::StoreOp>(loc, nullBox, boxAlloca);
525
+ return ifOp;
526
+ };
527
+
465
528
// all arrays are boxed
466
529
if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
467
- assert (isByRef && " passing arrays by value is unsupported" );
468
- // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
469
- mlir::Type innerTy = fir::extractSequenceType (boxTy);
530
+ assert (isByRef && " passing boxes by value is unsupported" );
531
+ bool isAllocatable = mlir::isa<fir::HeapType>(boxTy.getEleTy ());
532
+ mlir::Value boxAlloca = builder.create <fir::AllocaOp>(loc, ty);
533
+ mlir::Type innerTy = fir::unwrapRefType (boxTy.getEleTy ());
534
+ if (fir::isa_trivial (innerTy)) {
535
+ // boxed non-sequence value e.g. !fir.box<!fir.heap<i32>>
536
+ if (!isAllocatable)
537
+ TODO (loc, " Reduction of non-allocatable trivial typed box" );
538
+
539
+ fir::IfOp ifUnallocated = handleNullAllocatable (boxAlloca);
540
+
541
+ builder.setInsertionPointToStart (&ifUnallocated.getElseRegion ().front ());
542
+ mlir::Value valAlloc = builder.create <fir::AllocMemOp>(loc, innerTy);
543
+ builder.createStoreWithConvert (loc, initValue, valAlloc);
544
+ mlir::Value box = builder.create <fir::EmboxOp>(loc, ty, valAlloc);
545
+ builder.create <fir::StoreOp>(loc, box, boxAlloca);
546
+
547
+ auto insPt = builder.saveInsertionPoint ();
548
+ createReductionCleanupRegion (builder, loc, reductionDecl);
549
+ builder.restoreInsertionPoint (insPt);
550
+ builder.setInsertionPointAfter (ifUnallocated);
551
+ return boxAlloca;
552
+ }
553
+ innerTy = fir::extractSequenceType (boxTy);
470
554
if (!mlir::isa<fir::SequenceType>(innerTy))
471
555
TODO (loc, " Unsupported boxed type for reduction" );
556
+
557
+ fir::IfOp ifUnallocated{nullptr };
558
+ if (isAllocatable) {
559
+ ifUnallocated = handleNullAllocatable (boxAlloca);
560
+ builder.setInsertionPointToStart (&ifUnallocated.getElseRegion ().front ());
561
+ }
562
+
472
563
// Create the private copy from the initial fir.box:
473
- hlfir::Entity source = hlfir::Entity{builder. getBlock ()-> getArgument ( 0 ) };
564
+ hlfir::Entity source = hlfir::Entity{blockArg };
474
565
475
566
// Allocating on the heap in case the whole reduction is nested inside of a
476
567
// loop
477
568
// TODO: compare performance here to using allocas - this could be made to
478
569
// work by inserting stacksave/stackrestore around the reduction in
479
570
// openmpirbuilder
480
571
auto [temp, needsDealloc] = createTempFromMold (loc, builder, source);
481
- // if needsDealloc isn't statically false, add cleanup region. TODO: always
572
+ // if needsDealloc isn't statically false, add cleanup region. Always
482
573
// do this for allocatable boxes because they might have been re-allocated
483
574
// in the body of the loop/parallel region
575
+
484
576
std::optional<int64_t > cstNeedsDealloc =
485
577
fir::getIntIfConstant (needsDealloc);
486
578
assert (cstNeedsDealloc.has_value () &&
487
579
" createTempFromMold decides this statically" );
488
580
if (cstNeedsDealloc.has_value () && *cstNeedsDealloc != false ) {
489
581
mlir::OpBuilder::InsertionGuard guard (builder);
490
582
createReductionCleanupRegion (builder, loc, reductionDecl);
583
+ } else {
584
+ assert (!isAllocatable && " Allocatable arrays must be heap allocated" );
491
585
}
492
586
493
587
// Put the temporary inside of a box:
494
588
hlfir::Entity box = hlfir::genVariableBox (loc, builder, temp);
495
- builder.create <hlfir::AssignOp>(loc, initValue, box);
496
- mlir::Value boxAlloca = builder.create <fir::AllocaOp>(loc, ty);
497
- builder.create <fir::StoreOp>(loc, box, boxAlloca);
589
+ // hlfir::genVariableBox removes fir.heap<> around the element type
590
+ mlir::Value convertedBox = builder.createConvert (loc, ty, box.getBase ());
591
+ builder.create <hlfir::AssignOp>(loc, initValue, convertedBox);
592
+ builder.create <fir::StoreOp>(loc, convertedBox, boxAlloca);
593
+ if (ifUnallocated)
594
+ builder.setInsertionPointAfter (ifUnallocated);
498
595
return boxAlloca;
499
596
}
500
597
0 commit comments