2222#include " axiom/sql/presto/PrestoParseError.h"
2323#include " axiom/sql/presto/ast/AstBuilder.h"
2424#include " axiom/sql/presto/ast/AstPrinter.h"
25+ #include " axiom/sql/presto/ast/DefaultTraversalVisitor.h"
2526#include " axiom/sql/presto/ast/UpperCaseInputStream.h"
2627#include " axiom/sql/presto/grammar/PrestoSqlLexer.h"
2728#include " axiom/sql/presto/grammar/PrestoSqlParser.h"
@@ -179,59 +180,20 @@ std::string canonicalizeIdentifier(const Identifier& identifier) {
179180 return canonicalizeName (identifier.value ());
180181}
181182
182- // Analizes the expression to find out whether there are any aggregate function
183+ // Analyzes the expression to find out whether there are any aggregate function
183184// calls and to verify that aggregate calls are not nested, e.g. sum(count(x))
184185// is not allowed.
185- class ExprAnalyzer : public AstVisitor {
186+ class ExprAnalyzer : public DefaultTraversalVisitor {
186187 public:
187188 bool hasAggregate () const {
188189 return numAggregates_ > 0 ;
189190 }
190191
191- private:
192- void defaultVisit (Node* node) override {
193- if (dynamic_cast <Literal*>(node) != nullptr ) {
194- // Literals have no function calls.
195- return ;
196- }
197-
198- VELOX_NYI (
199- " Not yet supported node type: {}" , NodeTypeName::toName (node->type ()));
200- }
201-
202- void visitArithmeticUnaryExpression (
203- ArithmeticUnaryExpression* node) override {
204- node->value ()->accept (this );
205- }
206-
207- void visitArrayConstructor (ArrayConstructor* node) override {
208- for (const auto & value : node->values ()) {
209- value->accept (this );
210- }
211- }
212-
213- void visitBetweenPredicate (BetweenPredicate* node) override {
214- node->value ()->accept (this );
215- node->min ()->accept (this );
216- node->max ()->accept (this );
217- }
218-
219- void visitCast (Cast* node) override {
220- node->expression ()->accept (this );
221- }
222-
223- void visitDereferenceExpression (DereferenceExpression* node) override {
224- node->base ()->accept (this );
225- }
226-
192+ protected:
227193 void visitExistsPredicate (ExistsPredicate* node) override {
228194 // Aggregate function calls within a subquery do not count.
229195 }
230196
231- void visitExtract (Extract* node) override {
232- node->expression ()->accept (this );
233- }
234-
235197 void visitFunctionCall (FunctionCall* node) override {
236198 const auto & name = node->name ()->suffix ();
237199 if (facebook::velox::exec::getAggregateFunctionEntry (name)) {
@@ -245,112 +207,16 @@ class ExprAnalyzer : public AstVisitor {
245207 ++numAggregates_;
246208 }
247209
248- for (const auto & arg : node->arguments ()) {
249- arg->accept (this );
250- }
210+ DefaultTraversalVisitor::visitFunctionCall (node);
251211
252212 aggregateName_.reset ();
253213 }
254214
255- void visitInListExpression (InListExpression* node) override {
256- for (const auto & value : node->values ()) {
257- value->accept (this );
258- }
259- }
260-
261- void visitInPredicate (InPredicate* node) override {
262- node->value ()->accept (this );
263- node->valueList ()->accept (this );
264- }
265-
266- void visitIsNullPredicate (IsNullPredicate* node) override {
267- node->value ()->accept (this );
268- }
269-
270- void visitIsNotNullPredicate (IsNotNullPredicate* node) override {
271- node->value ()->accept (this );
272- }
273-
274- void visitLambdaExpression (LambdaExpression* node) override {
275- node->body ()->accept (this );
276- }
277-
278- void visitNotExpression (NotExpression* node) override {
279- node->value ()->accept (this );
280- }
281-
282- void visitArithmeticBinaryExpression (
283- ArithmeticBinaryExpression* node) override {
284- node->left ()->accept (this );
285- node->right ()->accept (this );
286- }
287-
288- void visitLogicalBinaryExpression (LogicalBinaryExpression* node) override {
289- node->left ()->accept (this );
290- node->right ()->accept (this );
291- }
292-
293- void visitComparisonExpression (ComparisonExpression* node) override {
294- node->left ()->accept (this );
295- node->right ()->accept (this );
296- }
297-
298- void visitLikePredicate (LikePredicate* node) override {
299- node->value ()->accept (this );
300- node->pattern ()->accept (this );
301- if (node->escape () != nullptr ) {
302- node->escape ()->accept (this );
303- }
304- }
305-
306- void visitSimpleCaseExpression (SimpleCaseExpression* node) override {
307- node->operand ()->accept (this );
308-
309- for (const auto & clause : node->whenClauses ()) {
310- clause->operand ()->accept (this );
311- clause->result ()->accept (this );
312- }
313-
314- if (node->defaultValue ()) {
315- node->defaultValue ()->accept (this );
316- }
317- }
318-
319- void visitSearchedCaseExpression (SearchedCaseExpression* node) override {
320- for (const auto & clause : node->whenClauses ()) {
321- clause->operand ()->accept (this );
322- clause->result ()->accept (this );
323- }
324-
325- if (node->defaultValue ()) {
326- node->defaultValue ()->accept (this );
327- }
328- }
329-
330215 void visitSubqueryExpression (SubqueryExpression* node) override {
331216 // Aggregate function calls within a subquery do not count.
332217 }
333218
334- void visitSubscriptExpression (SubscriptExpression* node) override {
335- node->base ()->accept (this );
336- node->index ()->accept (this );
337- }
338-
339- void visitIdentifier (Identifier* node) override {
340- // No function calls.
341- }
342-
343- void visitRow (Row* node) override {
344- for (const auto & item : node->items ()) {
345- item->accept (this );
346- }
347- }
348-
349- void visitAtTimeZone (AtTimeZone* node) override {
350- node->value ()->accept (this );
351- node->timeZone ()->accept (this );
352- }
353-
219+ private:
354220 size_t numAggregates_{0 };
355221 std::optional<std::string> aggregateName_;
356222};
@@ -2335,4 +2201,169 @@ SqlStatementPtr PrestoParser::doParse(
23352201 return doPlan (query, defaultConnectorId_, defaultSchema_, parseSql);
23362202}
23372203
2204+ namespace {
2205+
2206+ // Analyzes an expression to extract the fully-qualified names of any
2207+ // input or output tables or views in the expression. Table accesses
2208+ // inside CTEs are included, even if the CTE is never read from.
2209+ class TableVisitor : public DefaultTraversalVisitor {
2210+ public:
2211+ TableVisitor (
2212+ const std::string& defaultConnectorId,
2213+ const std::optional<std::string>& defaultSchema)
2214+ : defaultConnectorId_(defaultConnectorId),
2215+ defaultSchema_ (defaultSchema) {}
2216+
2217+ const std::unordered_set<std::string>& inputTables () const {
2218+ return inputTables_;
2219+ }
2220+
2221+ const std::optional<std::string>& outputTable () const {
2222+ return outputTable_;
2223+ }
2224+
2225+ protected:
2226+ void visitWithQuery (WithQuery* node) override {
2227+ // To cover the case where a CTE aliases an underlying
2228+ // table, e.g. 'WITH t AS (SELECT * FROM t)', we need to
2229+ // traverse the inner query before tracking the CTE alias.
2230+ DefaultTraversalVisitor::visitWithQuery (node);
2231+ ctes_.insert (node->name ()->value ());
2232+ }
2233+
2234+ void visitTable (Table* node) override {
2235+ if (isCteReference (*node->name ())) {
2236+ return ;
2237+ }
2238+ inputTables_.insert (constructTableName (*node->name ()));
2239+ DefaultTraversalVisitor::visitTable (node);
2240+ }
2241+
2242+ void visitInsert (Insert* node) override {
2243+ parseOutputTable (*node->target ());
2244+ DefaultTraversalVisitor::visitInsert (node);
2245+ }
2246+
2247+ void visitCreateTableAsSelect (CreateTableAsSelect* node) override {
2248+ parseOutputTable (*node->name ());
2249+ DefaultTraversalVisitor::visitCreateTableAsSelect (node);
2250+ }
2251+
2252+ void visitUpdate (Update* node) override {
2253+ parseOutputTable (*node->table ());
2254+ DefaultTraversalVisitor::visitUpdate (node);
2255+ }
2256+
2257+ void visitDelete (Delete* node) override {
2258+ parseOutputTable (*node->table ());
2259+ DefaultTraversalVisitor::visitDelete (node);
2260+ }
2261+
2262+ void visitCreateTable (CreateTable* node) override {
2263+ parseOutputTable (*node->name ());
2264+ DefaultTraversalVisitor::visitCreateTable (node);
2265+ }
2266+
2267+ void visitCreateView (CreateView* node) override {
2268+ parseOutputTable (*node->name ());
2269+ DefaultTraversalVisitor::visitCreateView (node);
2270+ }
2271+
2272+ void visitCreateMaterializedView (CreateMaterializedView* node) override {
2273+ parseOutputTable (*node->name ());
2274+ DefaultTraversalVisitor::visitCreateMaterializedView (node);
2275+ }
2276+
2277+ void visitDropTable (DropTable* node) override {
2278+ parseOutputTable (*node->tableName ());
2279+ DefaultTraversalVisitor::visitDropTable (node);
2280+ }
2281+
2282+ void visitDropView (DropView* node) override {
2283+ parseOutputTable (*node->viewName ());
2284+ DefaultTraversalVisitor::visitDropView (node);
2285+ }
2286+
2287+ void visitDropMaterializedView (DropMaterializedView* node) override {
2288+ parseOutputTable (*node->viewName ());
2289+ DefaultTraversalVisitor::visitDropMaterializedView (node);
2290+ }
2291+
2292+ private:
2293+ std::string constructTableName (const QualifiedName& name) {
2294+ const auto & parts = name.parts ();
2295+ VELOX_CHECK (!parts.empty (), " Table name cannot be empty" );
2296+ switch (parts.size ()) {
2297+ case 1 :
2298+ if (defaultSchema_.has_value ()) {
2299+ return fmt::format (
2300+ " {}.{}.{}" ,
2301+ defaultConnectorId_,
2302+ defaultSchema_.value (),
2303+ parts[0 ]);
2304+ }
2305+ return fmt::format (" {}.{}" , defaultConnectorId_, parts[0 ]);
2306+ case 2 :
2307+ return fmt::format (" {}.{}.{}" , defaultConnectorId_, parts[0 ], parts[1 ]);
2308+ case 3 :
2309+ return fmt::format (" {}.{}.{}" , parts[0 ], parts[1 ], parts[2 ]);
2310+ default :
2311+ VELOX_FAIL (
2312+ " Table name must have 1-3 components, '{}' has {}" ,
2313+ name.fullyQualifiedName (),
2314+ parts.size ());
2315+ }
2316+ }
2317+
2318+ void parseOutputTable (const QualifiedName& name) {
2319+ auto outputTable = constructTableName (name);
2320+ VELOX_CHECK (
2321+ !outputTable_.has_value (),
2322+ " Cannot perform write against multiple tables (found {} and {})" ,
2323+ outputTable_.value (),
2324+ outputTable);
2325+ outputTable_ = std::move (outputTable);
2326+ }
2327+
2328+ bool isCteReference (const QualifiedName& name) {
2329+ const auto & parts = name.parts ();
2330+ return parts.size () == 1 && ctes_.count (parts[0 ]) > 0 ;
2331+ }
2332+
2333+ const std::string& defaultConnectorId_;
2334+ const std::optional<std::string>& defaultSchema_;
2335+ std::unordered_set<std::string> ctes_;
2336+ std::unordered_set<std::string> inputTables_;
2337+ std::optional<std::string> outputTable_;
2338+ };
2339+
2340+ } // namespace
2341+
2342+ std::unordered_set<std::string> PrestoParser::getInputTables (
2343+ std::string_view sql) {
2344+ ParserHelper helper (sql);
2345+ auto * context = helper.parse ();
2346+
2347+ AstBuilder astBuilder (false );
2348+ auto statement =
2349+ std::any_cast<std::shared_ptr<Statement>>(astBuilder.visit (context));
2350+
2351+ TableVisitor visitor (defaultConnectorId_, defaultSchema_);
2352+ visitor.process (statement.get ());
2353+ return visitor.inputTables ();
2354+ }
2355+
2356+ std::optional<std::string> PrestoParser::getOutputTable (std::string_view sql) {
2357+ ParserHelper helper (sql);
2358+ auto * context = helper.parse ();
2359+
2360+ AstBuilder astBuilder (false );
2361+ auto statement =
2362+ std::any_cast<std::shared_ptr<Statement>>(astBuilder.visit (context));
2363+
2364+ TableVisitor visitor (defaultConnectorId_, defaultSchema_);
2365+ visitor.process (statement.get ());
2366+ return visitor.outputTable ();
2367+ }
2368+
23382369} // namespace axiom::sql::presto
0 commit comments