Skip to content

Commit 6fe0a3d

Browse files
committed
Fix ArrowVTab view rebind lifetime
`ArrowVTab` query parameters used to point at Arrow FFI structs that were consumed during `bind` with `Box::from_raw`. That made the parameters single-use: a stored view keeps the same numeric literals, so later DuckDB rebinds dereferenced stale memory and could segfault or double-free. Store a stable boxed `RecordBatch` pointer instead and clone the batch during `bind`. The box is intentionally leaked under the current raw-parameter design so views can rebind later. Also validate the opaque `UBIGINT` parameters and document the unsafe pointer contract.
1 parent 009281a commit 6fe0a3d

1 file changed

Lines changed: 198 additions & 52 deletions

File tree

crates/duckdb/src/vtab/arrow.rs

Lines changed: 198 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,55 @@ pub struct ArrowInitData {
4343
/// The Arrow table function.
4444
pub struct ArrowVTab;
4545

46-
unsafe fn address_to_arrow_schema(address: usize) -> FFI_ArrowSchema {
47-
let ptr = address as *mut FFI_ArrowSchema;
48-
unsafe { *Box::from_raw(ptr) }
46+
const ARROW_QUERY_PARAMS_MARKER: usize = 0x4152_5257; // "ARRW"
47+
48+
fn register_arrow_record_batch(rb: RecordBatch) -> [usize; 2] {
49+
let ptr = Box::into_raw(Box::new(rb));
50+
[ptr as usize, ARROW_QUERY_PARAMS_MARKER]
4951
}
5052

51-
unsafe fn address_to_arrow_array(address: usize) -> FFI_ArrowArray {
52-
let ptr = address as *mut FFI_ArrowArray;
53-
unsafe { *Box::from_raw(ptr) }
53+
fn arrow_query_param_usize(bind: &BindInfo, index: u64, name: &str) -> Result<usize, Box<dyn std::error::Error>> {
54+
let value = bind.get_parameter(index);
55+
if value.is_null() {
56+
return Err(format!("ArrowVTab {name} parameter must not be NULL").into());
57+
}
58+
59+
let logical_type = value.logical_type_id();
60+
if logical_type != LogicalTypeId::UBigint {
61+
return Err(format!("ArrowVTab {name} parameter must be UBIGINT, got {logical_type:?}").into());
62+
}
63+
64+
usize::try_from(value.to_uint64()).map_err(|_| format!("ArrowVTab {name} parameter does not fit in usize").into())
5465
}
5566

56-
unsafe fn address_to_arrow_ffi(array: usize, schema: usize) -> (FFI_ArrowArray, FFI_ArrowSchema) {
57-
let array = unsafe { address_to_arrow_array(array) };
58-
let schema = unsafe { address_to_arrow_schema(schema) };
59-
(array, schema)
67+
/// Imports a record batch from the current opaque ArrowVTab query parameters.
68+
///
69+
/// # Safety
70+
///
71+
/// `address` must be a non-null pointer returned by
72+
/// [`arrow_recordbatch_to_query_params`], and `marker` must be the matching
73+
/// layout marker returned with it. The marker catches common misuse only and
74+
/// does not make stale or forged pointers safe to dereference.
75+
unsafe fn address_to_arrow_record_batch(
76+
address: usize,
77+
marker: usize,
78+
) -> Result<RecordBatch, Box<dyn std::error::Error>> {
79+
let ptr = address as *const RecordBatch;
80+
if ptr.is_null() {
81+
return Err("invalid ArrowVTab record batch address".into());
82+
}
83+
84+
if marker != ARROW_QUERY_PARAMS_MARKER {
85+
return Err("ArrowVTab query parameter marker mismatch; use arrow_recordbatch_to_query_params".into());
86+
}
87+
88+
// SAFETY: The caller guarantees that `ptr` is the leaked RecordBatch
89+
// allocation created by `register_arrow_record_batch`.
90+
Ok(unsafe { (*ptr).clone() })
6091
}
6192

62-
unsafe fn address_to_arrow_record_batch(array: usize, schema: usize) -> RecordBatch {
63-
let (array, schema) = unsafe { address_to_arrow_ffi(array, schema) };
64-
let array_data = unsafe { from_ffi(array, &schema) }.expect("ok");
93+
fn arrow_record_batch_from_ffi(array: FFI_ArrowArray, schema: FFI_ArrowSchema) -> RecordBatch {
94+
let array_data = unsafe { from_ffi(array, &schema) }.expect("failed to import Arrow FFI data");
6595
let struct_array = StructArray::from(array_data);
6696
RecordBatch::from(&struct_array)
6797
}
@@ -75,20 +105,22 @@ impl VTab for ArrowVTab {
75105
if param_count != 2 {
76106
return Err(format!("Bad param count: {param_count}, expected 2").into());
77107
}
78-
let array = bind.get_parameter(0).to_int64();
79-
let schema = bind.get_parameter(1).to_int64();
80-
81-
unsafe {
82-
let rb = address_to_arrow_record_batch(array as usize, schema as usize);
83-
for f in rb.schema().fields() {
84-
let name = f.name();
85-
let data_type = f.data_type();
86-
let logical_type = to_duckdb_logical_type(data_type)?;
87-
bind.add_result_column(name, logical_type);
88-
}
108+
let address = arrow_query_param_usize(bind, 0, "record batch address")?;
109+
let marker = arrow_query_param_usize(bind, 1, "marker")?;
89110

90-
Ok(ArrowBindData { rb: Mutex::new(rb) })
111+
// SAFETY: ArrowVTab's raw-parameter API relies on callers passing
112+
// values returned unchanged by `arrow_recordbatch_to_query_params`.
113+
// Validation above catches nulls, type mismatches, and layout marker
114+
// mismatches, but cannot validate forged addresses.
115+
let rb = unsafe { address_to_arrow_record_batch(address, marker)? };
116+
for f in rb.schema().fields() {
117+
let name = f.name();
118+
let data_type = f.data_type();
119+
let logical_type = to_duckdb_logical_type(data_type)?;
120+
bind.add_result_column(name, logical_type);
91121
}
122+
123+
Ok(ArrowBindData { rb: Mutex::new(rb) })
92124
}
93125

94126
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
@@ -114,13 +146,13 @@ impl VTab for ArrowVTab {
114146

115147
fn parameters() -> Option<Vec<LogicalTypeHandle>> {
116148
Some(vec![
117-
LogicalTypeHandle::from(LogicalTypeId::UBigint), // file path
118-
LogicalTypeHandle::from(LogicalTypeId::UBigint), // sheet name
149+
LogicalTypeHandle::from(LogicalTypeId::UBigint), // record batch address
150+
LogicalTypeHandle::from(LogicalTypeId::UBigint), // query parameter marker
119151
])
120152
}
121153
}
122154

123-
/// Convert arrow DataType to duckdb type id
155+
/// Convert arrow DataType to DuckDB type id
124156
pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn std::error::Error>> {
125157
use LogicalTypeId::*;
126158

@@ -181,7 +213,7 @@ impl TryFrom<DataType> for LogicalTypeId {
181213
}
182214
}
183215

184-
/// Convert arrow DataType to duckdb logical type
216+
/// Convert arrow DataType to DuckDB logical type
185217
pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle, Box<dyn std::error::Error>> {
186218
match data_type {
187219
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type),
@@ -1040,37 +1072,34 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector<'_>) -> Re
10401072
Ok(())
10411073
}
10421074

1043-
/// Pass RecordBatch to duckdb.
1075+
/// Pass a RecordBatch to DuckDB.
10441076
///
1045-
/// # Safety
1046-
/// The caller must ensure that the pointer is valid
1047-
/// It's recommended to always use this function with arrow()
1077+
/// This returns opaque query parameters for [`ArrowVTab`].
1078+
///
1079+
/// Each call leaks one boxed [`RecordBatch`] for the lifetime of the process so
1080+
/// stored views can rebind the same parameters later. Do not call this per row
1081+
/// or per query. Create these parameters once per logical table or view.
10481082
pub fn arrow_recordbatch_to_query_params(rb: RecordBatch) -> [usize; 2] {
1049-
let data = ArrayData::from(StructArray::from(rb));
1050-
arrow_arraydata_to_query_params(data)
1083+
register_arrow_record_batch(rb)
10511084
}
10521085

1053-
/// Pass ArrayData to duckdb.
1086+
/// Pass ArrayData to DuckDB.
10541087
///
1055-
/// # Safety
1056-
/// The caller must ensure that the pointer is valid
1057-
/// It's recommended to always use this function with arrow()
1088+
/// This converts the [`ArrayData`] to a [`RecordBatch`] immediately. Like
1089+
/// [`arrow_recordbatch_to_query_params`], each call leaks one boxed
1090+
/// [`RecordBatch`] for the lifetime of the process.
10581091
pub fn arrow_arraydata_to_query_params(data: ArrayData) -> [usize; 2] {
1059-
let array = FFI_ArrowArray::new(&data);
1060-
let schema = FFI_ArrowSchema::try_from(data.data_type()).expect("Failed to convert schema");
1061-
arrow_ffi_to_query_params(array, schema)
1092+
let struct_array = StructArray::from(data);
1093+
arrow_recordbatch_to_query_params(RecordBatch::from(&struct_array))
10621094
}
10631095

1064-
/// Pass array and schema as a pointer to duckdb.
1096+
/// Pass array and schema to DuckDB.
10651097
///
1066-
/// # Safety
1067-
/// The caller must ensure that the pointer is valid
1068-
/// It's recommended to always use this function with arrow()
1098+
/// This imports the FFI values immediately. Like
1099+
/// [`arrow_recordbatch_to_query_params`], each call leaks one boxed
1100+
/// [`RecordBatch`] for the lifetime of the process.
10691101
pub fn arrow_ffi_to_query_params(array: FFI_ArrowArray, schema: FFI_ArrowSchema) -> [usize; 2] {
1070-
let arr = Box::into_raw(Box::new(array));
1071-
let sch = Box::into_raw(Box::new(schema));
1072-
1073-
[arr as *mut _ as usize, sch as *mut _ as usize]
1102+
arrow_recordbatch_to_query_params(arrow_record_batch_from_ffi(array, schema))
10741103
}
10751104

10761105
fn set_nulls_in_flat_vector(array: &dyn Array, out_vector: &mut FlatVector<'_>) {
@@ -1115,7 +1144,10 @@ fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector<'_>)
11151144

11161145
#[cfg(test)]
11171146
mod test {
1118-
use super::{ArrowVTab, arrow_recordbatch_to_query_params};
1147+
use super::{
1148+
ARROW_QUERY_PARAMS_MARKER, ArrowVTab, arrow_arraydata_to_query_params, arrow_ffi_to_query_params,
1149+
arrow_recordbatch_to_query_params,
1150+
};
11191151
use crate::{Connection, Result};
11201152
use arrow::{
11211153
array::{
@@ -1132,10 +1164,28 @@ mod test {
11321164
ArrowPrimitiveType, ByteArrayType, DataType, DurationSecondType, Field, IntervalDayTimeType,
11331165
IntervalMonthDayNanoType, IntervalYearMonthType, Schema, i256,
11341166
},
1167+
ffi::{FFI_ArrowArray, FFI_ArrowSchema},
11351168
record_batch::RecordBatch,
11361169
};
11371170
use std::{error::Error, sync::Arc};
11381171

1172+
fn example_record_batch() -> RecordBatch {
1173+
let schema = Schema::new(vec![
1174+
Field::new("id", DataType::Int32, true),
1175+
Field::new("name", DataType::Utf8, true),
1176+
Field::new("is_odd", DataType::Boolean, true),
1177+
]);
1178+
RecordBatch::try_new(
1179+
Arc::new(schema),
1180+
vec![
1181+
Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as ArrayRef,
1182+
Arc::new(StringArray::from(vec!["apple", "banana", "cherry", "date"])) as ArrayRef,
1183+
Arc::new(BooleanArray::from(vec![true, false, true, false])) as ArrayRef,
1184+
],
1185+
)
1186+
.expect("failed to create record batch")
1187+
}
1188+
11391189
#[test]
11401190
fn test_vtab_arrow() -> Result<(), Box<dyn Error>> {
11411191
let db = Connection::open_in_memory()?;
@@ -1162,7 +1212,7 @@ mod test {
11621212
db.register_table_function::<ArrowVTab>("arrow")?;
11631213

11641214
// This is a show case that it's easy for you to build an in-memory data
1165-
// and pass into duckdb
1215+
// and pass into DuckDB
11661216
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
11671217
let array = Int32Array::from(vec![1, 2, 3, 4, 5]);
11681218
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch");
@@ -1177,6 +1227,102 @@ mod test {
11771227
Ok(())
11781228
}
11791229

1230+
#[test]
1231+
fn test_vtab_arrow_view_can_rebind_record_batch() -> Result<(), Box<dyn Error>> {
1232+
let db = Connection::open_in_memory()?;
1233+
db.register_table_function::<ArrowVTab>("arrow")?;
1234+
1235+
let batch = example_record_batch();
1236+
let param = arrow_recordbatch_to_query_params(batch.clone());
1237+
db.execute(
1238+
&format!(
1239+
"CREATE VIEW arrow_view AS SELECT * FROM arrow({}::UBIGINT, {}::UBIGINT)",
1240+
param[0], param[1]
1241+
),
1242+
[],
1243+
)?;
1244+
1245+
for _ in 0..2 {
1246+
let rbs: Vec<RecordBatch> = db.prepare("SELECT * FROM arrow_view")?.query_arrow([])?.collect();
1247+
assert_eq!(vec![batch.clone()], rbs);
1248+
}
1249+
1250+
Ok(())
1251+
}
1252+
1253+
#[test]
1254+
fn test_vtab_arrow_arraydata_query_params() -> Result<(), Box<dyn Error>> {
1255+
let batch = example_record_batch();
1256+
let struct_array = StructArray::from(batch);
1257+
let param = arrow_arraydata_to_query_params(struct_array.to_data());
1258+
1259+
let db = Connection::open_in_memory()?;
1260+
db.register_table_function::<ArrowVTab>("arrow")?;
1261+
let mut stmt = db.prepare("select sum(id)::int32 from arrow(?, ?)")?;
1262+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
1263+
let column = rb.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
1264+
assert_eq!(column.value(0), 10);
1265+
Ok(())
1266+
}
1267+
1268+
#[test]
1269+
fn test_vtab_arrow_ffi_query_params() -> Result<(), Box<dyn Error>> {
1270+
let batch = example_record_batch();
1271+
let struct_array = StructArray::from(batch);
1272+
let array = FFI_ArrowArray::new(&struct_array.to_data());
1273+
let schema = FFI_ArrowSchema::try_from(struct_array.data_type())?;
1274+
let param = arrow_ffi_to_query_params(array, schema);
1275+
1276+
let db = Connection::open_in_memory()?;
1277+
db.register_table_function::<ArrowVTab>("arrow")?;
1278+
let mut stmt = db.prepare("select sum(id)::int32 from arrow(?, ?)")?;
1279+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
1280+
let column = rb.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
1281+
assert_eq!(column.value(0), 10);
1282+
Ok(())
1283+
}
1284+
1285+
#[test]
1286+
fn test_arrow_null_query_params_error() {
1287+
let db = Connection::open_in_memory().unwrap();
1288+
db.register_table_function::<ArrowVTab>("arrow").unwrap();
1289+
1290+
let err = db.prepare("SELECT * FROM arrow(NULL, NULL)").err().unwrap();
1291+
assert!(
1292+
err.to_string()
1293+
.contains("ArrowVTab record batch address parameter must not be NULL"),
1294+
"unexpected error: {err}"
1295+
);
1296+
}
1297+
1298+
#[test]
1299+
fn test_arrow_zero_address_query_params_error() {
1300+
let db = Connection::open_in_memory().unwrap();
1301+
db.register_table_function::<ArrowVTab>("arrow").unwrap();
1302+
1303+
let sql = format!(
1304+
"SELECT * FROM arrow(0::UBIGINT, {}::UBIGINT)",
1305+
ARROW_QUERY_PARAMS_MARKER
1306+
);
1307+
let err = db.prepare(&sql).err().unwrap();
1308+
assert!(
1309+
err.to_string().contains("invalid ArrowVTab record batch address"),
1310+
"unexpected error: {err}"
1311+
);
1312+
}
1313+
1314+
#[test]
1315+
fn test_arrow_marker_mismatch_query_params_error() {
1316+
let db = Connection::open_in_memory().unwrap();
1317+
db.register_table_function::<ArrowVTab>("arrow").unwrap();
1318+
1319+
let err = db.prepare("SELECT * FROM arrow(1::UBIGINT, 2::UBIGINT)").err().unwrap();
1320+
assert!(
1321+
err.to_string().contains("query parameter marker mismatch"),
1322+
"unexpected error: {err}"
1323+
);
1324+
}
1325+
11801326
#[test]
11811327
#[cfg(feature = "appender-arrow")]
11821328
fn test_append_struct() -> Result<(), Box<dyn Error>> {

0 commit comments

Comments
 (0)