Skip to content

Commit db5ead6

Browse files
committed
feat[gpu]: slice support for CUDA dyn dispatch
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 8e29827 commit db5ead6

File tree

12 files changed

+196
-126
lines changed

12 files changed

+196
-126
lines changed

encodings/fastlanes/public-api.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ impl vortex_array::arrays::filter::kernel::FilterKernel for vortex_fastlanes::Bi
196196

197197
pub fn vortex_fastlanes::BitPackedVTable::filter(array: &vortex_fastlanes::BitPackedArray, mask: &vortex_mask::Mask, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
198198

199-
impl vortex_array::arrays::slice::SliceKernel for vortex_fastlanes::BitPackedVTable
199+
impl vortex_array::arrays::slice::SliceReduce for vortex_fastlanes::BitPackedVTable
200200

201-
pub fn vortex_fastlanes::BitPackedVTable::slice(array: &vortex_fastlanes::BitPackedArray, range: core::ops::range::Range<usize>, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
201+
pub fn vortex_fastlanes::BitPackedVTable::slice(array: &vortex_fastlanes::BitPackedArray, range: core::ops::range::Range<usize>) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
202202

203203
impl vortex_array::compute::is_constant::IsConstantKernel for vortex_fastlanes::BitPackedVTable
204204

encodings/fastlanes/src/bitpacking/compute/slice.rs

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,15 @@ use std::cmp::max;
55
use std::ops::Range;
66

77
use vortex_array::ArrayRef;
8-
use vortex_array::ExecutionCtx;
98
use vortex_array::IntoArray;
10-
use vortex_array::arrays::SliceKernel;
9+
use vortex_array::arrays::SliceReduce;
1110
use vortex_error::VortexResult;
1211

1312
use crate::BitPackedArray;
1413
use crate::BitPackedVTable;
1514

16-
impl SliceKernel for BitPackedVTable {
17-
fn slice(
18-
array: &BitPackedArray,
19-
range: Range<usize>,
20-
_ctx: &mut ExecutionCtx,
21-
) -> VortexResult<Option<ArrayRef>> {
15+
impl SliceReduce for BitPackedVTable {
16+
fn slice(array: &BitPackedArray, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
2217
let offset_start = range.start + array.offset() as usize;
2318
let offset_stop = range.end + array.offset() as usize;
2419
let offset = offset_start % 1024;
@@ -51,43 +46,44 @@ impl SliceKernel for BitPackedVTable {
5146

5247
#[cfg(test)]
5348
mod tests {
54-
use std::sync::LazyLock;
55-
5649
use vortex_array::Array;
57-
use vortex_array::IntoArray;
58-
use vortex_array::VortexSessionExecute;
59-
use vortex_array::arrays::SliceArray;
60-
use vortex_array::session::ArraySession;
61-
use vortex_array::vtable::VTable;
50+
use vortex_array::arrays::SliceReduce;
51+
use vortex_array::arrays::SliceVTable;
6252
use vortex_error::VortexResult;
63-
use vortex_session::VortexSession;
6453

6554
use crate::BitPackedVTable;
6655
use crate::bitpack_compress::bitpack_encode;
6756

68-
static SESSION: LazyLock<VortexSession> =
69-
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
57+
#[test]
58+
fn test_slice_returns_bitpacked() -> VortexResult<()> {
59+
let values = vortex_array::arrays::PrimitiveArray::from_iter(0u32..2048);
60+
let bitpacked = bitpack_encode(&values, 11, None)?;
61+
62+
let result =
63+
BitPackedVTable::slice(&bitpacked, 500..1500)?.expect("expected slice to succeed");
64+
65+
assert!(result.is::<BitPackedVTable>());
66+
let result_bp = result.as_::<BitPackedVTable>();
67+
assert_eq!(result_bp.offset(), 500);
68+
assert_eq!(result.len(), 1000);
69+
70+
Ok(())
71+
}
7072

7173
#[test]
72-
fn test_execute_parent_returns_bitpacked_slice() -> VortexResult<()> {
74+
fn test_slice_via_array_trait() -> VortexResult<()> {
7375
let values = vortex_array::arrays::PrimitiveArray::from_iter(0u32..2048);
7476
let bitpacked = bitpack_encode(&values, 11, None)?;
7577

76-
let slice_array = SliceArray::new(bitpacked.clone().into_array(), 500..1500);
77-
78-
let mut ctx = SESSION.create_execution_ctx();
79-
let reduced = <BitPackedVTable as VTable>::execute_parent(
80-
&bitpacked,
81-
&slice_array.into_array(),
82-
0,
83-
&mut ctx,
84-
)?
85-
.expect("expected slice kernel to execute");
86-
87-
assert!(reduced.is::<BitPackedVTable>());
88-
let reduced_bp = reduced.as_::<BitPackedVTable>();
89-
assert_eq!(reduced_bp.offset(), 500);
90-
assert_eq!(reduced.len(), 1000);
78+
let sliced = bitpacked.as_ref().slice(500..1500)?;
79+
80+
// After optimize, the SliceArray should have been reduced away.
81+
assert!(
82+
!sliced.is::<SliceVTable>(),
83+
"expected SliceReduce to eliminate the SliceArray wrapper"
84+
);
85+
assert!(sliced.is::<BitPackedVTable>());
86+
assert_eq!(sliced.len(), 1000);
9187

9288
Ok(())
9389
}

encodings/fastlanes/src/bitpacking/vtable/kernels.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_array::arrays::FilterExecuteAdaptor;
5-
use vortex_array::arrays::SliceExecuteAdaptor;
65
use vortex_array::arrays::TakeExecuteAdaptor;
76
use vortex_array::kernel::ParentKernelSet;
87

98
use crate::BitPackedVTable;
109

1110
pub(crate) const PARENT_KERNELS: ParentKernelSet<BitPackedVTable> = ParentKernelSet::new(&[
1211
ParentKernelSet::lift(&FilterExecuteAdaptor(BitPackedVTable)),
13-
ParentKernelSet::lift(&SliceExecuteAdaptor(BitPackedVTable)),
1412
ParentKernelSet::lift(&TakeExecuteAdaptor(BitPackedVTable)),
1513
]);

encodings/fastlanes/src/bitpacking/vtable/operations.rs

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@ impl OperationsVTable<BitPackedVTable> for BitPackedVTable {
2626
#[cfg(test)]
2727
mod test {
2828
use std::ops::Range;
29-
use std::sync::LazyLock;
3029

3130
use vortex_array::Array;
3231
use vortex_array::IntoArray;
33-
use vortex_array::VortexSessionExecute;
3432
use vortex_array::arrays::PrimitiveArray;
35-
use vortex_array::arrays::SliceArray;
3633
use vortex_array::assert_arrays_eq;
3734
use vortex_array::assert_nth_scalar;
3835
use vortex_array::buffer::BufferHandle;
@@ -41,9 +38,7 @@ mod test {
4138
use vortex_array::dtype::PType;
4239
use vortex_array::patches::Patches;
4340
use vortex_array::scalar::Scalar;
44-
use vortex_array::session::ArraySession;
4541
use vortex_array::validity::Validity;
46-
use vortex_array::vtable::VTable;
4742
use vortex_buffer::Alignment;
4843
use vortex_buffer::Buffer;
4944
use vortex_buffer::ByteBuffer;
@@ -52,20 +47,8 @@ mod test {
5247
use crate::BitPackedArray;
5348
use crate::BitPackedVTable;
5449

55-
static SESSION: LazyLock<vortex_session::VortexSession> =
56-
LazyLock::new(|| vortex_session::VortexSession::empty().with::<ArraySession>());
57-
5850
fn slice_via_kernel(array: &BitPackedArray, range: Range<usize>) -> BitPackedArray {
59-
let slice_array = SliceArray::new(array.clone().into_array(), range);
60-
let mut ctx = SESSION.create_execution_ctx();
61-
let sliced = <BitPackedVTable as VTable>::execute_parent(
62-
array,
63-
&slice_array.into_array(),
64-
0,
65-
&mut ctx,
66-
)
67-
.expect("execute_parent failed")
68-
.expect("expected slice kernel to execute");
51+
let sliced = array.as_ref().slice(range).expect("slice failed");
6952
sliced.as_::<BitPackedVTable>().clone()
7053
}
7154

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex_array::arrays::SliceReduceAdaptor;
45
use vortex_array::optimizer::rules::ParentRuleSet;
56
use vortex_array::scalar_fn::fns::cast::CastReduceAdaptor;
67

78
use crate::BitPackedVTable;
89

9-
pub(crate) const RULES: ParentRuleSet<BitPackedVTable> =
10-
ParentRuleSet::new(&[ParentRuleSet::lift(&CastReduceAdaptor(BitPackedVTable))]);
10+
pub(crate) const RULES: ParentRuleSet<BitPackedVTable> = ParentRuleSet::new(&[
11+
ParentRuleSet::lift(&CastReduceAdaptor(BitPackedVTable)),
12+
ParentRuleSet::lift(&SliceReduceAdaptor(BitPackedVTable)),
13+
]);

vortex-cuda/gpu-scan-cli/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async fn main() -> VortexResult<()> {
9494

9595
// Create a full scan that executes on the GPU
9696
let cuda_stream =
97-
VortexCudaStreamPool::new(Arc::clone(cuda_ctx.stream().context()), 1).get_stream()?;
97+
VortexCudaStreamPool::new(Arc::clone(cuda_ctx.stream().context()), 1).stream()?;
9898
let gpu_reader = CopyDeviceReadAt::new(recompressed, cuda_stream);
9999

100100
let gpu_file = session

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ __device__ inline void dynamic_source_op(const T *__restrict input,
5656
constexpr uint32_t FL_CHUNK_SIZE = 1024;
5757
constexpr uint32_t LANES_PER_FL_BLOCK = FL_CHUNK_SIZE / T_BITS;
5858
const uint32_t bit_width = source_op.params.bitunpack.bit_width;
59+
const uint32_t element_offset = source_op.params.bitunpack.element_offset;
5960
const uint32_t packed_words_per_fl_block = LANES_PER_FL_BLOCK * bit_width;
60-
const uint64_t first_fl_block = chunk_start / FL_CHUNK_SIZE;
61+
// Shift chunk_start by the sub-block element offset.
62+
const uint64_t first_fl_block = (chunk_start + element_offset) / FL_CHUNK_SIZE;
6163

6264
// FL blocks must divide evenly. Otherwise, the last unpack would overflow smem.
6365
static_assert((ELEMENTS_PER_BLOCK % FL_CHUNK_SIZE) == 0);

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ union SourceParams {
4444
/// Unpack bit-packed data using FastLanes layout.
4545
struct BitunpackParams {
4646
uint8_t bit_width;
47+
uint8_t _padding[3];
48+
uint32_t element_offset; // Element offset within FL block (0..1023)
4749
} bitunpack;
4850

4951
/// Copy elements verbatim from global memory to shared memory.
52+
/// The input pointer is pre-adjusted on the host to account for slicing.
5053
struct LoadParams {
5154
uint8_t _padding;
5255
} load;

vortex-cuda/src/device_buffer.rs

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ mod private {
8181
}
8282
}
8383

84-
// Get it back out as a View of u8
85-
8684
impl CudaDeviceBuffer {
8785
/// Creates a new CUDA device buffer from a [`CudaSlice<T>`].
8886
///
@@ -101,6 +99,16 @@ impl CudaDeviceBuffer {
10199
}
102100
}
103101

102+
/// Returns the byte offset within the allocated buffer.
103+
pub fn offset(&self) -> usize {
104+
self.offset
105+
}
106+
107+
/// Returns the adjusted device pointer accounting for the offset.
108+
pub fn offset_ptr(&self) -> sys::CUdeviceptr {
109+
self.device_ptr + self.offset as u64
110+
}
111+
104112
/// Returns a [`CudaView`] to the CUDA device buffer.
105113
pub fn as_view<T: DeviceRepr + 'static>(&self) -> CudaView<'_, T> {
106114
// Return a new &[T]
@@ -159,7 +167,7 @@ impl CudaBufferExt for BufferHandle {
159167
.as_any()
160168
.downcast_ref::<CudaDeviceBuffer>()
161169
.ok_or_else(|| vortex_err!("expected CudaDeviceBuffer"))?
162-
.device_ptr;
170+
.offset_ptr();
163171

164172
Ok(ptr)
165173
}
@@ -279,41 +287,6 @@ impl DeviceBuffer for CudaDeviceBuffer {
279287
}))
280288
}
281289

282-
/// Slices the CUDA device buffer to a subrange.
283-
///
284-
/// **IMPORTANT**: this is a byte range, not elements range, due to the DeviceBuffer interface.
285-
fn slice(&self, range: Range<usize>) -> Arc<dyn DeviceBuffer> {
286-
assert!(
287-
range.end <= self.len,
288-
"Slice range end {} exceeds allocation size {}",
289-
range.end,
290-
self.len
291-
);
292-
293-
let new_offset = self.offset + range.start;
294-
let new_len = range.end - range.start;
295-
296-
let trailing = (self.device_ptr + new_offset as u64).trailing_zeros();
297-
let exponent =
298-
u8::try_from(min(15, trailing)).vortex_expect("min(15, x) always fits in u8");
299-
let slice_align = Alignment::from_exponent(exponent);
300-
301-
assert!(
302-
slice_align.is_aligned_to(self.allocation.alignment()),
303-
"slice must respect minimum alignment {}, min {}",
304-
slice_align,
305-
self.allocation.alignment()
306-
);
307-
308-
Arc::new(CudaDeviceBuffer {
309-
allocation: Arc::clone(&self.allocation),
310-
offset: new_offset,
311-
len: new_len,
312-
device_ptr: self.device_ptr,
313-
alignment: self.alignment,
314-
})
315-
}
316-
317290
fn as_any(&self) -> &dyn std::any::Any {
318291
self
319292
}

0 commit comments

Comments
 (0)