Skip to content

Commit 289b2ae

Browse files
hdikemanfacebook-github-bot
authored andcommitted
refactor(parser): Pull common AST traversal logic into DefaultTraversalVisitor (facebookincubator#807)
Summary: This refactor was originally done as part of facebookincubator#804, but I am pulling out into a separate PR here Extracting common AST traversal logic into a common parent class which can be overridden by implementations which want to traverse the entire AST but only handle a specific subset of nodes. Reviewed By: mbasmanova Differential Revision: D91607843
1 parent 59b3e2a commit 289b2ae

File tree

2 files changed

+888
-140
lines changed

2 files changed

+888
-140
lines changed

axiom/sql/presto/PrestoParser.cpp

Lines changed: 6 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "axiom/sql/presto/PrestoParseError.h"
2323
#include "axiom/sql/presto/ast/AstBuilder.h"
2424
#include "axiom/sql/presto/ast/AstPrinter.h"
25+
#include "axiom/sql/presto/ast/DefaultTraversalVisitor.h"
2526
#include "axiom/sql/presto/ast/UpperCaseInputStream.h"
2627
#include "axiom/sql/presto/grammar/PrestoSqlLexer.h"
2728
#include "axiom/sql/presto/grammar/PrestoSqlParser.h"
@@ -179,59 +180,20 @@ std::string canonicalizeIdentifier(const Identifier& identifier) {
179180
return canonicalizeName(identifier.value());
180181
}
181182

182-
// Analizes the expression to find out whether there are any aggregate function
183+
// Analyzes the expression to find out whether there are any aggregate function
183184
// calls and to verify that aggregate calls are not nested, e.g. sum(count(x))
184185
// is not allowed.
185-
class ExprAnalyzer : public AstVisitor {
186+
class ExprAnalyzer : public DefaultTraversalVisitor {
186187
public:
187188
bool hasAggregate() const {
188189
return numAggregates_ > 0;
189190
}
190191

191-
private:
192-
void defaultVisit(Node* node) override {
193-
if (dynamic_cast<Literal*>(node) != nullptr) {
194-
// Literals have no function calls.
195-
return;
196-
}
197-
198-
VELOX_NYI(
199-
"Not yet supported node type: {}", NodeTypeName::toName(node->type()));
200-
}
201-
202-
void visitArithmeticUnaryExpression(
203-
ArithmeticUnaryExpression* node) override {
204-
node->value()->accept(this);
205-
}
206-
207-
void visitArrayConstructor(ArrayConstructor* node) override {
208-
for (const auto& value : node->values()) {
209-
value->accept(this);
210-
}
211-
}
212-
213-
void visitBetweenPredicate(BetweenPredicate* node) override {
214-
node->value()->accept(this);
215-
node->min()->accept(this);
216-
node->max()->accept(this);
217-
}
218-
219-
void visitCast(Cast* node) override {
220-
node->expression()->accept(this);
221-
}
222-
223-
void visitDereferenceExpression(DereferenceExpression* node) override {
224-
node->base()->accept(this);
225-
}
226-
192+
protected:
227193
void visitExistsPredicate(ExistsPredicate* node) override {
228194
// Aggregate function calls within a subquery do not count.
229195
}
230196

231-
void visitExtract(Extract* node) override {
232-
node->expression()->accept(this);
233-
}
234-
235197
void visitFunctionCall(FunctionCall* node) override {
236198
const auto& name = node->name()->suffix();
237199
if (facebook::velox::exec::getAggregateFunctionEntry(name)) {
@@ -245,112 +207,16 @@ class ExprAnalyzer : public AstVisitor {
245207
++numAggregates_;
246208
}
247209

248-
for (const auto& arg : node->arguments()) {
249-
arg->accept(this);
250-
}
210+
DefaultTraversalVisitor::visitFunctionCall(node);
251211

252212
aggregateName_.reset();
253213
}
254214

255-
void visitInListExpression(InListExpression* node) override {
256-
for (const auto& value : node->values()) {
257-
value->accept(this);
258-
}
259-
}
260-
261-
void visitInPredicate(InPredicate* node) override {
262-
node->value()->accept(this);
263-
node->valueList()->accept(this);
264-
}
265-
266-
void visitIsNullPredicate(IsNullPredicate* node) override {
267-
node->value()->accept(this);
268-
}
269-
270-
void visitIsNotNullPredicate(IsNotNullPredicate* node) override {
271-
node->value()->accept(this);
272-
}
273-
274-
void visitLambdaExpression(LambdaExpression* node) override {
275-
node->body()->accept(this);
276-
}
277-
278-
void visitNotExpression(NotExpression* node) override {
279-
node->value()->accept(this);
280-
}
281-
282-
void visitArithmeticBinaryExpression(
283-
ArithmeticBinaryExpression* node) override {
284-
node->left()->accept(this);
285-
node->right()->accept(this);
286-
}
287-
288-
void visitLogicalBinaryExpression(LogicalBinaryExpression* node) override {
289-
node->left()->accept(this);
290-
node->right()->accept(this);
291-
}
292-
293-
void visitComparisonExpression(ComparisonExpression* node) override {
294-
node->left()->accept(this);
295-
node->right()->accept(this);
296-
}
297-
298-
void visitLikePredicate(LikePredicate* node) override {
299-
node->value()->accept(this);
300-
node->pattern()->accept(this);
301-
if (node->escape() != nullptr) {
302-
node->escape()->accept(this);
303-
}
304-
}
305-
306-
void visitSimpleCaseExpression(SimpleCaseExpression* node) override {
307-
node->operand()->accept(this);
308-
309-
for (const auto& clause : node->whenClauses()) {
310-
clause->operand()->accept(this);
311-
clause->result()->accept(this);
312-
}
313-
314-
if (node->defaultValue()) {
315-
node->defaultValue()->accept(this);
316-
}
317-
}
318-
319-
void visitSearchedCaseExpression(SearchedCaseExpression* node) override {
320-
for (const auto& clause : node->whenClauses()) {
321-
clause->operand()->accept(this);
322-
clause->result()->accept(this);
323-
}
324-
325-
if (node->defaultValue()) {
326-
node->defaultValue()->accept(this);
327-
}
328-
}
329-
330215
void visitSubqueryExpression(SubqueryExpression* node) override {
331216
// Aggregate function calls within a subquery do not count.
332217
}
333218

334-
void visitSubscriptExpression(SubscriptExpression* node) override {
335-
node->base()->accept(this);
336-
node->index()->accept(this);
337-
}
338-
339-
void visitIdentifier(Identifier* node) override {
340-
// No function calls.
341-
}
342-
343-
void visitRow(Row* node) override {
344-
for (const auto& item : node->items()) {
345-
item->accept(this);
346-
}
347-
}
348-
349-
void visitAtTimeZone(AtTimeZone* node) override {
350-
node->value()->accept(this);
351-
node->timeZone()->accept(this);
352-
}
353-
219+
private:
354220
size_t numAggregates_{0};
355221
std::optional<std::string> aggregateName_;
356222
};

0 commit comments

Comments
 (0)