Skip to content

Commit 747dc89

Browse files
committed
pr: Refine switch cases
1 parent 4b16b7b commit 747dc89

File tree

3 files changed

+58
-123
lines changed

3 files changed

+58
-123
lines changed

lib/vast/Conversion/Parser/CleanUp.cpp

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,56 @@ namespace vast::conv {
5151
}
5252
};
5353

54-
struct RefineParsingSwitch : operation_conversion_pattern< hl::SwitchOp >
54+
// struct RefineParsingSwitch : operation_conversion_pattern< hl::SwitchOp >
55+
// {
56+
// using op_t = hl::SwitchOp;
57+
// using base = operation_conversion_pattern< op_t >;
58+
// using base::base;
59+
60+
// using adaptor_t = typename op_t::Adaptor;
61+
62+
// logical_result matchAndRewrite(
63+
// op_t op, adaptor_t adaptor, conversion_rewriter &rewriter
64+
// ) const override {
65+
// auto &cond = op.getCondRegion().front();
66+
// auto yield = terminator< hl::ValueYieldOp >::get(cond);
67+
68+
// // // rewriter.inlineBlockBefore(&cond, op);
69+
// // rewriter.create< pr::Sink >(
70+
// // op.getLoc(), pr::NoDataType::get(op.getContext()),
71+
// // yield.op()->getOperands()
72+
// // );
73+
// // rewriter.eraseOp(yield.op());
74+
// rewriter.eraseOp(op);
75+
// return mlir::success();
76+
// }
77+
78+
// static bool has_only_nonparse_cases(hl::SwitchOp op) {
79+
// for (auto &case_region : op.getCases()) {
80+
// if (!pr::is_noparse_region(&case_region)) {
81+
// return false;
82+
// }
83+
// }
84+
// return true;
85+
// }
86+
87+
// static void legalize(base_conversion_config &cfg) {
88+
// cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) {
89+
// auto &cond = op.getCondRegion();
90+
// auto yield = terminator< hl::ValueYieldOp >::get(cond.front());
91+
// auto yielded = yield.op()->getOperand(0);
92+
// if (pr::is_maybedata(yielded) || pr::is_data(yielded)) {
93+
// return !has_only_nonparse_cases(op);
94+
// }
95+
96+
// return true;
97+
// });
98+
// }
99+
// };
100+
101+
template< typename op_t >
102+
struct RefineCase : operation_conversion_pattern< op_t >
55103
{
56-
using op_t = hl::SwitchOp;
57104
using base = operation_conversion_pattern< op_t >;
58105
using base::base;
59106

@@ -62,43 +109,22 @@ namespace vast::conv {
62109
logical_result matchAndRewrite(
63110
op_t op, adaptor_t adaptor, conversion_rewriter &rewriter
64111
) const override {
65-
auto &cond = op.getCondRegion().front();
66-
auto yield = terminator< hl::ValueYieldOp >::get(cond);
67-
// rewriter.inlineBlockBefore(&cond, op);
68-
rewriter.create< pr::Sink >(
69-
op.getLoc(), pr::NoDataType::get(op.getContext()), yield.op()->getOperands()
70-
);
71-
rewriter.eraseOp(yield.op());
72112
rewriter.eraseOp(op);
73113
return mlir::success();
74114
}
75115

76-
static bool has_only_nonparse_cases(hl::SwitchOp op) {
77-
for (auto &case_region : op.getCases()) {
78-
if (!pr::is_noparse_region(&case_region)) {
79-
return false;
80-
}
81-
}
82-
return true;
83-
}
84-
85116
static void legalize(base_conversion_config &cfg) {
86117
cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) {
87-
auto &cond = op.getCondRegion();
88-
auto yield = terminator< hl::ValueYieldOp >::get(cond.front());
89-
auto yielded = yield.op()->getOperand(0);
90-
if (pr::is_maybedata(yielded) || pr::is_data(yielded)) {
91-
return !has_only_nonparse_cases(op);
92-
}
93-
94-
return true;
118+
return !pr::is_noparse_op(op);
95119
});
96120
}
97121
};
98122

99123
// clang-format off
100124
using refines = util::type_list<
101-
EmptyDefaultOpElimination
125+
EmptyDefaultOpElimination,
126+
RefineCase< hl::CaseOp >,
127+
RefineCase< hl::DefaultOp >
102128
// RefineParsingSwitch
103129
>;
104130
// clang-format on

lib/vast/Conversion/Parser/Refine.cpp

Lines changed: 3 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -23,97 +23,6 @@ namespace vast::conv {
2323

2424
namespace pattern {
2525

26-
static bool is_nodata(mlir_type type) { return mlir::isa< pr::NoDataType >(type); }
27-
28-
static bool is_nodata(mlir_value value) { return is_nodata(value.getType()); }
29-
30-
static bool is_nodata(mlir::ValueRange values) {
31-
for (auto value : values) {
32-
if (!is_nodata(value)) {
33-
return false;
34-
}
35-
}
36-
return true;
37-
}
38-
39-
static bool is_data(mlir_type type) { return mlir::isa< pr::DataType >(type); }
40-
41-
static bool is_data(mlir_value value) { return is_data(value.getType()); }
42-
43-
static bool is_maybedata(mlir_type type) {
44-
return mlir::isa< pr::MaybeDataType >(type);
45-
}
46-
47-
static bool is_maybedata(mlir_value value) { return is_maybedata(value.getType()); }
48-
49-
static bool is_noparse_region(mlir::Region *region);
50-
51-
static bool is_noparse_op(mlir::Operation &op) {
52-
if (mlir::isa< pr::NoParse >(op)) {
53-
return true;
54-
}
55-
56-
if (mlir::isa< hl::NullStmt >(op)) {
57-
return true;
58-
}
59-
60-
if (mlir::isa< hl::BreakOp >(op)) {
61-
return true;
62-
}
63-
64-
if (mlir::isa< hl::ContinueOp >(op)) {
65-
return true;
66-
}
67-
68-
if (auto yield = mlir::dyn_cast< hl::CondYieldOp >(op)) {
69-
if (is_nodata(yield.getResult())) {
70-
return true;
71-
}
72-
}
73-
74-
if (auto yield = mlir::dyn_cast< hl::ValueYieldOp >(op)) {
75-
if (is_nodata(yield.getResult())) {
76-
return true;
77-
}
78-
}
79-
80-
if (auto ret = mlir::dyn_cast< hl::ReturnOp >(op)) {
81-
if (is_nodata(ret.getResult())) {
82-
return true;
83-
}
84-
}
85-
86-
if (auto call = mlir::dyn_cast< hl::CallOp >(op)) {
87-
return is_nodata(call.getArgOperands()) && is_nodata(call.getResults());
88-
}
89-
90-
if (auto d = mlir::dyn_cast< hl::DefaultOp >(op)) {
91-
return is_noparse_region(&d.getBody());
92-
}
93-
94-
if (auto c = mlir::dyn_cast< hl::CaseOp >(op)) {
95-
return is_noparse_region(&c.getBody()) && is_noparse_region(&c.getLhs());
96-
}
97-
98-
return false;
99-
}
100-
101-
static bool is_noparse_region(mlir::Region *region) {
102-
if (region->empty()) {
103-
return true;
104-
}
105-
106-
for (auto &block : *region) {
107-
for (auto &op : block) {
108-
if (!is_noparse_op(op)) {
109-
return false;
110-
}
111-
}
112-
}
113-
114-
return true;
115-
}
116-
11726
template< typename op_t >
11827
struct DefinitionElimination : erase_pattern< op_t >
11928
{
@@ -148,7 +57,7 @@ namespace vast::conv {
14857
static void legalize(base_conversion_config &cfg) {
14958
cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) {
15059
for (auto region : op.getRegions()) {
151-
if (!is_noparse_region(region)) {
60+
if (!pr::is_noparse_region(region)) {
15261
return true;
15362
}
15463
}
@@ -184,10 +93,10 @@ namespace vast::conv {
18493

18594
static void legalize(base_conversion_config &cfg) {
18695
cfg.target.addDynamicallyLegalOp< hl::ReturnOp >([](hl::ReturnOp op) {
187-
if (is_maybedata(op->getOperand(0))) {
96+
if (pr::is_maybedata(op->getOperand(0))) {
18897
auto result = op->getOperand(0).getDefiningOp();
18998
if (auto cast = mlir::dyn_cast< pr::Cast >(result)) {
190-
if (is_nodata(cast.getOperand())) {
99+
if (pr::is_nodata(cast.getOperand())) {
191100
return false;
192101
}
193102
}

lib/vast/Conversion/Parser/Utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace vast::pr {
4646

4747
static bool is_noparse_region(mlir::Region *region);
4848

49-
static bool is_noparse_op(mlir::Operation &op) {
49+
static bool is_noparse_op(mlir::Operation *op) {
5050
if (mlir::isa< pr::NoParse >(op)) {
5151
return true;
5252
}
@@ -103,7 +103,7 @@ namespace vast::pr {
103103

104104
for (auto &block : *region) {
105105
for (auto &op : block) {
106-
if (!is_noparse_op(op)) {
106+
if (!is_noparse_op(&op)) {
107107
return false;
108108
}
109109
}

0 commit comments

Comments
 (0)