@@ -8,13 +8,12 @@ use std::task::{Context, Poll};
8
8
use std:: time:: Duration ;
9
9
10
10
use arrow:: array:: RecordBatch ;
11
- use arrow:: compute:: concat_batches;
12
11
use arrow:: datatypes:: SchemaRef ;
13
12
use arrow:: error:: ArrowError ;
14
13
use arrow:: ipc:: convert:: fb_to_schema;
15
14
use arrow:: ipc:: reader:: StreamReader ;
16
15
use arrow:: ipc:: writer:: { IpcWriteOptions , StreamWriter } ;
17
- use arrow:: ipc:: { root_as_message , MetadataVersion } ;
16
+ use arrow:: ipc:: { MetadataVersion , root_as_message } ;
18
17
use arrow:: pyarrow:: * ;
19
18
use arrow:: util:: pretty;
20
19
use arrow_flight:: { FlightClient , FlightData , Ticket } ;
@@ -30,16 +29,16 @@ use datafusion::error::DataFusionError;
30
29
use datafusion:: execution:: object_store:: ObjectStoreUrl ;
31
30
use datafusion:: execution:: { RecordBatchStream , SendableRecordBatchStream , SessionStateBuilder } ;
32
31
use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
33
- use datafusion:: physical_plan:: { displayable , ExecutionPlan , ExecutionPlanProperties } ;
34
- use datafusion:: prelude:: { SessionConfig , SessionContext } ;
32
+ use datafusion:: physical_plan:: { ExecutionPlan , ExecutionPlanProperties , displayable } ;
33
+ use datafusion:: prelude:: { ParquetReadOptions , SessionConfig , SessionContext } ;
35
34
use datafusion_proto:: physical_plan:: AsExecutionPlan ;
36
35
use datafusion_python:: utils:: wait_for_future;
37
36
use futures:: { Stream , StreamExt } ;
38
37
use log:: debug;
38
+ use object_store:: ObjectStore ;
39
39
use object_store:: aws:: AmazonS3Builder ;
40
40
use object_store:: gcp:: GoogleCloudStorageBuilder ;
41
41
use object_store:: http:: HttpBuilder ;
42
- use object_store:: ObjectStore ;
43
42
use parking_lot:: Mutex ;
44
43
use pyo3:: prelude:: * ;
45
44
use pyo3:: types:: { PyBytes , PyList } ;
@@ -411,62 +410,77 @@ fn print_node(plan: &Arc<dyn ExecutionPlan>, indent: usize, output: &mut String)
411
410
}
412
411
}
413
412
414
- async fn exec_sql (
415
- query : String ,
416
- tables : Vec < ( String , String ) > ,
417
- ) -> Result < RecordBatch , DataFusionError > {
418
- let ctx = SessionContext :: new ( ) ;
419
- for ( name, path) in tables {
420
- let opt =
421
- ListingOptions :: new ( Arc :: new ( ParquetFormat :: new ( ) ) ) . with_file_extension ( ".parquet" ) ;
422
- debug ! ( "exec_sql: registering table {} at {}" , name, path) ;
413
+ #[ pyclass]
414
+ pub struct LocalValidator {
415
+ ctx : SessionContext ,
416
+ }
417
+
418
+ #[ pymethods]
419
+ impl LocalValidator {
420
+ #[ new]
421
+ fn new ( ) -> Self {
422
+ let ctx = SessionContext :: new ( ) ;
423
+ Self { ctx }
424
+ }
425
+
426
+ pub fn register_parquet ( & self , py : Python , name : String , path : String ) -> PyResult < ( ) > {
427
+ let options = ParquetReadOptions :: default ( ) ;
423
428
424
- let url = ListingTableUrl :: parse ( & path) ?;
429
+ let url = ListingTableUrl :: parse ( & path) . to_py_err ( ) ?;
425
430
426
- maybe_register_object_store ( & ctx, url. as_ref ( ) ) ?;
431
+ maybe_register_object_store ( & self . ctx , url. as_ref ( ) ) . to_py_err ( ) ?;
432
+ debug ! ( "register_parquet: registering table {} at {}" , name, path) ;
427
433
428
- ctx. register_listing_table ( & name, & path, opt , None , None )
429
- . await ? ;
434
+ wait_for_future ( py , self . ctx . register_parquet ( & name, & path, options . clone ( ) ) ) ? ;
435
+ Ok ( ( ) )
430
436
}
431
- let df = ctx. sql ( & query) . await ?;
432
- let schema = df. schema ( ) . inner ( ) . clone ( ) ;
433
- let batches = df. collect ( ) . await ?;
434
- concat_batches ( & schema, batches. iter ( ) ) . map_err ( |e| DataFusionError :: ArrowError ( e, None ) )
435
- }
436
437
437
- /// Executes a query on the specified tables using DataFusion without Ray.
438
- ///
439
- /// Returns the query results as a RecordBatch that can be used to verify the
440
- /// correctness of DataFusion-Ray execution of the same query.
441
- ///
442
- /// # Arguments
443
- ///
444
- /// * `py`: the Python token
445
- /// * `query`: the SQL query string to execute
446
- /// * `tables`: a list of `(name, url)` tuples specifying the tables to query;
447
- /// the `url` identifies the parquet files for each listing table and see
448
- /// [`datafusion::datasource::listing::ListingTableUrl::parse`] for details
449
- /// of supported URL formats
450
- /// * `listing`: boolean indicating whether this is a listing table path or not
451
- #[ pyfunction]
452
- #[ pyo3( signature = ( query, tables, listing=false ) ) ]
453
- pub fn exec_sql_on_tables (
454
- py : Python ,
455
- query : String ,
456
- tables : Bound < ' _ , PyList > ,
457
- listing : bool ,
458
- ) -> PyResult < PyObject > {
459
- let table_vec = {
460
- let mut v = Vec :: with_capacity ( tables. len ( ) ) ;
461
- for entry in tables. iter ( ) {
462
- let ( name, path) = entry. extract :: < ( String , String ) > ( ) ?;
463
- let path = if listing { format ! ( "{path}/" ) } else { path } ;
464
- v. push ( ( name, path) ) ;
465
- }
466
- v
467
- } ;
468
- let batch = wait_for_future ( py, exec_sql ( query, table_vec) ) ?;
469
- batch. to_pyarrow ( py)
438
+ #[ pyo3( signature = ( name, path, file_extension=".parquet" ) ) ]
439
+ pub fn register_listing_table (
440
+ & mut self ,
441
+ py : Python ,
442
+ name : & str ,
443
+ path : & str ,
444
+ file_extension : & str ,
445
+ ) -> PyResult < ( ) > {
446
+ let options =
447
+ ListingOptions :: new ( Arc :: new ( ParquetFormat :: new ( ) ) ) . with_file_extension ( file_extension) ;
448
+
449
+ let path = format ! ( "{path}/" ) ;
450
+ let url = ListingTableUrl :: parse ( & path) . to_py_err ( ) ?;
451
+
452
+ maybe_register_object_store ( & self . ctx , url. as_ref ( ) ) . to_py_err ( ) ?;
453
+
454
+ debug ! (
455
+ "register_listing_table: registering table {} at {}" ,
456
+ name, path
457
+ ) ;
458
+ wait_for_future (
459
+ py,
460
+ self . ctx
461
+ . register_listing_table ( name, path, options, None , None ) ,
462
+ )
463
+ . to_py_err ( )
464
+ }
465
+
466
+ #[ pyo3( signature = ( query) ) ]
467
+ fn collect_sql ( & self , py : Python , query : String ) -> PyResult < PyObject > {
468
+ let fut = async || {
469
+ let df = self . ctx . sql ( & query) . await ?;
470
+ let batches = df. collect ( ) . await ?;
471
+
472
+ Ok :: < _ , DataFusionError > ( batches)
473
+ } ;
474
+
475
+ let batches = wait_for_future ( py, fut ( ) )
476
+ . to_py_err ( ) ?
477
+ . iter ( )
478
+ . map ( |batch| batch. to_pyarrow ( py) )
479
+ . collect :: < PyResult < Vec < _ > > > ( ) ?;
480
+
481
+ let pylist = PyList :: new ( py, batches) ?;
482
+ Ok ( pylist. into ( ) )
483
+ }
470
484
}
471
485
472
486
pub ( crate ) fn register_object_store_for_paths_in_plan (
@@ -570,62 +584,14 @@ mod test {
570
584
use std:: { sync:: Arc , vec} ;
571
585
572
586
use arrow:: {
573
- array:: { Int32Array , StringArray } ,
587
+ array:: Int32Array ,
574
588
datatypes:: { DataType , Field , Schema } ,
575
589
} ;
576
- use datafusion:: {
577
- parquet:: file:: properties:: WriterProperties , test_util:: parquet:: TestParquetFile ,
578
- } ;
590
+
579
591
use futures:: stream;
580
592
581
593
use super :: * ;
582
594
583
- #[ tokio:: test]
584
- async fn test_exec_sql ( ) {
585
- let dir = tempfile:: tempdir ( ) . unwrap ( ) ;
586
- let path = dir. path ( ) . join ( "people.parquet" ) ;
587
-
588
- let batch = RecordBatch :: try_new (
589
- Arc :: new ( Schema :: new ( vec ! [
590
- Field :: new( "age" , DataType :: Int32 , false ) ,
591
- Field :: new( "name" , DataType :: Utf8 , false ) ,
592
- ] ) ) ,
593
- vec ! [
594
- Arc :: new( Int32Array :: from( vec![ 11 , 12 , 13 ] ) ) ,
595
- Arc :: new( StringArray :: from( vec![ "alice" , "bob" , "cindy" ] ) ) ,
596
- ] ,
597
- )
598
- . unwrap ( ) ;
599
- let props = WriterProperties :: builder ( ) . build ( ) ;
600
- let file = TestParquetFile :: try_new ( path. clone ( ) , props, Some ( batch. clone ( ) ) ) . unwrap ( ) ;
601
-
602
- // test with file
603
- let tables = vec ! [ (
604
- "people" . to_string( ) ,
605
- format!( "file://{}" , file. path( ) . to_str( ) . unwrap( ) ) ,
606
- ) ] ;
607
- let query = "SELECT * FROM people ORDER BY age" . to_string ( ) ;
608
- let res = exec_sql ( query. clone ( ) , tables) . await . unwrap ( ) ;
609
- assert_eq ! (
610
- format!(
611
- "{}" ,
612
- pretty:: pretty_format_batches( & [ batch. clone( ) ] ) . unwrap( )
613
- ) ,
614
- format!( "{}" , pretty:: pretty_format_batches( & [ res] ) . unwrap( ) ) ,
615
- ) ;
616
-
617
- // test with dir
618
- let tables = vec ! [ (
619
- "people" . to_string( ) ,
620
- format!( "file://{}/" , dir. path( ) . to_str( ) . unwrap( ) ) ,
621
- ) ] ;
622
- let res = exec_sql ( query, tables) . await . unwrap ( ) ;
623
- assert_eq ! (
624
- format!( "{}" , pretty:: pretty_format_batches( & [ batch] ) . unwrap( ) ) ,
625
- format!( "{}" , pretty:: pretty_format_batches( & [ res] ) . unwrap( ) ) ,
626
- ) ;
627
- }
628
-
629
595
#[ test]
630
596
fn test_ipc_roundtrip ( ) {
631
597
let batch = RecordBatch :: try_new (
@@ -641,10 +607,9 @@ mod test {
641
607
#[ tokio:: test]
642
608
async fn test_max_rows_stream ( ) {
643
609
let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ) ;
644
- let batch = RecordBatch :: try_new (
645
- schema. clone ( ) ,
646
- vec ! [ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) ) ] ,
647
- )
610
+ let batch = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( Int32Array :: from( vec![
611
+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ,
612
+ ] ) ) ] )
648
613
. unwrap ( ) ;
649
614
650
615
// 24 total rows
0 commit comments