Skip to content

Commit 67313f1

Browse files
committed
Fix issues with some hlfir.declare shape arguments
1 parent d598958 commit 67313f1

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/BuiltinOps.h"
2323
#include "llvm/ADT/SetVector.h"
2424
#include "llvm/ADT/SmallVector.h"
25+
#include "llvm/ADT/TypeSwitch.h"
2526

2627
namespace flangomp {
2728
#define GEN_PASS_DEF_FUNCTIONFILTERINGPASS
@@ -443,16 +444,38 @@ class FunctionFilteringPass
443444
builder.getI64IntegerAttr(0));
444445

445446
if (auto shape = declareOp.getShape()) {
447+
// The pre-cg rewrite pass requires the shape to be defined by one of
448+
// fir.shape, fir.shapeshift or fir.shift, so we need to make sure it's
449+
// still defined by one of these after this pass.
446450
Operation *shapeOp = shape.getDefiningOp();
447-
unsigned numArgs = shapeOp->getNumOperands();
448-
if (isa<fir::ShapeShiftOp>(shapeOp))
449-
numArgs /= 2;
450-
451-
// Since the pre-cg rewrite pass requires the shape to be defined by one
452-
// of fir.shape, fir.shapeshift or fir.shift, we need to create one of
453-
// these.
454-
llvm::SmallVector<Value> extents(numArgs, zero);
455-
auto newShape = builder.create<fir::ShapeOp>(shape.getLoc(), extents);
451+
llvm::SmallVector<Value> extents(shapeOp->getNumOperands(), zero);
452+
Value newShape =
453+
llvm::TypeSwitch<Operation *, Value>(shapeOp)
454+
.Case([&](fir::ShapeOp op) {
455+
return builder.create<fir::ShapeOp>(op.getLoc(), extents);
456+
})
457+
.Case([&](fir::ShapeShiftOp op) {
458+
auto type = fir::ShapeShiftType::get(op.getContext(),
459+
extents.size() / 2);
460+
return builder.create<fir::ShapeShiftOp>(op.getLoc(), type,
461+
extents);
462+
})
463+
.Case([&](fir::ShiftOp op) {
464+
auto type =
465+
fir::ShiftType::get(op.getContext(), extents.size());
466+
return builder.create<fir::ShiftOp>(op.getLoc(), type,
467+
extents);
468+
})
469+
.Default([](Operation *op) {
470+
op->emitOpError()
471+
<< "hlfir.declare shape expected to be one of: "
472+
"fir.shape, fir.shapeshift or fir.shift";
473+
return nullptr;
474+
});
475+
476+
if (!newShape)
477+
return failure();
478+
456479
declareOp.getShapeMutable().assign(newShape);
457480
}
458481

flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ module attributes {omp.is_target_device = true} {
163163
// CHECK-SAME: (%[[X:.*]]: [[X_TYPE:[^)]*]])
164164
func.func @box_ptr(%x: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
165165
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i64
166-
// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[ZERO]] : (i64) -> !fir.shape<1>
166+
// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape_shift %[[ZERO]], %[[ZERO]] : (i64, i64) -> !fir.shapeshift<1>
167167
// CHECK-NEXT: %[[PLACEHOLDER_X:.*]] = fir.alloca i1
168168
// CHECK-NEXT: %[[ALLOCA_X:.*]] = fir.convert %[[PLACEHOLDER_X]] : (!fir.ref<i1>) -> [[X_TYPE]]
169169
%0 = fir.alloca !fir.box<!fir.ptr<!fir.array<?xi32>>>
@@ -187,7 +187,7 @@ module attributes {omp.is_target_device = true} {
187187
%9:3 = fir.box_dims %3, %c0_2 : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, index) -> (index, index, index)
188188
%10 = fir.shape_shift %9#0, %9#1 : (index, index) -> !fir.shapeshift<1>
189189

190-
// CHECK-NEXT: %[[Y_DECL:.*]]:2 = hlfir.declare %[[ALLOCA_Y]](%[[SHAPE]]) {fortran_attrs = #fir.var_attrs<target>, uniq_name = "y"} : ([[Y_TYPE]], !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, [[Y_TYPE]])
190+
// CHECK-NEXT: %[[Y_DECL:.*]]:2 = hlfir.declare %[[ALLOCA_Y]](%[[SHAPE]]) {fortran_attrs = #fir.var_attrs<target>, uniq_name = "y"} : ([[Y_TYPE]], !fir.shapeshift<1>) -> (!fir.box<!fir.array<?xi32>>, [[Y_TYPE]])
191191
%11:2 = hlfir.declare %8(%10) {fortran_attrs = #fir.var_attrs<target>, uniq_name = "y"} : (!fir.ptr<!fir.array<?xi32>>, !fir.shapeshift<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.ptr<!fir.array<?xi32>>)
192192
%c1_3 = arith.constant 1 : index
193193
%c0_4 = arith.constant 0 : index

0 commit comments

Comments
 (0)