Skip to content

Commit 8acef3a

Browse files
authored
fix: keep Arrow device schemas aligned with exports (#8360)
Route export_device_array_with_schema through the exporter so schema derivation sees the same rebuilt host layout that gets exported. This fixes host ListView fallback cases where the Arrow C schema could describe a different child layout than the emitted array. Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 05871c3 commit 8acef3a

3 files changed

Lines changed: 218 additions & 14 deletions

File tree

vortex-cuda/src/arrow/canonical.rs

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use futures::future::BoxFuture;
1616
use futures::future::join;
1717
use vortex::array::ArrayRef;
1818
use vortex::array::Canonical;
19+
use vortex::array::ExecutionCtx;
20+
use vortex::array::IntoArray;
1921
use vortex::array::arrays::DecimalArray;
2022
use vortex::array::arrays::Dict;
2123
use vortex::array::arrays::DictArray;
@@ -36,6 +38,7 @@ use vortex::array::arrays::extension::ExtensionArrayExt;
3638
use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt;
3739
use vortex::array::arrays::fixed_size_list::FixedSizeListDataParts;
3840
use vortex::array::arrays::list::ListDataParts;
41+
use vortex::array::arrays::listview::ListViewArrayExt;
3942
use vortex::array::arrays::listview::list_from_list_view;
4043
use vortex::array::arrays::primitive::PrimitiveDataParts;
4144
use vortex::array::arrays::struct_::StructDataParts;
@@ -64,10 +67,12 @@ use crate::CudaExecutionCtx;
6467
use crate::arrow::ARROW_DEVICE_CUDA;
6568
use crate::arrow::ArrowArray;
6669
use crate::arrow::ArrowDeviceArray;
70+
use crate::arrow::ArrowDeviceArrayWithSchema;
6771
use crate::arrow::ExportDeviceArray;
6872
use crate::arrow::PrivateData;
6973
use crate::arrow::SyncEvent;
7074
use crate::arrow::arrow_device_export_dictionary_codes_dtype;
75+
use crate::arrow::arrow_schema_for_array;
7176
use crate::arrow::cuda_decimal_value_type;
7277
use crate::arrow::list_view::export_device_list_view;
7378
use crate::cub::exclusive_sum_i32;
@@ -96,6 +101,92 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
96101
reserved: Default::default(),
97102
})
98103
}
104+
105+
async fn export_device_array_with_schema(
106+
&self,
107+
array: ArrayRef,
108+
ctx: &mut CudaExecutionCtx,
109+
) -> VortexResult<ArrowDeviceArrayWithSchema> {
110+
let array = rebuild_array_for_export_schema(array, ctx.execution_ctx())?;
111+
let schema = arrow_schema_for_array(&array, ctx)?;
112+
let array = self.export_device_array(array, ctx).await?;
113+
Ok(ArrowDeviceArrayWithSchema { schema, array })
114+
}
115+
}
116+
117+
/// Rebuild arrays whose exported layout differs from their original layout.
118+
fn rebuild_array_for_export_schema(
119+
array: ArrayRef,
120+
ctx: &mut ExecutionCtx,
121+
) -> VortexResult<ArrayRef> {
122+
let array = match array.try_downcast::<Dict>() {
123+
Ok(dict) => {
124+
let parts = dict.into_parts();
125+
let values = rebuild_array_for_export_schema(parts.values, ctx)?;
126+
return Ok(DictArray::try_new(parts.codes, values)?.into_array());
127+
}
128+
Err(array) => array,
129+
};
130+
let array = match array.try_downcast::<Struct>() {
131+
Ok(struct_array) => {
132+
let len = struct_array.len();
133+
let StructDataParts {
134+
struct_fields,
135+
fields,
136+
validity,
137+
} = struct_array.into_data_parts();
138+
let fields = fields
139+
.iter()
140+
.map(|field| rebuild_array_for_export_schema(field.clone(), ctx))
141+
.collect::<VortexResult<Vec<_>>>()?;
142+
return Ok(
143+
StructArray::try_new(struct_fields.names().clone(), fields, len, validity)?
144+
.into_array(),
145+
);
146+
}
147+
Err(array) => array,
148+
};
149+
let array = match array.try_downcast::<List>() {
150+
Ok(list) => {
151+
let ListDataParts {
152+
elements,
153+
offsets,
154+
validity,
155+
..
156+
} = list.into_data_parts();
157+
let elements = rebuild_array_for_export_schema(elements, ctx)?;
158+
return Ok(ListArray::try_new(elements, offsets, validity)?.into_array());
159+
}
160+
Err(array) => array,
161+
};
162+
let array = match array.try_downcast::<FixedSizeList>() {
163+
Ok(fixed_size_list) => {
164+
let len = fixed_size_list.len();
165+
let list_size = fixed_size_list.list_size();
166+
let FixedSizeListDataParts {
167+
elements, validity, ..
168+
} = fixed_size_list.into_data_parts();
169+
let elements = rebuild_array_for_export_schema(elements, ctx)?;
170+
return Ok(
171+
FixedSizeListArray::try_new(elements, list_size, validity, len)?.into_array(),
172+
);
173+
}
174+
Err(array) => array,
175+
};
176+
let array = match array.try_downcast::<ListView>() {
177+
Ok(listview)
178+
if listview.as_ref().is_host() && listview.elements().as_opt::<Dict>().is_some() =>
179+
{
180+
return rebuild_array_for_export_schema(
181+
list_from_list_view(listview, ctx)?.into_array(),
182+
ctx,
183+
);
184+
}
185+
Ok(listview) => return Ok(listview.into_array()),
186+
Err(array) => array,
187+
};
188+
189+
Ok(array)
99190
}
100191

101192
/// Export arrays whose Arrow layout depends on their concrete children before CUDA
@@ -2136,7 +2227,7 @@ mod tests {
21362227
}
21372228

21382229
#[crate::test]
2139-
async fn test_export_host_non_contiguous_dictionary_list_view_preserves_dictionary_child()
2230+
async fn test_export_host_non_contiguous_dictionary_list_view_schema_matches_rebuilt_child()
21402231
-> VortexResult<()> {
21412232
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
21422233
.vortex_expect("failed to create execution context");
@@ -2162,7 +2253,13 @@ mod tests {
21622253
"",
21632254
Field::new(
21642255
Field::LIST_FIELD_DEFAULT_NAME,
2165-
DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Int32)),
2256+
DataType::Dictionary(
2257+
Box::new(DataType::Int64),
2258+
Box::new(DataType::Dictionary(
2259+
Box::new(DataType::Int16),
2260+
Box::new(DataType::Int32),
2261+
)),
2262+
),
21662263
true,
21672264
),
21682265
false,
@@ -2177,6 +2274,57 @@ mod tests {
21772274
assert!(!dict_child.dictionary.is_null());
21782275
assert_eq!(dict_child.length, 5);
21792276
assert_eq!(dict_child.n_buffers, 2);
2277+
let nested_dict = unsafe { &*dict_child.dictionary };
2278+
assert!(!nested_dict.dictionary.is_null());
2279+
2280+
unsafe { release_exported_array(&raw mut exported.array.array) };
2281+
Ok(())
2282+
}
2283+
2284+
// Regression test: with an average list size >= 128 the host list-view rebuild picks its
2285+
// list-by-list strategy, which may canonicalize Dict elements. The schema must describe the
2286+
// rebuilt child layout.
2287+
#[crate::test]
2288+
async fn test_export_host_large_lists_dictionary_list_view_schema_matches_rebuilt_child()
2289+
-> VortexResult<()> {
2290+
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
2291+
.vortex_expect("failed to create execution context");
2292+
2293+
let elements = DictArray::try_new(
2294+
PrimitiveArray::from_option_iter(
2295+
(0..256u32).map(|i| (i % 5 != 0).then_some((i % 3) as u8)),
2296+
)
2297+
.into_array(),
2298+
PrimitiveArray::from_iter([10i32, 20, 30]).into_array(),
2299+
)?
2300+
.into_array();
2301+
let array = ListViewArray::new(
2302+
elements,
2303+
PrimitiveArray::from_iter([128i32, 0]).into_array(),
2304+
PrimitiveArray::from_iter([128i32, 128]).into_array(),
2305+
Validity::NonNullable,
2306+
)
2307+
.into_array();
2308+
let mut exported = array.export_device_array_with_schema(&mut ctx).await?;
2309+
2310+
let field = Field::try_from(&exported.schema)?;
2311+
assert_eq!(
2312+
field,
2313+
Field::new_list(
2314+
"",
2315+
Field::new(Field::LIST_FIELD_DEFAULT_NAME, DataType::Int32, true),
2316+
false,
2317+
)
2318+
);
2319+
assert_eq!(
2320+
private_data_buffer_i32_values(&exported.array.array, 1)?,
2321+
[0, 128, 256]
2322+
);
2323+
let list_children = unsafe { std::slice::from_raw_parts(exported.array.array.children, 1) };
2324+
let child = unsafe { &*list_children[0] };
2325+
assert!(child.dictionary.is_null());
2326+
assert_eq!(child.length, 256);
2327+
assert_eq!(child.n_buffers, 2);
21802328

21812329
unsafe { release_exported_array(&raw mut exported.array.array) };
21822330
Ok(())

vortex-cuda/src/arrow/mod.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,13 @@ impl DeviceArrayExt for ArrayRef {
255255
self,
256256
ctx: &mut CudaExecutionCtx,
257257
) -> VortexResult<ArrowDeviceArrayWithSchema> {
258-
let schema = arrow_schema_for_array(&self, ctx)?;
259-
let array = self.export_device_array(ctx).await?;
260-
Ok(ArrowDeviceArrayWithSchema { schema, array })
258+
let exporter = Arc::clone(ctx.exporter());
259+
exporter.export_device_array_with_schema(self, ctx).await
261260
}
262261
}
263262

264263
/// Build the Arrow C schema that describes the exported device array.
265-
fn arrow_schema_for_array(
264+
pub(crate) fn arrow_schema_for_array(
266265
array: &ArrayRef,
267266
ctx: &mut CudaExecutionCtx,
268267
) -> VortexResult<FFI_ArrowSchema> {
@@ -479,4 +478,15 @@ pub trait ExportDeviceArray: Debug + Send + Sync + 'static {
479478
array: ArrayRef,
480479
ctx: &mut CudaExecutionCtx,
481480
) -> VortexResult<ArrowDeviceArray>;
481+
482+
/// Export a Vortex array as an [`ArrowDeviceArray`] with a matching Arrow C schema.
483+
async fn export_device_array_with_schema(
484+
&self,
485+
array: ArrayRef,
486+
ctx: &mut CudaExecutionCtx,
487+
) -> VortexResult<ArrowDeviceArrayWithSchema> {
488+
let schema = arrow_schema_for_array(&array, ctx)?;
489+
let array = self.export_device_array(array, ctx).await?;
490+
Ok(ArrowDeviceArrayWithSchema { schema, array })
491+
}
482492
}

vortex-ffi/src/error.rs

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,30 @@ pub(crate) fn vx_error_new(message: &str) -> *mut vx_error {
2727
}
2828

2929
/// Write an error message to `error` which has not been populated before.
30+
/// A null `error` pointer discards the message.
3031
pub(crate) fn write_error(error: *mut *mut vx_error, message: &str) {
31-
assert!(!error.is_null());
32+
if error.is_null() {
33+
return;
34+
}
3235
unsafe { error.write(vx_error_new(message)) };
3336
}
3437

38+
/// Clear `*error_out` to null unless `error_out` itself is null.
39+
fn clear_error(error_out: *mut *mut vx_error) {
40+
if error_out.is_null() {
41+
return;
42+
}
43+
unsafe { error_out.write(ptr::null_mut()) };
44+
}
45+
3546
#[inline]
3647
pub fn try_or_default<T: Default>(
3748
error_out: *mut *mut vx_error,
3849
function: impl FnOnce() -> VortexResult<T>,
3950
) -> T {
4051
match function() {
4152
Ok(value) => {
42-
unsafe { error_out.write(ptr::null_mut()) };
53+
clear_error(error_out);
4354
value
4455
}
4556
Err(err) => {
@@ -51,19 +62,16 @@ pub fn try_or_default<T: Default>(
5162

5263
/// Run `function`, returning its value on success and `error_value` on failure.
5364
///
54-
/// On success `*error_out` is cleared to null; on failure the error is written to `*error_out`
55-
/// when it is non-null.
56-
// Writes through `error_out` but stays safe like the other error-out helpers here; the raw-pointer
57-
// contract is documented at the C boundary.
58-
#[allow(clippy::not_unsafe_ptr_arg_deref)]
65+
/// `error_out` may be null, in which case error details are discarded. When it is non-null,
66+
/// `*error_out` is cleared to null on success and set to an owned `vx_error` on failure.
5967
pub fn try_or<T>(
6068
error_out: *mut *mut vx_error,
6169
error_value: T,
6270
function: impl FnOnce() -> VortexResult<T>,
6371
) -> T {
6472
match function() {
6573
Ok(value) => {
66-
unsafe { error_out.write(ptr::null_mut()) };
74+
clear_error(error_out);
6775
value
6876
}
6977
Err(err) => {
@@ -81,3 +89,41 @@ pub fn try_or<T>(
8189
pub unsafe extern "C-unwind" fn vx_error_get_message(error: *const vx_error) -> *const vx_string {
8290
vx_string::new_ref(&vx_error::as_ref(error).message)
8391
}
92+
93+
#[cfg(test)]
94+
mod tests {
95+
use std::ptr;
96+
97+
use vortex::error::vortex_err;
98+
99+
use super::*;
100+
use crate::error::vx_error_free;
101+
102+
#[test]
103+
fn test_try_or_null_error_out() {
104+
// A null error_out must be tolerated on both the success and failure paths.
105+
assert_eq!(try_or(ptr::null_mut(), -1, || Ok(42)), 42);
106+
assert_eq!(try_or(ptr::null_mut(), -1, || Err(vortex_err!("boom"))), -1);
107+
}
108+
109+
#[test]
110+
fn test_try_or_default_null_error_out() {
111+
assert_eq!(try_or_default(ptr::null_mut(), || Ok(42)), 42);
112+
assert_eq!(
113+
try_or_default::<i32>(ptr::null_mut(), || Err(vortex_err!("boom"))),
114+
0
115+
);
116+
}
117+
118+
#[test]
119+
fn test_try_or_writes_and_clears_error_out() {
120+
let mut error: *mut vx_error = ptr::null_mut();
121+
122+
assert_eq!(try_or(&raw mut error, -1, || Err(vortex_err!("boom"))), -1);
123+
assert!(!error.is_null());
124+
unsafe { vx_error_free(error) };
125+
126+
assert_eq!(try_or(&raw mut error, -1, || Ok(42)), 42);
127+
assert!(error.is_null());
128+
}
129+
}

0 commit comments

Comments
 (0)