@@ -11,9 +11,17 @@ use datafusion::{
1111 optimizer:: { Analyzer , Optimizer , OptimizerContext , OptimizerRule } ,
1212 sql:: planner:: { ContextProvider , SqlToRel } ,
1313} ;
14- use indexmap:: IndexMap ;
15- use proof_of_sql:: sql:: proof_plans:: DynProofPlan ;
16- use sqlparser:: { dialect:: GenericDialect , parser:: Parser } ;
14+ use indexmap:: { IndexMap , IndexSet } ;
15+ use proof_of_sql:: {
16+ base:: database:: { ParseError , TableRef } ,
17+ sql:: proof_plans:: DynProofPlan ,
18+ } ;
19+ use sqlparser:: {
20+ ast:: { visit_relations, Statement } ,
21+ dialect:: GenericDialect ,
22+ parser:: Parser ,
23+ } ;
24+ use std:: ops:: ControlFlow ;
1725
1826/// Get [`Optimizer`]
1927///
@@ -111,3 +119,72 @@ pub fn sql_to_proof_plans_with_postprocessing<S: ContextProvider>(
111119 logical_plan_to_proof_plan_with_postprocessing,
112120 )
113121}
122+
123+ /// Given a `Statement` retrieves all unique tables in the query
124+ pub fn get_table_refs_from_statement (
125+ statement : & Statement ,
126+ ) -> Result < IndexSet < TableRef > , ParseError > {
127+ let mut table_refs: IndexSet < TableRef > = IndexSet :: < TableRef > :: new ( ) ;
128+ visit_relations ( statement, |object_name| {
129+ match object_name. to_string ( ) . as_str ( ) . try_into ( ) {
130+ Ok ( table_ref) => {
131+ table_refs. insert ( table_ref) ;
132+ ControlFlow :: Continue ( ( ) )
133+ }
134+ e => ControlFlow :: Break ( e) ,
135+ }
136+ } )
137+ . break_value ( )
138+ . transpose ( ) ?;
139+ Ok ( table_refs)
140+ }
141+
142+ #[ cfg( test) ]
143+ mod tests {
144+ use super :: get_table_refs_from_statement;
145+ use indexmap:: IndexSet ;
146+ use proof_of_sql:: base:: database:: TableRef ;
147+ use sqlparser:: { dialect:: GenericDialect , parser:: Parser } ;
148+
149+ #[ test]
150+ fn we_can_get_table_references ( ) {
151+ let statement = Parser :: parse_sql (
152+ & GenericDialect { } ,
153+ "SELECT e.employee_id, e.employee_name, d.department_name, p.project_name, s.salary
154+ FROM employees e
155+ JOIN departments d ON e.department_id = d.department_id
156+ JOIN management.projects p ON e.employee_id = p.employee_id
157+ JOIN internal.salaries s ON e.employee_id = s.employee_id
158+ WHERE e.department_id IN (
159+ SELECT department_id
160+ FROM departments
161+ WHERE department_name = 'Sales'
162+ )
163+ AND p.project_id IN (
164+ SELECT project_id
165+ FROM project_assignments
166+ WHERE employee_id = e.employee_id
167+ )
168+ AND s.salary > (
169+ SELECT AVG(salary)
170+ FROM internal.salaries
171+ WHERE department_id = e.department_id
172+ );
173+ " ,
174+ )
175+ . unwrap ( ) [ 0 ]
176+ . clone ( ) ;
177+ let table_refs = get_table_refs_from_statement ( & statement) . unwrap ( ) ;
178+ let expected_table_refs: IndexSet < TableRef > = [
179+ ( "" , "departments" ) ,
180+ ( "" , "employees" ) ,
181+ ( "management" , "projects" ) ,
182+ ( "" , "project_assignments" ) ,
183+ ( "internal" , "salaries" ) ,
184+ ]
185+ . map ( |( s, t) | TableRef :: new ( s, t) )
186+ . into_iter ( )
187+ . collect ( ) ;
188+ assert_eq ! ( table_refs, expected_table_refs) ;
189+ }
190+ }
0 commit comments