Skip to content

Commit 54dc191

Browse files
committed
fix: keep Arrow device schemas aligned with exports
Route export_device_array_with_schema through the exporter so schema derivation sees the same rebuilt host layout that gets exported. This fixes host ListView fallback cases where the Arrow C schema could describe a different child layout than the emitted array. Also tolerate null FFI error_out pointers per the C API contract and cover both regressions with tests. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent a289c23 commit 54dc191

3 files changed

Lines changed: 218 additions & 14 deletions

File tree

vortex-cuda/src/arrow/canonical.rs

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use cudarc::driver::result as cuda_driver;
1515
use futures::future::BoxFuture;
1616
use vortex::array::ArrayRef;
1717
use vortex::array::Canonical;
18+
use vortex::array::ExecutionCtx;
19+
use vortex::array::IntoArray;
1820
use vortex::array::arrays::DecimalArray;
1921
use vortex::array::arrays::Dict;
2022
use vortex::array::arrays::DictArray;
@@ -35,6 +37,7 @@ use vortex::array::arrays::extension::ExtensionArrayExt;
3537
use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt;
3638
use vortex::array::arrays::fixed_size_list::FixedSizeListDataParts;
3739
use vortex::array::arrays::list::ListDataParts;
40+
use vortex::array::arrays::listview::ListViewArrayExt;
3841
use vortex::array::arrays::listview::list_from_list_view;
3942
use vortex::array::arrays::primitive::PrimitiveDataParts;
4043
use vortex::array::arrays::struct_::StructDataParts;
@@ -63,10 +66,12 @@ use crate::CudaExecutionCtx;
6366
use crate::arrow::ARROW_DEVICE_CUDA;
6467
use crate::arrow::ArrowArray;
6568
use crate::arrow::ArrowDeviceArray;
69+
use crate::arrow::ArrowDeviceArrayWithSchema;
6670
use crate::arrow::ExportDeviceArray;
6771
use crate::arrow::PrivateData;
6872
use crate::arrow::SyncEvent;
6973
use crate::arrow::arrow_device_export_dictionary_codes_dtype;
74+
use crate::arrow::arrow_schema_for_array;
7075
use crate::arrow::cuda_decimal_value_type;
7176
use crate::arrow::list_view::export_device_list_view;
7277
use crate::cub::exclusive_sum_i32;
@@ -95,6 +100,92 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
95100
reserved: Default::default(),
96101
})
97102
}
103+
104+
async fn export_device_array_with_schema(
105+
&self,
106+
array: ArrayRef,
107+
ctx: &mut CudaExecutionCtx,
108+
) -> VortexResult<ArrowDeviceArrayWithSchema> {
109+
let array = rebuild_array_for_export_schema(array, ctx.execution_ctx())?;
110+
let schema = arrow_schema_for_array(&array, ctx)?;
111+
let array = self.export_device_array(array, ctx).await?;
112+
Ok(ArrowDeviceArrayWithSchema { schema, array })
113+
}
114+
}
115+
116+
/// Rebuild arrays whose exported layout differs from their original layout.
117+
fn rebuild_array_for_export_schema(
118+
array: ArrayRef,
119+
ctx: &mut ExecutionCtx,
120+
) -> VortexResult<ArrayRef> {
121+
let array = match array.try_downcast::<Dict>() {
122+
Ok(dict) => {
123+
let parts = dict.into_parts();
124+
let values = rebuild_array_for_export_schema(parts.values, ctx)?;
125+
return Ok(DictArray::try_new(parts.codes, values)?.into_array());
126+
}
127+
Err(array) => array,
128+
};
129+
let array = match array.try_downcast::<Struct>() {
130+
Ok(struct_array) => {
131+
let len = struct_array.len();
132+
let StructDataParts {
133+
struct_fields,
134+
fields,
135+
validity,
136+
} = struct_array.into_data_parts();
137+
let fields = fields
138+
.iter()
139+
.map(|field| rebuild_array_for_export_schema(field.clone(), ctx))
140+
.collect::<VortexResult<Vec<_>>>()?;
141+
return Ok(
142+
StructArray::try_new(struct_fields.names().clone(), fields, len, validity)?
143+
.into_array(),
144+
);
145+
}
146+
Err(array) => array,
147+
};
148+
let array = match array.try_downcast::<List>() {
149+
Ok(list) => {
150+
let ListDataParts {
151+
elements,
152+
offsets,
153+
validity,
154+
..
155+
} = list.into_data_parts();
156+
let elements = rebuild_array_for_export_schema(elements, ctx)?;
157+
return Ok(ListArray::try_new(elements, offsets, validity)?.into_array());
158+
}
159+
Err(array) => array,
160+
};
161+
let array = match array.try_downcast::<FixedSizeList>() {
162+
Ok(fixed_size_list) => {
163+
let len = fixed_size_list.len();
164+
let list_size = fixed_size_list.list_size();
165+
let FixedSizeListDataParts {
166+
elements, validity, ..
167+
} = fixed_size_list.into_data_parts();
168+
let elements = rebuild_array_for_export_schema(elements, ctx)?;
169+
return Ok(
170+
FixedSizeListArray::try_new(elements, list_size, validity, len)?.into_array(),
171+
);
172+
}
173+
Err(array) => array,
174+
};
175+
let array = match array.try_downcast::<ListView>() {
176+
Ok(listview)
177+
if listview.as_ref().is_host() && listview.elements().as_opt::<Dict>().is_some() =>
178+
{
179+
return rebuild_array_for_export_schema(
180+
list_from_list_view(listview, ctx)?.into_array(),
181+
ctx,
182+
);
183+
}
184+
Ok(listview) => return Ok(listview.into_array()),
185+
Err(array) => array,
186+
};
187+
188+
Ok(array)
98189
}
99190

100191
/// Export arrays whose Arrow layout depends on their concrete children before CUDA
@@ -2139,7 +2230,7 @@ mod tests {
21392230
}
21402231

21412232
#[crate::test]
2142-
async fn test_export_host_non_contiguous_dictionary_list_view_preserves_dictionary_child()
2233+
async fn test_export_host_non_contiguous_dictionary_list_view_schema_matches_rebuilt_child()
21432234
-> VortexResult<()> {
21442235
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
21452236
.vortex_expect("failed to create execution context");
@@ -2165,7 +2256,13 @@ mod tests {
21652256
"",
21662257
Field::new(
21672258
Field::LIST_FIELD_DEFAULT_NAME,
2168-
DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Int32)),
2259+
DataType::Dictionary(
2260+
Box::new(DataType::Int64),
2261+
Box::new(DataType::Dictionary(
2262+
Box::new(DataType::Int16),
2263+
Box::new(DataType::Int32),
2264+
)),
2265+
),
21692266
true,
21702267
),
21712268
false,
@@ -2180,6 +2277,57 @@ mod tests {
21802277
assert!(!dict_child.dictionary.is_null());
21812278
assert_eq!(dict_child.length, 5);
21822279
assert_eq!(dict_child.n_buffers, 2);
2280+
let nested_dict = unsafe { &*dict_child.dictionary };
2281+
assert!(!nested_dict.dictionary.is_null());
2282+
2283+
unsafe { release_exported_array(&raw mut exported.array.array) };
2284+
Ok(())
2285+
}
2286+
2287+
// Regression test: with an average list size >= 128 the host list-view rebuild picks its
2288+
// list-by-list strategy, which may canonicalize Dict elements. The schema must describe the
2289+
// rebuilt child layout.
2290+
#[crate::test]
2291+
async fn test_export_host_large_lists_dictionary_list_view_schema_matches_rebuilt_child()
2292+
-> VortexResult<()> {
2293+
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
2294+
.vortex_expect("failed to create execution context");
2295+
2296+
let elements = DictArray::try_new(
2297+
PrimitiveArray::from_option_iter(
2298+
(0..256u32).map(|i| (i % 5 != 0).then_some((i % 3) as u8)),
2299+
)
2300+
.into_array(),
2301+
PrimitiveArray::from_iter([10i32, 20, 30]).into_array(),
2302+
)?
2303+
.into_array();
2304+
let array = ListViewArray::new(
2305+
elements,
2306+
PrimitiveArray::from_iter([128i32, 0]).into_array(),
2307+
PrimitiveArray::from_iter([128i32, 128]).into_array(),
2308+
Validity::NonNullable,
2309+
)
2310+
.into_array();
2311+
let mut exported = array.export_device_array_with_schema(&mut ctx).await?;
2312+
2313+
let field = Field::try_from(&exported.schema)?;
2314+
assert_eq!(
2315+
field,
2316+
Field::new_list(
2317+
"",
2318+
Field::new(Field::LIST_FIELD_DEFAULT_NAME, DataType::Int32, true),
2319+
false,
2320+
)
2321+
);
2322+
assert_eq!(
2323+
private_data_buffer_i32_values(&exported.array.array, 1)?,
2324+
[0, 128, 256]
2325+
);
2326+
let list_children = unsafe { std::slice::from_raw_parts(exported.array.array.children, 1) };
2327+
let child = unsafe { &*list_children[0] };
2328+
assert!(child.dictionary.is_null());
2329+
assert_eq!(child.length, 256);
2330+
assert_eq!(child.n_buffers, 2);
21832331

21842332
unsafe { release_exported_array(&raw mut exported.array.array) };
21852333
Ok(())

vortex-cuda/src/arrow/mod.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,13 @@ impl DeviceArrayExt for ArrayRef {
255255
self,
256256
ctx: &mut CudaExecutionCtx,
257257
) -> VortexResult<ArrowDeviceArrayWithSchema> {
258-
let schema = arrow_schema_for_array(&self, ctx)?;
259-
let array = self.export_device_array(ctx).await?;
260-
Ok(ArrowDeviceArrayWithSchema { schema, array })
258+
let exporter = Arc::clone(ctx.exporter());
259+
exporter.export_device_array_with_schema(self, ctx).await
261260
}
262261
}
263262

264263
/// Build the Arrow C schema that describes the exported device array.
265-
fn arrow_schema_for_array(
264+
pub(crate) fn arrow_schema_for_array(
266265
array: &ArrayRef,
267266
ctx: &mut CudaExecutionCtx,
268267
) -> VortexResult<FFI_ArrowSchema> {
@@ -479,4 +478,15 @@ pub trait ExportDeviceArray: Debug + Send + Sync + 'static {
479478
array: ArrayRef,
480479
ctx: &mut CudaExecutionCtx,
481480
) -> VortexResult<ArrowDeviceArray>;
481+
482+
/// Export a Vortex array as an [`ArrowDeviceArray`] with a matching Arrow C schema.
483+
async fn export_device_array_with_schema(
484+
&self,
485+
array: ArrayRef,
486+
ctx: &mut CudaExecutionCtx,
487+
) -> VortexResult<ArrowDeviceArrayWithSchema> {
488+
let schema = arrow_schema_for_array(&array, ctx)?;
489+
let array = self.export_device_array(array, ctx).await?;
490+
Ok(ArrowDeviceArrayWithSchema { schema, array })
491+
}
482492
}

vortex-ffi/src/error.rs

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,30 @@ pub(crate) fn vx_error_new(message: &str) -> *mut vx_error {
2727
}
2828

2929
/// Write an error message to `error` which has not been populated before.
30+
/// A null `error` pointer discards the message.
3031
pub(crate) fn write_error(error: *mut *mut vx_error, message: &str) {
31-
assert!(!error.is_null());
32+
if error.is_null() {
33+
return;
34+
}
3235
unsafe { error.write(vx_error_new(message)) };
3336
}
3437

38+
/// Clear `*error_out` to null unless `error_out` itself is null.
39+
fn clear_error(error_out: *mut *mut vx_error) {
40+
if error_out.is_null() {
41+
return;
42+
}
43+
unsafe { error_out.write(ptr::null_mut()) };
44+
}
45+
3546
#[inline]
3647
pub fn try_or_default<T: Default>(
3748
error_out: *mut *mut vx_error,
3849
function: impl FnOnce() -> VortexResult<T>,
3950
) -> T {
4051
match function() {
4152
Ok(value) => {
42-
unsafe { error_out.write(ptr::null_mut()) };
53+
clear_error(error_out);
4354
value
4455
}
4556
Err(err) => {
@@ -51,19 +62,16 @@ pub fn try_or_default<T: Default>(
5162

5263
/// Run `function`, returning its value on success and `error_value` on failure.
5364
///
54-
/// On success `*error_out` is cleared to null; on failure the error is written to `*error_out`
55-
/// when it is non-null.
56-
// Writes through `error_out` but stays safe like the other error-out helpers here; the raw-pointer
57-
// contract is documented at the C boundary.
58-
#[allow(clippy::not_unsafe_ptr_arg_deref)]
65+
/// `error_out` may be null, in which case error details are discarded. When it is non-null,
66+
/// `*error_out` is cleared to null on success and set to an owned `vx_error` on failure.
5967
pub fn try_or<T>(
6068
error_out: *mut *mut vx_error,
6169
error_value: T,
6270
function: impl FnOnce() -> VortexResult<T>,
6371
) -> T {
6472
match function() {
6573
Ok(value) => {
66-
unsafe { error_out.write(ptr::null_mut()) };
74+
clear_error(error_out);
6775
value
6876
}
6977
Err(err) => {
@@ -81,3 +89,41 @@ pub fn try_or<T>(
8189
pub unsafe extern "C-unwind" fn vx_error_get_message(error: *const vx_error) -> *const vx_string {
8290
vx_string::new_ref(&vx_error::as_ref(error).message)
8391
}
92+
93+
#[cfg(test)]
94+
mod tests {
95+
use std::ptr;
96+
97+
use vortex::error::vortex_err;
98+
99+
use super::*;
100+
use crate::error::vx_error_free;
101+
102+
#[test]
103+
fn test_try_or_null_error_out() {
104+
// A null error_out must be tolerated on both the success and failure paths.
105+
assert_eq!(try_or(ptr::null_mut(), -1, || Ok(42)), 42);
106+
assert_eq!(try_or(ptr::null_mut(), -1, || Err(vortex_err!("boom"))), -1);
107+
}
108+
109+
#[test]
110+
fn test_try_or_default_null_error_out() {
111+
assert_eq!(try_or_default(ptr::null_mut(), || Ok(42)), 42);
112+
assert_eq!(
113+
try_or_default::<i32>(ptr::null_mut(), || Err(vortex_err!("boom"))),
114+
0
115+
);
116+
}
117+
118+
#[test]
119+
fn test_try_or_writes_and_clears_error_out() {
120+
let mut error: *mut vx_error = ptr::null_mut();
121+
122+
assert_eq!(try_or(&raw mut error, -1, || Err(vortex_err!("boom"))), -1);
123+
assert!(!error.is_null());
124+
unsafe { vx_error_free(error) };
125+
126+
assert_eq!(try_or(&raw mut error, -1, || Ok(42)), 42);
127+
assert!(error.is_null());
128+
}
129+
}

0 commit comments

Comments
 (0)