@@ -47,9 +47,8 @@ impl ExecutionPlanVisitor for VortexMetricsFinder {
4747 type Error = std:: convert:: Infallible ;
4848 fn pre_visit ( & mut self , plan : & dyn ExecutionPlan ) -> Result < bool , Self :: Error > {
4949 if let Some ( exec) = plan. as_any ( ) . downcast_ref :: < DataSourceExec > ( ) {
50- if let Some ( metrics) = exec. metrics ( ) {
51- self . 0 . push ( metrics) ;
52- }
50+ // Start with exec metrics or create a new set
51+ let mut set = exec. metrics ( ) . unwrap_or_default ( ) ;
5352
5453 // Include our own metrics from VortexSource
5554 if let Some ( file_scan) = exec. data_source ( ) . as_any ( ) . downcast_ref :: < FileScanConfig > ( )
@@ -58,7 +57,6 @@ impl ExecutionPlanVisitor for VortexMetricsFinder {
5857 . as_any ( )
5958 . downcast_ref :: < VortexSource > ( )
6059 {
61- let mut set = MetricsSet :: new ( ) ;
6260 for metric in scan
6361 . metrics_registry ( )
6462 . snapshot ( )
@@ -67,12 +65,14 @@ impl ExecutionPlanVisitor for VortexMetricsFinder {
6765 {
6866 set. push ( Arc :: new ( metric) ) ;
6967 }
70-
71- self . 0 . push ( set) ;
7268 }
73- }
7469
75- Ok ( true )
70+ self . 0 . push ( set) ;
71+
72+ Ok ( false )
73+ } else {
74+ Ok ( true )
75+ }
7676 }
7777}
7878
@@ -193,3 +193,83 @@ fn f_to_u(f: f64) -> Option<usize> {
193193 // After the range check, truncation is guaranteed to keep the value in usize bounds.
194194 f. trunc ( ) as usize )
195195}
196+
197+ #[ cfg( test) ]
198+ mod tests {
199+ use datafusion:: execution:: SessionStateBuilder ;
200+ use datafusion:: prelude:: SessionContext ;
201+ use datafusion_datasource:: source:: DataSourceExec ;
202+ use datafusion_physical_plan:: ExecutionPlanVisitor ;
203+ use datafusion_physical_plan:: accept;
204+ use tempfile:: TempDir ;
205+
206+ use super :: VortexMetricsFinder ;
207+ use crate :: VortexFormatFactory ;
208+ use crate :: persistent:: register_vortex_format_factory;
209+
210+ /// Counts the number of DataSourceExec nodes in a plan.
211+ struct DataSourceExecCounter ( usize ) ;
212+
213+ impl ExecutionPlanVisitor for DataSourceExecCounter {
214+ type Error = std:: convert:: Infallible ;
215+ fn pre_visit (
216+ & mut self ,
217+ plan : & dyn datafusion_physical_plan:: ExecutionPlan ,
218+ ) -> Result < bool , Self :: Error > {
219+ if plan. as_any ( ) . downcast_ref :: < DataSourceExec > ( ) . is_some ( ) {
220+ self . 0 += 1 ;
221+ Ok ( false )
222+ } else {
223+ Ok ( true )
224+ }
225+ }
226+ }
227+
228+ #[ tokio:: test]
229+ async fn metrics_finder_returns_one_set_per_data_source_exec ( ) -> anyhow:: Result < ( ) > {
230+ let dir = TempDir :: new ( ) ?;
231+
232+ let factory = VortexFormatFactory :: new ( ) ;
233+ let mut session_state_builder = SessionStateBuilder :: new ( ) . with_default_features ( ) ;
234+ register_vortex_format_factory ( factory, & mut session_state_builder) ;
235+ let session = SessionContext :: new_with_state ( session_state_builder. build ( ) ) ;
236+
237+ session
238+ . sql ( & format ! (
239+ "CREATE EXTERNAL TABLE my_tbl \
240+ (c1 VARCHAR NOT NULL, c2 INT NOT NULL) \
241+ STORED AS vortex \
242+ LOCATION '{}'",
243+ dir. path( ) . to_str( ) . unwrap( )
244+ ) )
245+ . await ?;
246+
247+ session
248+ . sql ( "INSERT INTO my_tbl VALUES ('a', 1), ('b', 2)" )
249+ . await ?
250+ . collect ( )
251+ . await ?;
252+
253+ let df = session. sql ( "SELECT * FROM my_tbl" ) . await ?;
254+ let ( state, plan) = df. into_parts ( ) ;
255+ let physical_plan = state. create_physical_plan ( & plan) . await ?;
256+
257+ // Count DataSourceExec nodes
258+ let mut counter = DataSourceExecCounter ( 0 ) ;
259+ accept ( physical_plan. as_ref ( ) , & mut counter) ?;
260+
261+ // Get metrics sets
262+ let metrics_sets = VortexMetricsFinder :: find_all ( physical_plan. as_ref ( ) ) ;
263+
264+ assert ! ( metrics_sets. len( ) > 0 ) ;
265+ assert_eq ! (
266+ metrics_sets. len( ) ,
267+ counter. 0 ,
268+ "Expected one MetricsSet per DataSourceExec, got {} sets for {} DataSourceExec nodes" ,
269+ metrics_sets. len( ) ,
270+ counter. 0
271+ ) ;
272+
273+ Ok ( ( ) )
274+ }
275+ }
0 commit comments