Skip to content

Commit 076b97a

Browse files
feat: add get_table_refs_from_statement to planner
1 parent d81ecd6 commit 076b97a

File tree

3 files changed

+84
-4
lines changed

3 files changed

+84
-4
lines changed

crates/proof-of-sql-planner/src/conversion.rs

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,17 @@ use datafusion::{
1111
optimizer::{Analyzer, Optimizer, OptimizerContext, OptimizerRule},
1212
sql::planner::{ContextProvider, SqlToRel},
1313
};
14-
use indexmap::IndexMap;
15-
use proof_of_sql::sql::proof_plans::DynProofPlan;
16-
use sqlparser::{dialect::GenericDialect, parser::Parser};
14+
use indexmap::{IndexMap, IndexSet};
15+
use proof_of_sql::{
16+
base::database::{ParseError, TableRef},
17+
sql::proof_plans::DynProofPlan,
18+
};
19+
use sqlparser::{
20+
ast::{visit_relations, Statement},
21+
dialect::GenericDialect,
22+
parser::Parser,
23+
};
24+
use std::ops::ControlFlow;
1725

1826
/// Get [`Optimizer`]
1927
///
@@ -111,3 +119,72 @@ pub fn sql_to_proof_plans_with_postprocessing<S: ContextProvider>(
111119
logical_plan_to_proof_plan_with_postprocessing,
112120
)
113121
}
122+
123+
/// Given a `Statement` retrieves all unique tables in the query
124+
pub fn get_table_refs_from_statement(
125+
statement: &Statement,
126+
) -> Result<IndexSet<TableRef>, ParseError> {
127+
let mut table_refs: IndexSet<TableRef> = IndexSet::<TableRef>::new();
128+
visit_relations(statement, |object_name| {
129+
match object_name.to_string().as_str().try_into() {
130+
Ok(table_ref) => {
131+
table_refs.insert(table_ref);
132+
ControlFlow::Continue(())
133+
}
134+
e => ControlFlow::Break(e),
135+
}
136+
})
137+
.break_value()
138+
.transpose()?;
139+
Ok(table_refs)
140+
}
141+
142+
#[cfg(test)]
143+
mod tests {
144+
use super::get_table_refs_from_statement;
145+
use indexmap::IndexSet;
146+
use proof_of_sql::base::database::TableRef;
147+
use sqlparser::{dialect::GenericDialect, parser::Parser};
148+
149+
#[test]
150+
fn we_can_get_table_references() {
151+
let statement = Parser::parse_sql(
152+
&GenericDialect {},
153+
"SELECT e.employee_id, e.employee_name, d.department_name, p.project_name, s.salary
154+
FROM employees e
155+
JOIN departments d ON e.department_id = d.department_id
156+
JOIN management.projects p ON e.employee_id = p.employee_id
157+
JOIN internal.salaries s ON e.employee_id = s.employee_id
158+
WHERE e.department_id IN (
159+
SELECT department_id
160+
FROM departments
161+
WHERE department_name = 'Sales'
162+
)
163+
AND p.project_id IN (
164+
SELECT project_id
165+
FROM project_assignments
166+
WHERE employee_id = e.employee_id
167+
)
168+
AND s.salary > (
169+
SELECT AVG(salary)
170+
FROM internal.salaries
171+
WHERE department_id = e.department_id
172+
);
173+
",
174+
)
175+
.unwrap()[0]
176+
.clone();
177+
let table_refs = get_table_refs_from_statement(&statement).unwrap();
178+
let expected_table_refs: IndexSet<TableRef> = [
179+
("", "departments"),
180+
("", "employees"),
181+
("management", "projects"),
182+
("", "project_assignments"),
183+
("internal", "salaries"),
184+
]
185+
.map(|(s, t)| TableRef::new(s, t))
186+
.into_iter()
187+
.collect();
188+
assert_eq!(table_refs, expected_table_refs);
189+
}
190+
}

crates/proof-of-sql-planner/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ pub use context::PoSqlContextProvider;
88
#[cfg(test)]
99
pub(crate) use context::PoSqlTableSource;
1010
mod conversion;
11-
pub use conversion::{sql_to_proof_plans, sql_to_proof_plans_with_postprocessing};
11+
pub use conversion::{
12+
get_table_refs_from_statement, sql_to_proof_plans, sql_to_proof_plans_with_postprocessing,
13+
};
1214
#[cfg(test)]
1315
mod df_util;
1416
mod expr;

crates/proof-of-sql/src/base/database/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ mod literal_value;
4545
pub use literal_value::LiteralValue;
4646

4747
mod error;
48+
pub use error::ParseError;
4849

4950
mod table_ref;
5051
#[cfg(feature = "arrow")]

0 commit comments

Comments
 (0)