Skip to content

Commit 308efaa

Browse files
chore[cuda]: clean up and move around kernel (vortex-data#6082)
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 2b36b2e commit 308efaa

11 files changed

Lines changed: 202 additions & 139 deletions

File tree

vortex-array/src/arrays/dict/array.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ pub struct DictArray {
4545
pub(super) all_values_referenced: bool,
4646
}
4747

48+
pub struct DictArrayParts {
49+
pub codes: ArrayRef,
50+
pub values: ArrayRef,
51+
pub dtype: DType,
52+
}
53+
4854
impl DictArray {
4955
/// Build a new `DictArray` without validating the codes or values.
5056
///
@@ -114,8 +120,12 @@ impl DictArray {
114120
Ok(unsafe { Self::new_unchecked(codes, values) })
115121
}
116122

117-
pub fn into_parts(self) -> (ArrayRef, ArrayRef) {
118-
(self.codes, self.values)
123+
pub fn into_parts(self) -> DictArrayParts {
124+
DictArrayParts {
125+
codes: self.codes,
126+
values: self.values,
127+
dtype: self.dtype,
128+
}
119129
}
120130

121131
#[inline]

vortex-array/src/arrow/executor/dictionary.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use vortex_error::vortex_bail;
1515
use crate::ArrayRef;
1616
use crate::ExecutionCtx;
1717
use crate::arrays::DictArray;
18+
use crate::arrays::DictArrayParts;
1819
use crate::arrays::DictVTable;
1920
use crate::arrow::ArrowArrayExecutor;
2021

@@ -47,7 +48,7 @@ fn dict_to_dict(
4748
values_type: &DataType,
4849
ctx: &mut ExecutionCtx,
4950
) -> VortexResult<ArrowArrayRef> {
50-
let (codes, values) = array.into_parts();
51+
let DictArrayParts { codes, values, .. } = array.into_parts();
5152
let codes = codes.execute_arrow(Some(codes_type), ctx)?;
5253
let values = values.execute_arrow(Some(values_type), ctx)?;
5354

vortex-cuda/benches/for_cuda.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ fn benchmark_for_u8(c: &mut Criterion) {
199199
&for_array,
200200
|b, for_array| {
201201
b.iter_custom(|iters| {
202-
let mut cuda_ctx = CudaSession::new_ctx(VortexSession::empty())
202+
let mut cuda_ctx = CudaSession::create_execution_ctx(VortexSession::empty())
203203
.vortex_expect("failed to create execution context");
204204

205205
let encoded = for_array.encoded();
@@ -248,7 +248,7 @@ fn benchmark_for_u16(c: &mut Criterion) {
248248
&for_array,
249249
|b, for_array| {
250250
b.iter_custom(|iters| {
251-
let mut cuda_ctx = CudaSession::new_ctx(VortexSession::empty())
251+
let mut cuda_ctx = CudaSession::create_execution_ctx(VortexSession::empty())
252252
.vortex_expect("failed to create execution context");
253253

254254
let encoded = for_array.encoded();
@@ -297,7 +297,7 @@ fn benchmark_for_u32(c: &mut Criterion) {
297297
&for_array,
298298
|b, for_array| {
299299
b.iter_custom(|iters| {
300-
let mut cuda_ctx = CudaSession::new_ctx(VortexSession::empty())
300+
let mut cuda_ctx = CudaSession::create_execution_ctx(VortexSession::empty())
301301
.vortex_expect("failed to create execution context");
302302

303303
let encoded = for_array.encoded();
@@ -346,7 +346,7 @@ fn benchmark_for_u64(c: &mut Criterion) {
346346
&for_array,
347347
|b, for_array| {
348348
b.iter_custom(|iters| {
349-
let mut cuda_ctx = CudaSession::new_ctx(VortexSession::empty())
349+
let mut cuda_ctx = CudaSession::create_execution_ctx(VortexSession::empty())
350350
.vortex_expect("failed to create execution context");
351351

352352
let encoded = for_array.encoded();

vortex-cuda/src/executor.rs

Lines changed: 8 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,18 @@ use cudarc::driver::LaunchArgs;
1717
use cudarc::driver::result;
1818
use cudarc::driver::result::memcpy_htod_async;
1919
use cudarc::driver::sys;
20-
use cudarc::driver::sys::CUevent_flags;
2120
use futures::future::BoxFuture;
2221
use kanal::Sender;
2322
use result::stream;
2423
use vortex_array::Array;
2524
use vortex_array::ArrayRef;
2625
use vortex_array::Canonical;
27-
use vortex_array::VortexSessionExecute;
26+
use vortex_array::ExecutionCtx;
2827
use vortex_array::buffer::BufferHandle;
2928
use vortex_buffer::Buffer;
3029
use vortex_dtype::PType;
3130
use vortex_error::VortexResult;
3231
use vortex_error::vortex_err;
33-
use vortex_session::VortexSession;
3432

3533
use crate::CudaDeviceBuffer;
3634
use crate::CudaSession;
@@ -115,109 +113,23 @@ pub struct CudaKernelEvents {
115113
pub after_launch: CudaEvent,
116114
}
117115

118-
/// Convenience macro to launch a CUDA kernel.
119-
///
120-
/// The kernel gets launched on the stream of the execution context.
121-
///
122-
/// The kernel launch config:
123-
/// LaunchConfig {
124-
/// grid_dim: (array.len() / 2048, 1, 1),
125-
/// block_dim: (64, 1, 1),
126-
/// shared_mem_bytes: 0,
127-
/// };
128-
/// 64 threads are used per block which corresponds to 2 warps.
129-
/// Each block handles 2048 elements. Each thread handles 32 elements.
130-
/// The last block and thread are allowed to have less elements.
131-
///
132-
/// Note: A macro is necessary to unroll the launch builder arguments.
133-
///
134-
/// # Returns
135-
///
136-
/// A pair of CUDA events submitted before and after the kernel.
137-
/// Depending on `CUevent_flags` these events can contain timestamps. Use
138-
/// `CU_EVENT_DISABLE_TIMING` for minimal overhead and `CU_EVENT_DEFAULT` to
139-
/// enable timestamps.
140-
#[macro_export]
141-
macro_rules! launch_cuda_kernel {
142-
(
143-
execution_ctx: $ctx:expr,
144-
module: $module:expr,
145-
ptypes: $ptypes:expr,
146-
launch_args: [$($arg:expr),* $(,)?],
147-
event_recording: $event_recording:expr,
148-
array_len: $len:expr
149-
) => {{
150-
let cuda_function = $ctx.load_function($module, $ptypes)?;
151-
let mut launch_builder = $ctx.launch_builder(&cuda_function);
152-
153-
$(
154-
launch_builder.arg(&$arg);
155-
)*
156-
157-
$crate::executor::launch_cuda_kernel_impl(&mut launch_builder, $event_recording, $len)?
158-
}};
159-
}
160-
161-
/// Launches a CUDA kernel with the passed launch builder.
162-
///
163-
/// # Arguments
164-
///
165-
/// * `launch_builder` - Configured launch builder
166-
/// * `array_len` - Length of the array to process
167-
///
168-
/// # Returns
169-
///
170-
/// A pair of CUDA events submitted before and after the kernel.
171-
/// Depending on `CUevent_flags` these events can contain timestamps. Use
172-
/// `CU_EVENT_DISABLE_TIMING` for minimal overhead and `CU_EVENT_DEFAULT` to
173-
/// enable timestamps.
174-
pub fn launch_cuda_kernel_impl(
175-
launch_builder: &mut LaunchArgs,
176-
event_flags: CUevent_flags,
177-
array_len: usize,
178-
) -> VortexResult<CudaKernelEvents> {
179-
let num_chunks = u32::try_from(array_len.div_ceil(2048))?;
180-
181-
let config = cudarc::driver::LaunchConfig {
182-
grid_dim: (num_chunks, 1, 1),
183-
block_dim: (64, 1, 1),
184-
shared_mem_bytes: 0,
185-
};
186-
187-
launch_builder.record_kernel_launch(event_flags);
188-
189-
unsafe {
190-
launch_builder
191-
.launch(config)
192-
.map_err(|e| vortex_err!("Failed to launch kernel: {}", e))
193-
.and_then(|events| {
194-
events
195-
.ok_or_else(|| vortex_err!("CUDA events not recorded"))
196-
.map(|(before_launch, after_launch)| CudaKernelEvents {
197-
before_launch,
198-
after_launch,
199-
})
200-
})
201-
}
202-
}
203-
204116
/// CUDA execution context.
205117
///
206118
/// Provides access to the CUDA context and stream for kernel execution.
207119
/// Handles memory allocation and data transfers between host and device.
208120
pub struct CudaExecutionCtx {
209121
stream: Arc<CudaStream>,
210-
vortex_session: VortexSession,
122+
ctx: ExecutionCtx,
211123
cuda_session: CudaSession,
212124
}
213125

214126
impl CudaExecutionCtx {
215127
/// Creates a new CUDA execution context.
216-
pub(crate) fn new(stream: Arc<CudaStream>, vortex_session: VortexSession) -> Self {
217-
let cuda_session = vortex_session.cuda_session().clone();
128+
pub(crate) fn new(stream: Arc<CudaStream>, ctx: ExecutionCtx) -> Self {
129+
let cuda_session = ctx.session().cuda_session().clone();
218130
Self {
219131
stream,
220-
vortex_session,
132+
ctx,
221133
cuda_session,
222134
}
223135
}
@@ -351,17 +263,16 @@ pub trait CudaArrayExt: Array {
351263
#[async_trait]
352264
impl CudaArrayExt for ArrayRef {
353265
async fn execute_cuda(self, ctx: &mut CudaExecutionCtx) -> VortexResult<Canonical> {
354-
if self.is_canonical() {
355-
return self.to_canonical();
266+
if self.is_canonical() || self.is_empty() {
267+
return self.execute(&mut ctx.ctx);
356268
}
357269

358270
let Some(support) = ctx.cuda_session.kernel(&self.encoding_id()) else {
359271
tracing::debug!(
360272
encoding = %self.encoding_id(),
361273
"No CUDA support registered for encoding, falling back to CPU execution"
362274
);
363-
let mut array_ctx = ctx.vortex_session.create_execution_ctx();
364-
return self.execute(&mut array_ctx);
275+
return self.execute(&mut ctx.ctx);
365276
};
366277

367278
tracing::debug!(
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use async_trait::async_trait;
5+
use vortex_array::ArrayRef;
6+
use vortex_array::Canonical;
7+
use vortex_array::arrays::DictVTable;
8+
use vortex_error::VortexExpect;
9+
use vortex_error::VortexResult;
10+
11+
use crate::executor::CudaExecute;
12+
use crate::executor::CudaExecutionCtx;
13+
14+
/// CUDA executor for dictionary-encoded arrays.
15+
#[derive(Debug)]
16+
pub struct DictExecutor;
17+
18+
#[async_trait]
19+
impl CudaExecute for DictExecutor {
20+
async fn execute(
21+
&self,
22+
array: ArrayRef,
23+
_ctx: &mut CudaExecutionCtx,
24+
) -> VortexResult<Canonical> {
25+
let _dict_array = array
26+
.try_into::<DictVTable>()
27+
.ok()
28+
.vortex_expect("Array is not a Dict array");
29+
30+
todo!()
31+
}
32+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod dict;
5+
pub use dict::DictExecutor;

0 commit comments

Comments
 (0)