Skip to content

Commit 031664a

Browse files
hdikemanfacebook-github-bot
authored andcommitted
feat(parser): Add input/output table extraction to PrestoParser
Summary: There are usecases for which callers may want to extract some information from a query without needing to resolve all the metadata details required to build a full logical plan. An example could be a client-side check decides where to send a query based on the tables it accesses, or moving ACL checks earlier in query execution by determining accessed tables immediately See #789 for the related issue To enable this, I am adding two APIs to the PrestoParser, one which extracts accessed input tables, and one which extracts output tables, if any exist There are two parts to this changeset: 1. on recommendation of Masha, defined a DefaultTraversalVisitor, which performs a DFS traversal over all nodes in the AST. I used this baseclass for the existing ExprAnalyzer and the new TableVisitor. I can pull this into a separate PR if desired 2. add the TableVisitor, which extracts input tables and the output table for the query, and link it into two new PrestoParser APIs for input and output tables respectively Some things I was unsure about and would like feedback: 1. I exposed two APIs, but I could easily have exposed one (getInputAndOutputTables) and return a struct containing the output of both APIs 2. I implemented the handlers for query types not currently covered by the parser (materialized view statements, some view statements, pure CREATE TABLE), but these cannot be run yet. I can also remove them or leave more comments in PrestoParser.cpp I am also looking for comments on structuring: PrestoParser.cpp is getting big, I can cut it up into a few source/header files in this diff or a follow-up if others agree (but did not want to do so without discussion) Differential Revision: D91525572
1 parent 59b3e2a commit 031664a

File tree

5 files changed

+1209
-141
lines changed

5 files changed

+1209
-141
lines changed

axiom/sql/presto/PrestoParser.cpp

Lines changed: 171 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "axiom/sql/presto/PrestoParseError.h"
2323
#include "axiom/sql/presto/ast/AstBuilder.h"
2424
#include "axiom/sql/presto/ast/AstPrinter.h"
25+
#include "axiom/sql/presto/ast/DefaultTraversalVisitor.h"
2526
#include "axiom/sql/presto/ast/UpperCaseInputStream.h"
2627
#include "axiom/sql/presto/grammar/PrestoSqlLexer.h"
2728
#include "axiom/sql/presto/grammar/PrestoSqlParser.h"
@@ -179,59 +180,20 @@ std::string canonicalizeIdentifier(const Identifier& identifier) {
179180
return canonicalizeName(identifier.value());
180181
}
181182

182-
// Analizes the expression to find out whether there are any aggregate function
183+
// Analyzes the expression to find out whether there are any aggregate function
183184
// calls and to verify that aggregate calls are not nested, e.g. sum(count(x))
184185
// is not allowed.
185-
class ExprAnalyzer : public AstVisitor {
186+
class ExprAnalyzer : public DefaultTraversalVisitor {
186187
public:
187188
bool hasAggregate() const {
188189
return numAggregates_ > 0;
189190
}
190191

191-
private:
192-
void defaultVisit(Node* node) override {
193-
if (dynamic_cast<Literal*>(node) != nullptr) {
194-
// Literals have no function calls.
195-
return;
196-
}
197-
198-
VELOX_NYI(
199-
"Not yet supported node type: {}", NodeTypeName::toName(node->type()));
200-
}
201-
202-
void visitArithmeticUnaryExpression(
203-
ArithmeticUnaryExpression* node) override {
204-
node->value()->accept(this);
205-
}
206-
207-
void visitArrayConstructor(ArrayConstructor* node) override {
208-
for (const auto& value : node->values()) {
209-
value->accept(this);
210-
}
211-
}
212-
213-
void visitBetweenPredicate(BetweenPredicate* node) override {
214-
node->value()->accept(this);
215-
node->min()->accept(this);
216-
node->max()->accept(this);
217-
}
218-
219-
void visitCast(Cast* node) override {
220-
node->expression()->accept(this);
221-
}
222-
223-
void visitDereferenceExpression(DereferenceExpression* node) override {
224-
node->base()->accept(this);
225-
}
226-
192+
protected:
227193
void visitExistsPredicate(ExistsPredicate* node) override {
228194
// Aggregate function calls within a subquery do not count.
229195
}
230196

231-
void visitExtract(Extract* node) override {
232-
node->expression()->accept(this);
233-
}
234-
235197
void visitFunctionCall(FunctionCall* node) override {
236198
const auto& name = node->name()->suffix();
237199
if (facebook::velox::exec::getAggregateFunctionEntry(name)) {
@@ -245,112 +207,16 @@ class ExprAnalyzer : public AstVisitor {
245207
++numAggregates_;
246208
}
247209

248-
for (const auto& arg : node->arguments()) {
249-
arg->accept(this);
250-
}
210+
DefaultTraversalVisitor::visitFunctionCall(node);
251211

252212
aggregateName_.reset();
253213
}
254214

255-
void visitInListExpression(InListExpression* node) override {
256-
for (const auto& value : node->values()) {
257-
value->accept(this);
258-
}
259-
}
260-
261-
void visitInPredicate(InPredicate* node) override {
262-
node->value()->accept(this);
263-
node->valueList()->accept(this);
264-
}
265-
266-
void visitIsNullPredicate(IsNullPredicate* node) override {
267-
node->value()->accept(this);
268-
}
269-
270-
void visitIsNotNullPredicate(IsNotNullPredicate* node) override {
271-
node->value()->accept(this);
272-
}
273-
274-
void visitLambdaExpression(LambdaExpression* node) override {
275-
node->body()->accept(this);
276-
}
277-
278-
void visitNotExpression(NotExpression* node) override {
279-
node->value()->accept(this);
280-
}
281-
282-
void visitArithmeticBinaryExpression(
283-
ArithmeticBinaryExpression* node) override {
284-
node->left()->accept(this);
285-
node->right()->accept(this);
286-
}
287-
288-
void visitLogicalBinaryExpression(LogicalBinaryExpression* node) override {
289-
node->left()->accept(this);
290-
node->right()->accept(this);
291-
}
292-
293-
void visitComparisonExpression(ComparisonExpression* node) override {
294-
node->left()->accept(this);
295-
node->right()->accept(this);
296-
}
297-
298-
void visitLikePredicate(LikePredicate* node) override {
299-
node->value()->accept(this);
300-
node->pattern()->accept(this);
301-
if (node->escape() != nullptr) {
302-
node->escape()->accept(this);
303-
}
304-
}
305-
306-
void visitSimpleCaseExpression(SimpleCaseExpression* node) override {
307-
node->operand()->accept(this);
308-
309-
for (const auto& clause : node->whenClauses()) {
310-
clause->operand()->accept(this);
311-
clause->result()->accept(this);
312-
}
313-
314-
if (node->defaultValue()) {
315-
node->defaultValue()->accept(this);
316-
}
317-
}
318-
319-
void visitSearchedCaseExpression(SearchedCaseExpression* node) override {
320-
for (const auto& clause : node->whenClauses()) {
321-
clause->operand()->accept(this);
322-
clause->result()->accept(this);
323-
}
324-
325-
if (node->defaultValue()) {
326-
node->defaultValue()->accept(this);
327-
}
328-
}
329-
330215
void visitSubqueryExpression(SubqueryExpression* node) override {
331216
// Aggregate function calls within a subquery do not count.
332217
}
333218

334-
void visitSubscriptExpression(SubscriptExpression* node) override {
335-
node->base()->accept(this);
336-
node->index()->accept(this);
337-
}
338-
339-
void visitIdentifier(Identifier* node) override {
340-
// No function calls.
341-
}
342-
343-
void visitRow(Row* node) override {
344-
for (const auto& item : node->items()) {
345-
item->accept(this);
346-
}
347-
}
348-
349-
void visitAtTimeZone(AtTimeZone* node) override {
350-
node->value()->accept(this);
351-
node->timeZone()->accept(this);
352-
}
353-
219+
private:
354220
size_t numAggregates_{0};
355221
std::optional<std::string> aggregateName_;
356222
};
@@ -2335,4 +2201,169 @@ SqlStatementPtr PrestoParser::doParse(
23352201
return doPlan(query, defaultConnectorId_, defaultSchema_, parseSql);
23362202
}
23372203

2204+
namespace {
2205+
2206+
// Analyzes an expression to extract the fully-qualified names of any
2207+
// input or output tables or views in the expression. Table accesses
2208+
// inside CTEs are included, even if the CTE is never read from.
2209+
class TableVisitor : public DefaultTraversalVisitor {
2210+
public:
2211+
TableVisitor(
2212+
const std::string& defaultConnectorId,
2213+
const std::optional<std::string>& defaultSchema)
2214+
: defaultConnectorId_(defaultConnectorId),
2215+
defaultSchema_(defaultSchema) {}
2216+
2217+
const std::unordered_set<std::string>& inputTables() const {
2218+
return inputTables_;
2219+
}
2220+
2221+
const std::optional<std::string>& outputTable() const {
2222+
return outputTable_;
2223+
}
2224+
2225+
protected:
2226+
void visitWithQuery(WithQuery* node) override {
2227+
// To cover the case where a CTE aliases an underlying
2228+
// table, e.g. 'WITH t AS (SELECT * FROM t)', we need to
2229+
// traverse the inner query before tracking the CTE alias.
2230+
DefaultTraversalVisitor::visitWithQuery(node);
2231+
ctes_.insert(node->name()->value());
2232+
}
2233+
2234+
void visitTable(Table* node) override {
2235+
if (isCteReference(*node->name())) {
2236+
return;
2237+
}
2238+
inputTables_.insert(constructTableName(*node->name()));
2239+
DefaultTraversalVisitor::visitTable(node);
2240+
}
2241+
2242+
void visitInsert(Insert* node) override {
2243+
parseOutputTable(*node->target());
2244+
DefaultTraversalVisitor::visitInsert(node);
2245+
}
2246+
2247+
void visitCreateTableAsSelect(CreateTableAsSelect* node) override {
2248+
parseOutputTable(*node->name());
2249+
DefaultTraversalVisitor::visitCreateTableAsSelect(node);
2250+
}
2251+
2252+
void visitUpdate(Update* node) override {
2253+
parseOutputTable(*node->table());
2254+
DefaultTraversalVisitor::visitUpdate(node);
2255+
}
2256+
2257+
void visitDelete(Delete* node) override {
2258+
parseOutputTable(*node->table());
2259+
DefaultTraversalVisitor::visitDelete(node);
2260+
}
2261+
2262+
void visitCreateTable(CreateTable* node) override {
2263+
parseOutputTable(*node->name());
2264+
DefaultTraversalVisitor::visitCreateTable(node);
2265+
}
2266+
2267+
void visitCreateView(CreateView* node) override {
2268+
parseOutputTable(*node->name());
2269+
DefaultTraversalVisitor::visitCreateView(node);
2270+
}
2271+
2272+
void visitCreateMaterializedView(CreateMaterializedView* node) override {
2273+
parseOutputTable(*node->name());
2274+
DefaultTraversalVisitor::visitCreateMaterializedView(node);
2275+
}
2276+
2277+
void visitDropTable(DropTable* node) override {
2278+
parseOutputTable(*node->tableName());
2279+
DefaultTraversalVisitor::visitDropTable(node);
2280+
}
2281+
2282+
void visitDropView(DropView* node) override {
2283+
parseOutputTable(*node->viewName());
2284+
DefaultTraversalVisitor::visitDropView(node);
2285+
}
2286+
2287+
void visitDropMaterializedView(DropMaterializedView* node) override {
2288+
parseOutputTable(*node->viewName());
2289+
DefaultTraversalVisitor::visitDropMaterializedView(node);
2290+
}
2291+
2292+
private:
2293+
std::string constructTableName(const QualifiedName& name) {
2294+
const auto& parts = name.parts();
2295+
VELOX_CHECK(!parts.empty(), "Table name cannot be empty");
2296+
switch (parts.size()) {
2297+
case 1:
2298+
if (defaultSchema_.has_value()) {
2299+
return fmt::format(
2300+
"{}.{}.{}",
2301+
defaultConnectorId_,
2302+
defaultSchema_.value(),
2303+
parts[0]);
2304+
}
2305+
return fmt::format("{}.{}", defaultConnectorId_, parts[0]);
2306+
case 2:
2307+
return fmt::format("{}.{}.{}", defaultConnectorId_, parts[0], parts[1]);
2308+
case 3:
2309+
return fmt::format("{}.{}.{}", parts[0], parts[1], parts[2]);
2310+
default:
2311+
VELOX_FAIL(
2312+
"Table name must have 1-3 components, '{}' has {}",
2313+
name.fullyQualifiedName(),
2314+
parts.size());
2315+
}
2316+
}
2317+
2318+
void parseOutputTable(const QualifiedName& name) {
2319+
auto outputTable = constructTableName(name);
2320+
VELOX_CHECK(
2321+
!outputTable_.has_value(),
2322+
"Cannot perform write against multiple tables (found {} and {})",
2323+
outputTable_.value(),
2324+
outputTable);
2325+
outputTable_ = std::move(outputTable);
2326+
}
2327+
2328+
bool isCteReference(const QualifiedName& name) {
2329+
const auto& parts = name.parts();
2330+
return parts.size() == 1 && ctes_.count(parts[0]) > 0;
2331+
}
2332+
2333+
const std::string& defaultConnectorId_;
2334+
const std::optional<std::string>& defaultSchema_;
2335+
std::unordered_set<std::string> ctes_;
2336+
std::unordered_set<std::string> inputTables_;
2337+
std::optional<std::string> outputTable_;
2338+
};
2339+
2340+
} // namespace
2341+
2342+
std::unordered_set<std::string> PrestoParser::getInputTables(
2343+
std::string_view sql) {
2344+
ParserHelper helper(sql);
2345+
auto* context = helper.parse();
2346+
2347+
AstBuilder astBuilder(false);
2348+
auto statement =
2349+
std::any_cast<std::shared_ptr<Statement>>(astBuilder.visit(context));
2350+
2351+
TableVisitor visitor(defaultConnectorId_, defaultSchema_);
2352+
visitor.process(statement.get());
2353+
return visitor.inputTables();
2354+
}
2355+
2356+
std::optional<std::string> PrestoParser::getOutputTable(std::string_view sql) {
2357+
ParserHelper helper(sql);
2358+
auto* context = helper.parse();
2359+
2360+
AstBuilder astBuilder(false);
2361+
auto statement =
2362+
std::any_cast<std::shared_ptr<Statement>>(astBuilder.visit(context));
2363+
2364+
TableVisitor visitor(defaultConnectorId_, defaultSchema_);
2365+
visitor.process(statement.get());
2366+
return visitor.outputTable();
2367+
}
2368+
23382369
} // namespace axiom::sql::presto

axiom/sql/presto/PrestoParser.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ class PrestoParser {
4747
std::string_view sql,
4848
bool enableTracing = false);
4949

50+
/// Extracts input tables from a SQL query, if any exist.
51+
/// @param sql SQL query statement
52+
/// @return set of fully-qualified input table names in `catalog.schema.table`
53+
/// format. If the query accesses no tables, an empty set is returned
54+
std::unordered_set<std::string> getInputTables(std::string_view sql);
55+
56+
/// Extracts the output table from a SQL query, if one exists.
57+
/// @param sql SQL query statement.
58+
/// @return fully-qualified output table name in `catalog.schema.table`
59+
/// format, or std::nullopt if the query does not modify a table.
60+
std::optional<std::string> getOutputTable(std::string_view sql);
61+
5062
private:
5163
SqlStatementPtr doParse(std::string_view sql, bool enableTracing);
5264

0 commit comments

Comments
 (0)