@@ -43,25 +43,55 @@ pub struct ArrowInitData {
4343/// The Arrow table function.
4444pub struct ArrowVTab ;
4545
46- unsafe fn address_to_arrow_schema ( address : usize ) -> FFI_ArrowSchema {
47- let ptr = address as * mut FFI_ArrowSchema ;
48- unsafe { * Box :: from_raw ( ptr) }
46+ const ARROW_QUERY_PARAMS_MARKER : usize = 0x4152_5257 ; // "ARRW"
47+
48+ fn register_arrow_record_batch ( rb : RecordBatch ) -> [ usize ; 2 ] {
49+ let ptr = Box :: into_raw ( Box :: new ( rb) ) ;
50+ [ ptr as usize , ARROW_QUERY_PARAMS_MARKER ]
4951}
5052
51- unsafe fn address_to_arrow_array ( address : usize ) -> FFI_ArrowArray {
52- let ptr = address as * mut FFI_ArrowArray ;
53- unsafe { * Box :: from_raw ( ptr) }
53+ fn arrow_query_param_usize ( bind : & BindInfo , index : u64 , name : & str ) -> Result < usize , Box < dyn std:: error:: Error > > {
54+ let value = bind. get_parameter ( index) ;
55+ if value. is_null ( ) {
56+ return Err ( format ! ( "ArrowVTab {name} parameter must not be NULL" ) . into ( ) ) ;
57+ }
58+
59+ let logical_type = value. logical_type_id ( ) ;
60+ if logical_type != LogicalTypeId :: UBigint {
61+ return Err ( format ! ( "ArrowVTab {name} parameter must be UBIGINT, got {logical_type:?}" ) . into ( ) ) ;
62+ }
63+
64+ usize:: try_from ( value. to_uint64 ( ) ) . map_err ( |_| format ! ( "ArrowVTab {name} parameter does not fit in usize" ) . into ( ) )
5465}
5566
56- unsafe fn address_to_arrow_ffi ( array : usize , schema : usize ) -> ( FFI_ArrowArray , FFI_ArrowSchema ) {
57- let array = unsafe { address_to_arrow_array ( array) } ;
58- let schema = unsafe { address_to_arrow_schema ( schema) } ;
59- ( array, schema)
67+ /// Imports a record batch from the current opaque ArrowVTab query parameters.
68+ ///
69+ /// # Safety
70+ ///
71+ /// `address` must be a non-null pointer returned by
72+ /// [`arrow_recordbatch_to_query_params`], and `marker` must be the matching
73+ /// layout marker returned with it. The marker catches common misuse only and
74+ /// does not make stale or forged pointers safe to dereference.
75+ unsafe fn address_to_arrow_record_batch (
76+ address : usize ,
77+ marker : usize ,
78+ ) -> Result < RecordBatch , Box < dyn std:: error:: Error > > {
79+ let ptr = address as * const RecordBatch ;
80+ if ptr. is_null ( ) {
81+ return Err ( "invalid ArrowVTab record batch address" . into ( ) ) ;
82+ }
83+
84+ if marker != ARROW_QUERY_PARAMS_MARKER {
85+ return Err ( "ArrowVTab query parameter marker mismatch; use arrow_recordbatch_to_query_params" . into ( ) ) ;
86+ }
87+
88+ // SAFETY: The caller guarantees that `ptr` is the leaked RecordBatch
89+ // allocation created by `register_arrow_record_batch`.
90+ Ok ( unsafe { ( * ptr) . clone ( ) } )
6091}
6192
62- unsafe fn address_to_arrow_record_batch ( array : usize , schema : usize ) -> RecordBatch {
63- let ( array, schema) = unsafe { address_to_arrow_ffi ( array, schema) } ;
64- let array_data = unsafe { from_ffi ( array, & schema) } . expect ( "ok" ) ;
93+ fn arrow_record_batch_from_ffi ( array : FFI_ArrowArray , schema : FFI_ArrowSchema ) -> RecordBatch {
94+ let array_data = unsafe { from_ffi ( array, & schema) } . expect ( "failed to import Arrow FFI data" ) ;
6595 let struct_array = StructArray :: from ( array_data) ;
6696 RecordBatch :: from ( & struct_array)
6797}
@@ -75,20 +105,22 @@ impl VTab for ArrowVTab {
75105 if param_count != 2 {
76106 return Err ( format ! ( "Bad param count: {param_count}, expected 2" ) . into ( ) ) ;
77107 }
78- let array = bind. get_parameter ( 0 ) . to_int64 ( ) ;
79- let schema = bind. get_parameter ( 1 ) . to_int64 ( ) ;
80-
81- unsafe {
82- let rb = address_to_arrow_record_batch ( array as usize , schema as usize ) ;
83- for f in rb. schema ( ) . fields ( ) {
84- let name = f. name ( ) ;
85- let data_type = f. data_type ( ) ;
86- let logical_type = to_duckdb_logical_type ( data_type) ?;
87- bind. add_result_column ( name, logical_type) ;
88- }
108+ let address = arrow_query_param_usize ( bind, 0 , "record batch address" ) ?;
109+ let marker = arrow_query_param_usize ( bind, 1 , "marker" ) ?;
89110
90- Ok ( ArrowBindData { rb : Mutex :: new ( rb) } )
111+ // SAFETY: ArrowVTab's raw-parameter API relies on callers passing
112+ // values returned unchanged by `arrow_recordbatch_to_query_params`.
113+ // Validation above catches nulls, type mismatches, and layout marker
114+ // mismatches, but cannot validate forged addresses.
115+ let rb = unsafe { address_to_arrow_record_batch ( address, marker) ? } ;
116+ for f in rb. schema ( ) . fields ( ) {
117+ let name = f. name ( ) ;
118+ let data_type = f. data_type ( ) ;
119+ let logical_type = to_duckdb_logical_type ( data_type) ?;
120+ bind. add_result_column ( name, logical_type) ;
91121 }
122+
123+ Ok ( ArrowBindData { rb : Mutex :: new ( rb) } )
92124 }
93125
94126 fn init ( _: & InitInfo ) -> Result < Self :: InitData , Box < dyn std:: error:: Error > > {
@@ -114,13 +146,13 @@ impl VTab for ArrowVTab {
114146
115147 fn parameters ( ) -> Option < Vec < LogicalTypeHandle > > {
116148 Some ( vec ! [
117- LogicalTypeHandle :: from( LogicalTypeId :: UBigint ) , // file path
118- LogicalTypeHandle :: from( LogicalTypeId :: UBigint ) , // sheet name
149+ LogicalTypeHandle :: from( LogicalTypeId :: UBigint ) , // record batch address
150+ LogicalTypeHandle :: from( LogicalTypeId :: UBigint ) , // query parameter marker
119151 ] )
120152 }
121153}
122154
123- /// Convert arrow DataType to duckdb type id
155+ /// Convert arrow DataType to DuckDB type id
124156pub fn to_duckdb_type_id ( data_type : & DataType ) -> Result < LogicalTypeId , Box < dyn std:: error:: Error > > {
125157 use LogicalTypeId :: * ;
126158
@@ -181,7 +213,7 @@ impl TryFrom<DataType> for LogicalTypeId {
181213 }
182214}
183215
184- /// Convert arrow DataType to duckdb logical type
216+ /// Convert arrow DataType to DuckDB logical type
185217pub fn to_duckdb_logical_type ( data_type : & DataType ) -> Result < LogicalTypeHandle , Box < dyn std:: error:: Error > > {
186218 match data_type {
187219 DataType :: Dictionary ( _, value_type) => to_duckdb_logical_type ( value_type) ,
@@ -1040,37 +1072,34 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector<'_>) -> Re
10401072 Ok ( ( ) )
10411073}
10421074
1043- /// Pass RecordBatch to duckdb .
1075+ /// Pass a RecordBatch to DuckDB .
10441076///
1045- /// # Safety
1046- /// The caller must ensure that the pointer is valid
1047- /// It's recommended to always use this function with arrow()
1077+ /// This returns opaque query parameters for [`ArrowVTab`].
1078+ ///
1079+ /// Each call leaks one boxed [`RecordBatch`] for the lifetime of the process so
1080+ /// stored views can rebind the same parameters later. Do not call this per row
1081+ /// or per query. Create these parameters once per logical table or view.
10481082pub fn arrow_recordbatch_to_query_params ( rb : RecordBatch ) -> [ usize ; 2 ] {
1049- let data = ArrayData :: from ( StructArray :: from ( rb) ) ;
1050- arrow_arraydata_to_query_params ( data)
1083+ register_arrow_record_batch ( rb)
10511084}
10521085
1053- /// Pass ArrayData to duckdb .
1086+ /// Pass ArrayData to DuckDB .
10541087///
1055- /// # Safety
1056- /// The caller must ensure that the pointer is valid
1057- /// It's recommended to always use this function with arrow()
1088+ /// This converts the [`ArrayData`] to a [`RecordBatch`] immediately. Like
1089+ /// [`arrow_recordbatch_to_query_params`], each call leaks one boxed
1090+ /// [`RecordBatch`] for the lifetime of the process.
10581091pub fn arrow_arraydata_to_query_params ( data : ArrayData ) -> [ usize ; 2 ] {
1059- let array = FFI_ArrowArray :: new ( & data) ;
1060- let schema = FFI_ArrowSchema :: try_from ( data. data_type ( ) ) . expect ( "Failed to convert schema" ) ;
1061- arrow_ffi_to_query_params ( array, schema)
1092+ let struct_array = StructArray :: from ( data) ;
1093+ arrow_recordbatch_to_query_params ( RecordBatch :: from ( & struct_array) )
10621094}
10631095
1064- /// Pass array and schema as a pointer to duckdb .
1096+ /// Pass array and schema to DuckDB .
10651097///
1066- /// # Safety
1067- /// The caller must ensure that the pointer is valid
1068- /// It's recommended to always use this function with arrow()
1098+ /// This imports the FFI values immediately. Like
1099+ /// [`arrow_recordbatch_to_query_params`], each call leaks one boxed
1100+ /// [`RecordBatch`] for the lifetime of the process.
10691101pub fn arrow_ffi_to_query_params ( array : FFI_ArrowArray , schema : FFI_ArrowSchema ) -> [ usize ; 2 ] {
1070- let arr = Box :: into_raw ( Box :: new ( array) ) ;
1071- let sch = Box :: into_raw ( Box :: new ( schema) ) ;
1072-
1073- [ arr as * mut _ as usize , sch as * mut _ as usize ]
1102+ arrow_recordbatch_to_query_params ( arrow_record_batch_from_ffi ( array, schema) )
10741103}
10751104
10761105fn set_nulls_in_flat_vector ( array : & dyn Array , out_vector : & mut FlatVector < ' _ > ) {
@@ -1115,7 +1144,10 @@ fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector<'_>)
11151144
11161145#[ cfg( test) ]
11171146mod test {
1118- use super :: { ArrowVTab , arrow_recordbatch_to_query_params} ;
1147+ use super :: {
1148+ ARROW_QUERY_PARAMS_MARKER , ArrowVTab , arrow_arraydata_to_query_params, arrow_ffi_to_query_params,
1149+ arrow_recordbatch_to_query_params,
1150+ } ;
11191151 use crate :: { Connection , Result } ;
11201152 use arrow:: {
11211153 array:: {
@@ -1132,10 +1164,28 @@ mod test {
11321164 ArrowPrimitiveType , ByteArrayType , DataType , DurationSecondType , Field , IntervalDayTimeType ,
11331165 IntervalMonthDayNanoType , IntervalYearMonthType , Schema , i256,
11341166 } ,
1167+ ffi:: { FFI_ArrowArray , FFI_ArrowSchema } ,
11351168 record_batch:: RecordBatch ,
11361169 } ;
11371170 use std:: { error:: Error , sync:: Arc } ;
11381171
1172+ fn example_record_batch ( ) -> RecordBatch {
1173+ let schema = Schema :: new ( vec ! [
1174+ Field :: new( "id" , DataType :: Int32 , true ) ,
1175+ Field :: new( "name" , DataType :: Utf8 , true ) ,
1176+ Field :: new( "is_odd" , DataType :: Boolean , true ) ,
1177+ ] ) ;
1178+ RecordBatch :: try_new (
1179+ Arc :: new ( schema) ,
1180+ vec ! [
1181+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 , 4 ] ) ) as ArrayRef ,
1182+ Arc :: new( StringArray :: from( vec![ "apple" , "banana" , "cherry" , "date" ] ) ) as ArrayRef ,
1183+ Arc :: new( BooleanArray :: from( vec![ true , false , true , false ] ) ) as ArrayRef ,
1184+ ] ,
1185+ )
1186+ . expect ( "failed to create record batch" )
1187+ }
1188+
11391189 #[ test]
11401190 fn test_vtab_arrow ( ) -> Result < ( ) , Box < dyn Error > > {
11411191 let db = Connection :: open_in_memory ( ) ?;
@@ -1162,7 +1212,7 @@ mod test {
11621212 db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
11631213
11641214 // This is a show case that it's easy for you to build an in-memory data
1165- // and pass into duckdb
1215+ // and pass into DuckDB
11661216 let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ;
11671217 let array = Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ;
11681218 let rb = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( array) ] ) . expect ( "failed to create record batch" ) ;
@@ -1177,6 +1227,102 @@ mod test {
11771227 Ok ( ( ) )
11781228 }
11791229
1230+ #[ test]
1231+ fn test_vtab_arrow_view_can_rebind_record_batch ( ) -> Result < ( ) , Box < dyn Error > > {
1232+ let db = Connection :: open_in_memory ( ) ?;
1233+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
1234+
1235+ let batch = example_record_batch ( ) ;
1236+ let param = arrow_recordbatch_to_query_params ( batch. clone ( ) ) ;
1237+ db. execute (
1238+ & format ! (
1239+ "CREATE VIEW arrow_view AS SELECT * FROM arrow({}::UBIGINT, {}::UBIGINT)" ,
1240+ param[ 0 ] , param[ 1 ]
1241+ ) ,
1242+ [ ] ,
1243+ ) ?;
1244+
1245+ for _ in 0 ..2 {
1246+ let rbs: Vec < RecordBatch > = db. prepare ( "SELECT * FROM arrow_view" ) ?. query_arrow ( [ ] ) ?. collect ( ) ;
1247+ assert_eq ! ( vec![ batch. clone( ) ] , rbs) ;
1248+ }
1249+
1250+ Ok ( ( ) )
1251+ }
1252+
1253+ #[ test]
1254+ fn test_vtab_arrow_arraydata_query_params ( ) -> Result < ( ) , Box < dyn Error > > {
1255+ let batch = example_record_batch ( ) ;
1256+ let struct_array = StructArray :: from ( batch) ;
1257+ let param = arrow_arraydata_to_query_params ( struct_array. to_data ( ) ) ;
1258+
1259+ let db = Connection :: open_in_memory ( ) ?;
1260+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
1261+ let mut stmt = db. prepare ( "select sum(id)::int32 from arrow(?, ?)" ) ?;
1262+ let rb = stmt. query_arrow ( param) ?. next ( ) . expect ( "no record batch" ) ;
1263+ let column = rb. column ( 0 ) . as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
1264+ assert_eq ! ( column. value( 0 ) , 10 ) ;
1265+ Ok ( ( ) )
1266+ }
1267+
1268+ #[ test]
1269+ fn test_vtab_arrow_ffi_query_params ( ) -> Result < ( ) , Box < dyn Error > > {
1270+ let batch = example_record_batch ( ) ;
1271+ let struct_array = StructArray :: from ( batch) ;
1272+ let array = FFI_ArrowArray :: new ( & struct_array. to_data ( ) ) ;
1273+ let schema = FFI_ArrowSchema :: try_from ( struct_array. data_type ( ) ) ?;
1274+ let param = arrow_ffi_to_query_params ( array, schema) ;
1275+
1276+ let db = Connection :: open_in_memory ( ) ?;
1277+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
1278+ let mut stmt = db. prepare ( "select sum(id)::int32 from arrow(?, ?)" ) ?;
1279+ let rb = stmt. query_arrow ( param) ?. next ( ) . expect ( "no record batch" ) ;
1280+ let column = rb. column ( 0 ) . as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
1281+ assert_eq ! ( column. value( 0 ) , 10 ) ;
1282+ Ok ( ( ) )
1283+ }
1284+
1285+ #[ test]
1286+ fn test_arrow_null_query_params_error ( ) {
1287+ let db = Connection :: open_in_memory ( ) . unwrap ( ) ;
1288+ db. register_table_function :: < ArrowVTab > ( "arrow" ) . unwrap ( ) ;
1289+
1290+ let err = db. prepare ( "SELECT * FROM arrow(NULL, NULL)" ) . err ( ) . unwrap ( ) ;
1291+ assert ! (
1292+ err. to_string( )
1293+ . contains( "ArrowVTab record batch address parameter must not be NULL" ) ,
1294+ "unexpected error: {err}"
1295+ ) ;
1296+ }
1297+
1298+ #[ test]
1299+ fn test_arrow_zero_address_query_params_error ( ) {
1300+ let db = Connection :: open_in_memory ( ) . unwrap ( ) ;
1301+ db. register_table_function :: < ArrowVTab > ( "arrow" ) . unwrap ( ) ;
1302+
1303+ let sql = format ! (
1304+ "SELECT * FROM arrow(0::UBIGINT, {}::UBIGINT)" ,
1305+ ARROW_QUERY_PARAMS_MARKER
1306+ ) ;
1307+ let err = db. prepare ( & sql) . err ( ) . unwrap ( ) ;
1308+ assert ! (
1309+ err. to_string( ) . contains( "invalid ArrowVTab record batch address" ) ,
1310+ "unexpected error: {err}"
1311+ ) ;
1312+ }
1313+
1314+ #[ test]
1315+ fn test_arrow_marker_mismatch_query_params_error ( ) {
1316+ let db = Connection :: open_in_memory ( ) . unwrap ( ) ;
1317+ db. register_table_function :: < ArrowVTab > ( "arrow" ) . unwrap ( ) ;
1318+
1319+ let err = db. prepare ( "SELECT * FROM arrow(1::UBIGINT, 2::UBIGINT)" ) . err ( ) . unwrap ( ) ;
1320+ assert ! (
1321+ err. to_string( ) . contains( "query parameter marker mismatch" ) ,
1322+ "unexpected error: {err}"
1323+ ) ;
1324+ }
1325+
11801326 #[ test]
11811327 #[ cfg( feature = "appender-arrow" ) ]
11821328 fn test_append_struct ( ) -> Result < ( ) , Box < dyn Error > > {
0 commit comments