2222#include " axiom/sql/presto/PrestoParseError.h"
2323#include " axiom/sql/presto/ast/AstBuilder.h"
2424#include " axiom/sql/presto/ast/AstPrinter.h"
25+ #include " axiom/sql/presto/ast/DefaultTraversalVisitor.h"
2526#include " axiom/sql/presto/ast/UpperCaseInputStream.h"
2627#include " axiom/sql/presto/grammar/PrestoSqlLexer.h"
2728#include " axiom/sql/presto/grammar/PrestoSqlParser.h"
@@ -179,59 +180,20 @@ std::string canonicalizeIdentifier(const Identifier& identifier) {
179180 return canonicalizeName (identifier.value ());
180181}
181182
182- // Analizes the expression to find out whether there are any aggregate function
183+ // Analyzes the expression to find out whether there are any aggregate function
183184// calls and to verify that aggregate calls are not nested, e.g. sum(count(x))
184185// is not allowed.
185- class ExprAnalyzer : public AstVisitor {
186+ class ExprAnalyzer : public DefaultTraversalVisitor {
186187 public:
187188 bool hasAggregate () const {
188189 return numAggregates_ > 0 ;
189190 }
190191
191- private:
192- void defaultVisit (Node* node) override {
193- if (dynamic_cast <Literal*>(node) != nullptr ) {
194- // Literals have no function calls.
195- return ;
196- }
197-
198- VELOX_NYI (
199- " Not yet supported node type: {}" , NodeTypeName::toName (node->type ()));
200- }
201-
202- void visitArithmeticUnaryExpression (
203- ArithmeticUnaryExpression* node) override {
204- node->value ()->accept (this );
205- }
206-
207- void visitArrayConstructor (ArrayConstructor* node) override {
208- for (const auto & value : node->values ()) {
209- value->accept (this );
210- }
211- }
212-
213- void visitBetweenPredicate (BetweenPredicate* node) override {
214- node->value ()->accept (this );
215- node->min ()->accept (this );
216- node->max ()->accept (this );
217- }
218-
219- void visitCast (Cast* node) override {
220- node->expression ()->accept (this );
221- }
222-
223- void visitDereferenceExpression (DereferenceExpression* node) override {
224- node->base ()->accept (this );
225- }
226-
192+ protected:
227193 void visitExistsPredicate (ExistsPredicate* node) override {
228194 // Aggregate function calls within a subquery do not count.
229195 }
230196
231- void visitExtract (Extract* node) override {
232- node->expression ()->accept (this );
233- }
234-
235197 void visitFunctionCall (FunctionCall* node) override {
236198 const auto & name = node->name ()->suffix ();
237199 if (facebook::velox::exec::getAggregateFunctionEntry (name)) {
@@ -245,112 +207,16 @@ class ExprAnalyzer : public AstVisitor {
245207 ++numAggregates_;
246208 }
247209
248- for (const auto & arg : node->arguments ()) {
249- arg->accept (this );
250- }
210+ DefaultTraversalVisitor::visitFunctionCall (node);
251211
252212 aggregateName_.reset ();
253213 }
254214
255- void visitInListExpression (InListExpression* node) override {
256- for (const auto & value : node->values ()) {
257- value->accept (this );
258- }
259- }
260-
261- void visitInPredicate (InPredicate* node) override {
262- node->value ()->accept (this );
263- node->valueList ()->accept (this );
264- }
265-
266- void visitIsNullPredicate (IsNullPredicate* node) override {
267- node->value ()->accept (this );
268- }
269-
270- void visitIsNotNullPredicate (IsNotNullPredicate* node) override {
271- node->value ()->accept (this );
272- }
273-
274- void visitLambdaExpression (LambdaExpression* node) override {
275- node->body ()->accept (this );
276- }
277-
278- void visitNotExpression (NotExpression* node) override {
279- node->value ()->accept (this );
280- }
281-
282- void visitArithmeticBinaryExpression (
283- ArithmeticBinaryExpression* node) override {
284- node->left ()->accept (this );
285- node->right ()->accept (this );
286- }
287-
288- void visitLogicalBinaryExpression (LogicalBinaryExpression* node) override {
289- node->left ()->accept (this );
290- node->right ()->accept (this );
291- }
292-
293- void visitComparisonExpression (ComparisonExpression* node) override {
294- node->left ()->accept (this );
295- node->right ()->accept (this );
296- }
297-
298- void visitLikePredicate (LikePredicate* node) override {
299- node->value ()->accept (this );
300- node->pattern ()->accept (this );
301- if (node->escape () != nullptr ) {
302- node->escape ()->accept (this );
303- }
304- }
305-
306- void visitSimpleCaseExpression (SimpleCaseExpression* node) override {
307- node->operand ()->accept (this );
308-
309- for (const auto & clause : node->whenClauses ()) {
310- clause->operand ()->accept (this );
311- clause->result ()->accept (this );
312- }
313-
314- if (node->defaultValue ()) {
315- node->defaultValue ()->accept (this );
316- }
317- }
318-
319- void visitSearchedCaseExpression (SearchedCaseExpression* node) override {
320- for (const auto & clause : node->whenClauses ()) {
321- clause->operand ()->accept (this );
322- clause->result ()->accept (this );
323- }
324-
325- if (node->defaultValue ()) {
326- node->defaultValue ()->accept (this );
327- }
328- }
329-
330215 void visitSubqueryExpression (SubqueryExpression* node) override {
331216 // Aggregate function calls within a subquery do not count.
332217 }
333218
334- void visitSubscriptExpression (SubscriptExpression* node) override {
335- node->base ()->accept (this );
336- node->index ()->accept (this );
337- }
338-
339- void visitIdentifier (Identifier* node) override {
340- // No function calls.
341- }
342-
343- void visitRow (Row* node) override {
344- for (const auto & item : node->items ()) {
345- item->accept (this );
346- }
347- }
348-
349- void visitAtTimeZone (AtTimeZone* node) override {
350- node->value ()->accept (this );
351- node->timeZone ()->accept (this );
352- }
353-
219+ private:
354220 size_t numAggregates_{0 };
355221 std::optional<std::string> aggregateName_;
356222};
0 commit comments