Skip to content

Commit 848fd89

Browse files
authored
fix[cuda]: Fix gpu_scan test (#6208)
- Slice not being applied lazily - execute_cuda for struct needs to pushdown to fields - CudaDeviceBuffer applying slicing at item instead of byte granularity - VarBinView had latent GPU bug leftover from BufferHandles conversion - BitPackedExecutor using wrong slice indices - `CudaDeviceBuffer` to type-erase the underlying allocation so it matches`BufferHandle` more closely --------- Signed-off-by: Andrew Duffy <[email protected]>
1 parent 4b907e5 commit 848fd89

File tree

23 files changed

+399
-112
lines changed

23 files changed

+399
-112
lines changed

encodings/alp/src/alp/array.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
use std::fmt::Debug;
55
use std::hash::Hash;
6-
use std::ops::Range;
76

87
use vortex_array::Array;
98
use vortex_array::ArrayBufferVisitor;
@@ -18,6 +17,7 @@ use vortex_array::IntoArray;
1817
use vortex_array::Precision;
1918
use vortex_array::ProstMetadata;
2019
use vortex_array::SerializeMetadata;
20+
use vortex_array::arrays::SliceVTable;
2121
use vortex_array::buffer::BufferHandle;
2222
use vortex_array::patches::Patches;
2323
use vortex_array::patches::PatchesMetadata;
@@ -174,9 +174,18 @@ impl VTable for ALPVTable {
174174
)?))
175175
}
176176

177-
fn slice(array: &Self::Array, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
178-
Ok(Some(
179-
ALPArray::new(
177+
fn execute_parent(
178+
array: &Self::Array,
179+
parent: &ArrayRef,
180+
_child_idx: usize,
181+
ctx: &mut ExecutionCtx,
182+
) -> VortexResult<Option<Canonical>> {
183+
// CPU-only: if parent is SliceArray, perform slicing of the buffer and any patches
184+
// Note that this triggers compute (binary searching Patches) which we cannot do when the
185+
// buffers live in GPU memory.
186+
if let Some(slice_array) = parent.as_opt::<SliceVTable>() {
187+
let range = slice_array.slice_range().clone();
188+
let sliced_alp = ALPArray::new(
180189
array.encoded().slice(range.clone())?,
181190
array.exponents(),
182191
array
@@ -185,8 +194,11 @@ impl VTable for ALPVTable {
185194
.transpose()?
186195
.flatten(),
187196
)
188-
.into_array(),
189-
))
197+
.into_array();
198+
return Ok(Some(sliced_alp.execute::<Canonical>(ctx)?));
199+
}
200+
201+
Ok(None)
190202
}
191203
}
192204

vortex-array/src/array/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,4 +845,14 @@ impl<V: VTable> ArrayVisitor for ArrayAdapter<V> {
845845
Ok(metadata) => Debug::fmt(&metadata, f),
846846
}
847847
}
848+
849+
fn is_host(&self) -> bool {
850+
for array in self.depth_first_traversal() {
851+
if !array.buffer_handles().iter().all(BufferHandle::is_on_host) {
852+
return false;
853+
}
854+
}
855+
856+
true
857+
}
848858
}

vortex-array/src/array/visitor.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ pub trait ArrayVisitor {
4949

5050
/// Formats a human-readable metadata description.
5151
fn metadata_fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result;
52+
53+
/// Checks if all buffers in the array tree are host-resident.
54+
///
55+
/// This will fail if any buffers of self or child arrays are GPU-resident.
56+
fn is_host(&self) -> bool;
5257
}
5358

5459
impl ArrayVisitor for Arc<dyn Array> {
@@ -95,6 +100,10 @@ impl ArrayVisitor for Arc<dyn Array> {
95100
fn metadata_fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
96101
self.as_ref().metadata_fmt(f)
97102
}
103+
104+
fn is_host(&self) -> bool {
105+
self.as_ref().is_host()
106+
}
98107
}
99108

100109
pub trait ArrayVisitorExt: Array {

vortex-array/src/arrays/filter/rules.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,20 @@ use crate::ArrayRef;
88
use crate::IntoArray;
99
use crate::arrays::FilterArray;
1010
use crate::arrays::FilterVTable;
11+
use crate::arrays::StructArray;
12+
use crate::arrays::StructArrayParts;
13+
use crate::arrays::StructVTable;
1114
use crate::matchers::Exact;
1215
use crate::optimizer::rules::ArrayParentReduceRule;
16+
use crate::optimizer::rules::ArrayReduceRule;
1317
use crate::optimizer::rules::ParentRuleSet;
18+
use crate::optimizer::rules::ReduceRuleSet;
1419

1520
pub(super) const PARENT_RULES: ParentRuleSet<FilterVTable> =
1621
ParentRuleSet::new(&[ParentRuleSet::lift(&FilterFilterRule)]);
1722

23+
pub(super) const RULES: ReduceRuleSet<FilterVTable> = ReduceRuleSet::new(&[&FilterStructRule]);
24+
1825
/// A simple redecution rule that simplifies a [`FilterArray`] whose child is also a
1926
/// [`FilterArray`].
2027
#[derive(Debug)]
@@ -39,3 +46,41 @@ impl ArrayParentReduceRule<FilterVTable> for FilterFilterRule {
3946
Ok(Some(new_array.into_array()))
4047
}
4148
}
49+
50+
/// A reduce rule that pushes a filter down into the fields of a StructArray.
51+
#[derive(Debug)]
52+
struct FilterStructRule;
53+
54+
impl ArrayReduceRule<FilterVTable> for FilterStructRule {
55+
fn reduce(&self, array: &FilterArray) -> VortexResult<Option<ArrayRef>> {
56+
let mask = array.filter_mask();
57+
let Some(struct_array) = array.child().as_opt::<StructVTable>() else {
58+
return Ok(None);
59+
};
60+
61+
let len = mask.true_count();
62+
let StructArrayParts {
63+
fields,
64+
struct_fields,
65+
validity,
66+
..
67+
} = struct_array.clone().into_parts();
68+
69+
let filtered_validity = validity.filter(mask)?;
70+
71+
let filtered_fields = fields
72+
.iter()
73+
.map(|field| field.filter(mask.clone()))
74+
.collect::<VortexResult<Vec<_>>>()?;
75+
76+
Ok(Some(
77+
StructArray::new(
78+
struct_fields.names().clone(),
79+
filtered_fields,
80+
len,
81+
filtered_validity,
82+
)
83+
.into_array(),
84+
))
85+
}
86+
}

vortex-array/src/arrays/filter/vtable.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::arrays::filter::array::FilterArray;
2525
use crate::arrays::filter::execute::execute_filter;
2626
use crate::arrays::filter::execute::execute_filter_fast_paths;
2727
use crate::arrays::filter::rules::PARENT_RULES;
28+
use crate::arrays::filter::rules::RULES;
2829
use crate::buffer::BufferHandle;
2930
use crate::executor::ExecutionCtx;
3031
use crate::serde::ArrayChildren;
@@ -140,6 +141,10 @@ impl VTable for FilterVTable {
140141
) -> VortexResult<Option<ArrayRef>> {
141142
PARENT_RULES.evaluate(array, parent, child_idx)
142143
}
144+
145+
fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
146+
RULES.evaluate(array)
147+
}
143148
}
144149

145150
impl BaseArrayVTable<FilterVTable> for FilterVTable {

vortex-array/src/arrays/slice/array.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ pub struct SliceArray {
1515
pub(super) stats: ArrayStats,
1616
}
1717

18+
pub struct SliceArrayParts {
19+
pub child: ArrayRef,
20+
pub range: Range<usize>,
21+
}
22+
1823
impl SliceArray {
1924
pub fn new(child: ArrayRef, range: Range<usize>) -> Self {
2025
if range.end > child.len() {
@@ -40,4 +45,12 @@ impl SliceArray {
4045
pub fn child(&self) -> &ArrayRef {
4146
&self.child
4247
}
48+
49+
/// Consume the slice array and return its components.
50+
pub fn into_parts(self) -> SliceArrayParts {
51+
SliceArrayParts {
52+
child: self.child,
53+
range: self.range,
54+
}
55+
}
4356
}

vortex-array/src/arrays/varbinview/vtable/array.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use std::hash::Hash;
55

66
use vortex_dtype::DType;
7+
use vortex_vector::binaryview::BinaryView;
78

89
use crate::Precision;
910
use crate::arrays::varbinview::VarBinViewArray;
@@ -15,7 +16,7 @@ use crate::vtable::BaseArrayVTable;
1516

1617
impl BaseArrayVTable<VarBinViewVTable> for VarBinViewVTable {
1718
fn len(array: &VarBinViewArray) -> usize {
18-
array.views().len()
19+
array.views_handle().len() / size_of::<BinaryView>()
1920
}
2021

2122
fn dtype(array: &VarBinViewArray) -> &DType {

vortex-array/src/buffer.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ pub trait DeviceBuffer: 'static + Send + Sync + Debug + DynEq + DynHash {
8383

8484
/// Create a new buffer that references a subrange of this buffer at the given
8585
/// slice indices.
86+
///
87+
/// Note that slice indices are in byte units.
8688
fn slice(&self, range: Range<usize>) -> Arc<dyn DeviceBuffer>;
8789

8890
/// Return a buffer with the given alignment. Where possible, this will be zero-copy.
@@ -93,6 +95,19 @@ pub trait DeviceBuffer: 'static + Send + Sync + Debug + DynEq + DynHash {
9395
fn aligned(self: Arc<Self>, alignment: Alignment) -> VortexResult<Arc<dyn DeviceBuffer>>;
9496
}
9597

98+
pub trait DeviceBufferExt: DeviceBuffer {
99+
/// Slice a range of elements `T` out of the device buffer.
100+
fn slice_typed<T: Sized>(&self, range: Range<usize>) -> Arc<dyn DeviceBuffer>;
101+
}
102+
103+
impl<B: DeviceBuffer> DeviceBufferExt for B {
104+
fn slice_typed<T: Sized>(&self, range: Range<usize>) -> Arc<dyn DeviceBuffer> {
105+
let start_bytes = range.start * size_of::<T>();
106+
let end_bytes = range.end * size_of::<T>();
107+
self.slice(start_bytes..end_bytes)
108+
}
109+
}
110+
96111
impl Hash for dyn DeviceBuffer {
97112
fn hash<H: Hasher>(&self, state: &mut H) {
98113
self.dyn_hash(state);

vortex-array/src/patches.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use vortex_vector::primitive::PrimitiveVectorMut;
3636

3737
use crate::Array;
3838
use crate::ArrayRef;
39+
use crate::ArrayVisitor;
3940
use crate::IntoArray;
4041
use crate::ToCanonical;
4142
use crate::arrays::PrimitiveArray;
@@ -171,25 +172,30 @@ impl Patches {
171172
"Patch indices must be non-nullable unsigned integers, got {:?}",
172173
indices.dtype()
173174
);
175+
174176
vortex_ensure!(
175177
indices.len() <= array_len,
176178
"Patch indices must be shorter than the array length"
177179
);
178180
vortex_ensure!(!indices.is_empty(), "Patch indices must not be empty");
179181

180-
let max = usize::try_from(&indices.scalar_at(indices.len() - 1)?)
181-
.map_err(|_| vortex_err!("indices must be a number"))?;
182-
vortex_ensure!(
183-
max - offset < array_len,
184-
"Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
185-
);
182+
// Perform validation of components when they are host-resident.
183+
// This is not possible to do eagerly when the data is on GPU memory.
184+
if indices.is_host() && values.is_host() {
185+
let max = usize::try_from(&indices.scalar_at(indices.len() - 1)?)
186+
.map_err(|_| vortex_err!("indices must be a number"))?;
187+
vortex_ensure!(
188+
max - offset < array_len,
189+
"Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
190+
);
186191

187-
debug_assert!(
188-
is_sorted(indices.as_ref())
189-
.unwrap_or(Some(false))
190-
.unwrap_or(false),
191-
"Patch indices must be sorted"
192-
);
192+
debug_assert!(
193+
is_sorted(indices.as_ref())
194+
.unwrap_or(Some(false))
195+
.unwrap_or(false),
196+
"Patch indices must be sorted"
197+
);
198+
}
193199

194200
Ok(Self {
195201
array_len,

vortex-cuda/benches/filter_cuda.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#![allow(clippy::cast_possible_truncation)]
88

99
use std::ffi::c_void;
10+
use std::fmt::Debug;
1011
use std::mem::size_of;
1112
use std::time::Duration;
1213

@@ -124,7 +125,14 @@ async fn run_filter_timed<T: CubFilterable + cudarc::driver::DeviceRepr>(
124125
/// Benchmark filter for a specific type.
125126
fn benchmark_filter_type<T>(c: &mut Criterion, type_name: &str)
126127
where
127-
T: CubFilterable + cudarc::driver::DeviceRepr + From<u8> + Clone + Send + Sync + 'static,
128+
T: CubFilterable
129+
+ cudarc::driver::DeviceRepr
130+
+ From<u8>
131+
+ Debug
132+
+ Clone
133+
+ Send
134+
+ Sync
135+
+ 'static,
128136
{
129137
let mut group = c.benchmark_group(format!("Filter_cuda_{type_name}"));
130138
group.sample_size(10);
@@ -161,7 +169,7 @@ where
161169
let d_input = d_input_handle
162170
.as_device()
163171
.as_any()
164-
.downcast_ref::<CudaDeviceBuffer<T>>()
172+
.downcast_ref::<CudaDeviceBuffer>()
165173
.unwrap();
166174

167175
// Copy bitmask to device
@@ -171,7 +179,7 @@ where
171179
let d_bitmask = d_bitmask_handle
172180
.as_device()
173181
.as_any()
174-
.downcast_ref::<CudaDeviceBuffer<u8>>()
182+
.downcast_ref::<CudaDeviceBuffer>()
175183
.unwrap();
176184

177185
// Allocate output and temp buffers

0 commit comments

Comments
 (0)