Skip to content

Commit a7dec68

Browse files
author
yanming
committed
[flang][fir] Add FIR structured control flow ops to SCF dialect pass.
1 parent c78e6bb commit a7dec68

File tree

6 files changed

+265
-0
lines changed

6 files changed

+265
-0
lines changed

flang/include/flang/Optimizer/Support/InitFIR.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
2626
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2727
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
28+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
2829
#include "mlir/InitAllDialects.h"
2930
#include "mlir/Pass/Pass.h"
3031
#include "mlir/Pass/PassRegistry.h"
@@ -103,6 +104,7 @@ inline void registerMLIRPassesForFortranTools() {
103104
mlir::registerPrintOpStatsPass();
104105
mlir::registerInlinerPass();
105106
mlir::registerSCCPPass();
107+
mlir::registerSCFPasses();
106108
mlir::affine::registerAffineScalarReplacementPass();
107109
mlir::registerSymbolDCEPass();
108110
mlir::registerLocationSnapshotPass();

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ std::unique_ptr<mlir::Pass>
7272
createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
7373
std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
7474
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
75+
std::unique_ptr<mlir::Pass> createFIRToSCFPass();
7576
std::unique_ptr<mlir::Pass>
7677
createAddDebugInfoPass(fir::AddDebugInfoOptions options = {});
7778

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ def AffineDialectDemotion : Pass<"demote-affine", "::mlir::func::FuncOp"> {
7676
];
7777
}
7878

79+
def FIRToSCFPass : Pass<"fir-to-scf"> {
80+
let summary = "Convert FIR structured control flow ops to SCF dialect.";
81+
let description = [{
82+
Convert FIR structured control flow ops to SCF dialect.
83+
}];
84+
let constructor = "::fir::createFIRToSCFPass()";
85+
let dependentDialects = [
86+
"fir::FIROpsDialect", "mlir::scf::SCFDialect"
87+
];
88+
}
89+
7990
def AnnotateConstantOperands : Pass<"annotate-constant"> {
8091
let summary = "Annotate constant operands to all FIR operations";
8192
let description = [{

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_flang_library(FIRTransforms
1616
CUFComputeSharedMemoryOffsetsAndSize.cpp
1717
ArrayValueCopy.cpp
1818
ExternalNameConversion.cpp
19+
FIRToSCF.cpp
1920
MemoryUtils.cpp
2021
MemoryAllocation.cpp
2122
StackArrays.cpp
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//===-- FIRToSCF.cpp ------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Dialect/FIRDialect.h"
10+
#include "flang/Optimizer/Transforms/Passes.h"
11+
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/Transforms/DialectConversion.h"
13+
14+
namespace fir {
15+
#define GEN_PASS_DEF_FIRTOSCFPASS
16+
#include "flang/Optimizer/Transforms/Passes.h.inc"
17+
} // namespace fir
18+
19+
using namespace fir;
20+
using namespace mlir;
21+
22+
namespace {
23+
class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
24+
public:
25+
void runOnOperation() override;
26+
};
27+
} // namespace
28+
29+
struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
30+
using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
31+
32+
LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp,
33+
PatternRewriter &rewriter) const override {
34+
auto loc = doLoopOp.getLoc();
35+
bool hasFinalValue = doLoopOp.getFinalValue().has_value();
36+
37+
// Get loop values from the DoLoopOp
38+
auto low = doLoopOp.getLowerBound();
39+
auto high = doLoopOp.getUpperBound();
40+
assert(low && high && "must be a Value");
41+
auto step = doLoopOp.getStep();
42+
llvm::SmallVector<mlir::Value> iterArgs;
43+
if (hasFinalValue)
44+
iterArgs.push_back(low);
45+
iterArgs.append(doLoopOp.getIterOperands().begin(),
46+
doLoopOp.getIterOperands().end());
47+
48+
// Caculate the trip count.
49+
auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
50+
auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
51+
auto tripCount = rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
52+
auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
53+
auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
54+
auto scfForOp =
55+
rewriter.create<scf::ForOp>(loc, zero, tripCount, one, iterArgs);
56+
57+
auto &loopOps = doLoopOp.getBody()->getOperations();
58+
auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
59+
auto results = resultOp.getOperands();
60+
Block *loweredBody = scfForOp.getBody();
61+
62+
loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
63+
loopOps.begin(),
64+
std::prev(loopOps.end()));
65+
66+
rewriter.setInsertionPointToStart(loweredBody);
67+
Value iv =
68+
rewriter.create<arith::MulIOp>(loc, scfForOp.getInductionVar(), step);
69+
iv = rewriter.create<arith::AddIOp>(loc, low, iv);
70+
71+
if (!results.empty()) {
72+
rewriter.setInsertionPointToEnd(loweredBody);
73+
rewriter.create<scf::YieldOp>(resultOp->getLoc(), results);
74+
}
75+
doLoopOp.getInductionVar().replaceAllUsesWith(iv);
76+
rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
77+
hasFinalValue
78+
? scfForOp.getRegionIterArgs().drop_front()
79+
: scfForOp.getRegionIterArgs());
80+
81+
// Copy loop annotations from the do loop to the loop entry condition.
82+
if (auto ann = doLoopOp.getLoopAnnotation())
83+
scfForOp->setAttr("loop_annotation", *ann);
84+
85+
rewriter.replaceOp(doLoopOp, scfForOp);
86+
return success();
87+
}
88+
};
89+
90+
void FIRToSCFPass::runOnOperation() {
91+
RewritePatternSet patterns(&getContext());
92+
patterns.add<DoLoopConversion>(patterns.getContext());
93+
ConversionTarget target(getContext());
94+
target.addIllegalOp<fir::DoLoopOp>();
95+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
96+
if (failed(
97+
applyPartialConversion(getOperation(), target, std::move(patterns))))
98+
signalPassFailure();
99+
}
100+
101+
std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() {
102+
return std::make_unique<FIRToSCFPass>();
103+
}

flang/test/Fir/FirToSCF/do-loop.fir

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// RUN: fir-opt %s --fir-to-scf | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @simple_loop(
4+
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
5+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
6+
// CHECK: %[[VAL_1:.*]] = arith.constant 100 : index
7+
// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
8+
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
9+
// CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
10+
// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
11+
// CHECK: %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
12+
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
13+
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
14+
// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] {
15+
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_0]] : index
16+
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_0]], %[[VAL_10]] : index
17+
// CHECK: %[[VAL_12:.*]] = fir.array_coor %[[ARG0]](%[[VAL_2]]) %[[VAL_11]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
18+
// CHECK: fir.store %[[VAL_3]] to %[[VAL_12]] : !fir.ref<i32>
19+
// CHECK: }
20+
// CHECK: return
21+
// CHECK: }
22+
func.func @simple_loop(%arg0: !fir.ref<!fir.array<100xi32>>) {
23+
%c1 = arith.constant 1 : index
24+
%c100 = arith.constant 100 : index
25+
%0 = fir.shape %c100 : (index) -> !fir.shape<1>
26+
%c1_i32 = arith.constant 1 : i32
27+
fir.do_loop %arg1 = %c1 to %c100 step %c1 {
28+
%1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
29+
fir.store %c1_i32 to %1 : !fir.ref<i32>
30+
}
31+
return
32+
}
33+
34+
// CHECK-LABEL: func.func @loop_with_negtive_step(
35+
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
36+
// CHECK: %[[VAL_0:.*]] = arith.constant 100 : index
37+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
38+
// CHECK: %[[VAL_2:.*]] = arith.constant -1 : index
39+
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
40+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
41+
// CHECK: %[[VAL_5:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
42+
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_2]] : index
43+
// CHECK: %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_2]] : index
44+
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
45+
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
46+
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] {
47+
// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_2]] : index
48+
// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_0]], %[[VAL_11]] : index
49+
// CHECK: %[[VAL_13:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_12]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
50+
// CHECK: fir.store %[[VAL_4]] to %[[VAL_13]] : !fir.ref<i32>
51+
// CHECK: }
52+
// CHECK: return
53+
// CHECK: }
54+
func.func @loop_with_negtive_step(%arg0: !fir.ref<!fir.array<100xi32>>) {
55+
%c100 = arith.constant 100 : index
56+
%c1 = arith.constant 1 : index
57+
%c-1 = arith.constant -1 : index
58+
%0 = fir.shape %c100 : (index) -> !fir.shape<1>
59+
%c1_i32 = arith.constant 1 : i32
60+
fir.do_loop %arg1 = %c100 to %c1 step %c-1 {
61+
%1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
62+
fir.store %c1_i32 to %1 : !fir.ref<i32>
63+
}
64+
return
65+
}
66+
67+
// CHECK-LABEL: func.func @loop_with_results(
68+
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
69+
// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<i32>) {
70+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
71+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
72+
// CHECK: %[[VAL_2:.*]] = arith.constant 100 : index
73+
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
74+
// CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
75+
// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
76+
// CHECK: %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
77+
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
78+
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
79+
// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_1]]) -> (i32) {
80+
// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_0]] : index
81+
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_0]], %[[VAL_12]] : index
82+
// CHECK: %[[VAL_14:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_13]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
83+
// CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
84+
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_11]], %[[VAL_15]] : i32
85+
// CHECK: scf.yield %[[VAL_16]] : i32
86+
// CHECK: }
87+
// CHECK: fir.store %[[VAL_9]] to %[[ARG1]] : !fir.ref<i32>
88+
// CHECK: return
89+
// CHECK: }
90+
func.func @loop_with_results(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
91+
%c1 = arith.constant 1 : index
92+
%c0_i32 = arith.constant 0 : i32
93+
%c100 = arith.constant 100 : index
94+
%0 = fir.shape %c100 : (index) -> !fir.shape<1>
95+
%1 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (i32) {
96+
%2 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
97+
%3 = fir.load %2 : !fir.ref<i32>
98+
%4 = arith.addi %arg3, %3 : i32
99+
fir.result %4 : i32
100+
}
101+
fir.store %1 to %arg1 : !fir.ref<i32>
102+
return
103+
}
104+
105+
// CHECK-LABEL: func.func @loop_with_final_value(
106+
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
107+
// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<i32>) {
108+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
109+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
110+
// CHECK: %[[VAL_2:.*]] = arith.constant 100 : index
111+
// CHECK: %[[VAL_3:.*]] = fir.alloca index
112+
// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
113+
// CHECK: %[[VAL_5:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
114+
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_0]] : index
115+
// CHECK: %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_0]] : index
116+
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
117+
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
118+
// CHECK: %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_1]]) -> (index, i32) {
119+
// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_11]], %[[VAL_0]] : index
120+
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_0]], %[[VAL_14]] : index
121+
// CHECK: %[[VAL_16:.*]] = fir.array_coor %[[ARG0]](%[[VAL_4]]) %[[VAL_15]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
122+
// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
123+
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_0]] overflow<nsw> : index
124+
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] overflow<nsw> : i32
125+
// CHECK: scf.yield %[[VAL_18]], %[[VAL_19]] : index, i32
126+
// CHECK: }
127+
// CHECK: fir.store %[[VAL_20:.*]]#0 to %[[VAL_3]] : !fir.ref<index>
128+
// CHECK: fir.store %[[VAL_20]]#1 to %[[ARG1]] : !fir.ref<i32>
129+
// CHECK: return
130+
// CHECK: }
131+
func.func @loop_with_final_value(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
132+
%c1 = arith.constant 1 : index
133+
%c0_i32 = arith.constant 0 : i32
134+
%c100 = arith.constant 100 : index
135+
%0 = fir.alloca index
136+
%1 = fir.shape %c100 : (index) -> !fir.shape<1>
137+
%2:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (index, i32) {
138+
%3 = fir.array_coor %arg0(%1) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
139+
%4 = fir.load %3 : !fir.ref<i32>
140+
%5 = arith.addi %arg2, %c1 overflow<nsw> : index
141+
%6 = arith.addi %arg3, %4 overflow<nsw> : i32
142+
fir.result %5, %6 : index, i32
143+
}
144+
fir.store %2#0 to %0 : !fir.ref<index>
145+
fir.store %2#1 to %arg1 : !fir.ref<i32>
146+
return
147+
}

0 commit comments

Comments
 (0)