@@ -17,20 +17,18 @@ use cudarc::driver::LaunchArgs;
1717use cudarc:: driver:: result;
1818use cudarc:: driver:: result:: memcpy_htod_async;
1919use cudarc:: driver:: sys;
20- use cudarc:: driver:: sys:: CUevent_flags ;
2120use futures:: future:: BoxFuture ;
2221use kanal:: Sender ;
2322use result:: stream;
2423use vortex_array:: Array ;
2524use vortex_array:: ArrayRef ;
2625use vortex_array:: Canonical ;
27- use vortex_array:: VortexSessionExecute ;
26+ use vortex_array:: ExecutionCtx ;
2827use vortex_array:: buffer:: BufferHandle ;
2928use vortex_buffer:: Buffer ;
3029use vortex_dtype:: PType ;
3130use vortex_error:: VortexResult ;
3231use vortex_error:: vortex_err;
33- use vortex_session:: VortexSession ;
3432
3533use crate :: CudaDeviceBuffer ;
3634use crate :: CudaSession ;
@@ -115,109 +113,23 @@ pub struct CudaKernelEvents {
115113 pub after_launch : CudaEvent ,
116114}
117115
118- /// Convenience macro to launch a CUDA kernel.
119- ///
120- /// The kernel gets launched on the stream of the execution context.
121- ///
122- /// The kernel launch config:
123- /// LaunchConfig {
124- /// grid_dim: (array.len() / 2048, 1, 1),
125- /// block_dim: (64, 1, 1),
126- /// shared_mem_bytes: 0,
127- /// };
128- /// 64 threads are used per block which corresponds to 2 warps.
129- /// Each block handles 2048 elements. Each thread handles 32 elements.
130- /// The last block and thread are allowed to have less elements.
131- ///
132- /// Note: A macro is necessary to unroll the launch builder arguments.
133- ///
134- /// # Returns
135- ///
136- /// A pair of CUDA events submitted before and after the kernel.
137- /// Depending on `CUevent_flags` these events can contain timestamps. Use
138- /// `CU_EVENT_DISABLE_TIMING` for minimal overhead and `CU_EVENT_DEFAULT` to
139- /// enable timestamps.
140- #[ macro_export]
141- macro_rules! launch_cuda_kernel {
142- (
143- execution_ctx: $ctx: expr,
144- module: $module: expr,
145- ptypes: $ptypes: expr,
146- launch_args: [ $( $arg: expr) ,* $( , ) ?] ,
147- event_recording: $event_recording: expr,
148- array_len: $len: expr
149- ) => { {
150- let cuda_function = $ctx. load_function( $module, $ptypes) ?;
151- let mut launch_builder = $ctx. launch_builder( & cuda_function) ;
152-
153- $(
154- launch_builder. arg( & $arg) ;
155- ) *
156-
157- $crate:: executor:: launch_cuda_kernel_impl( & mut launch_builder, $event_recording, $len) ?
158- } } ;
159- }
160-
161- /// Launches a CUDA kernel with the passed launch builder.
162- ///
163- /// # Arguments
164- ///
165- /// * `launch_builder` - Configured launch builder
166- /// * `array_len` - Length of the array to process
167- ///
168- /// # Returns
169- ///
170- /// A pair of CUDA events submitted before and after the kernel.
171- /// Depending on `CUevent_flags` these events can contain timestamps. Use
172- /// `CU_EVENT_DISABLE_TIMING` for minimal overhead and `CU_EVENT_DEFAULT` to
173- /// enable timestamps.
174- pub fn launch_cuda_kernel_impl (
175- launch_builder : & mut LaunchArgs ,
176- event_flags : CUevent_flags ,
177- array_len : usize ,
178- ) -> VortexResult < CudaKernelEvents > {
179- let num_chunks = u32:: try_from ( array_len. div_ceil ( 2048 ) ) ?;
180-
181- let config = cudarc:: driver:: LaunchConfig {
182- grid_dim : ( num_chunks, 1 , 1 ) ,
183- block_dim : ( 64 , 1 , 1 ) ,
184- shared_mem_bytes : 0 ,
185- } ;
186-
187- launch_builder. record_kernel_launch ( event_flags) ;
188-
189- unsafe {
190- launch_builder
191- . launch ( config)
192- . map_err ( |e| vortex_err ! ( "Failed to launch kernel: {}" , e) )
193- . and_then ( |events| {
194- events
195- . ok_or_else ( || vortex_err ! ( "CUDA events not recorded" ) )
196- . map ( |( before_launch, after_launch) | CudaKernelEvents {
197- before_launch,
198- after_launch,
199- } )
200- } )
201- }
202- }
203-
204116/// CUDA execution context.
205117///
206118/// Provides access to the CUDA context and stream for kernel execution.
207119/// Handles memory allocation and data transfers between host and device.
208120pub struct CudaExecutionCtx {
209121 stream : Arc < CudaStream > ,
210- vortex_session : VortexSession ,
122+ ctx : ExecutionCtx ,
211123 cuda_session : CudaSession ,
212124}
213125
214126impl CudaExecutionCtx {
215127 /// Creates a new CUDA execution context.
216- pub ( crate ) fn new ( stream : Arc < CudaStream > , vortex_session : VortexSession ) -> Self {
217- let cuda_session = vortex_session . cuda_session ( ) . clone ( ) ;
128+ pub ( crate ) fn new ( stream : Arc < CudaStream > , ctx : ExecutionCtx ) -> Self {
129+ let cuda_session = ctx . session ( ) . cuda_session ( ) . clone ( ) ;
218130 Self {
219131 stream,
220- vortex_session ,
132+ ctx ,
221133 cuda_session,
222134 }
223135 }
@@ -351,17 +263,16 @@ pub trait CudaArrayExt: Array {
351263#[ async_trait]
352264impl CudaArrayExt for ArrayRef {
353265 async fn execute_cuda ( self , ctx : & mut CudaExecutionCtx ) -> VortexResult < Canonical > {
354- if self . is_canonical ( ) {
355- return self . to_canonical ( ) ;
266+ if self . is_canonical ( ) || self . is_empty ( ) {
267+ return self . execute ( & mut ctx . ctx ) ;
356268 }
357269
358270 let Some ( support) = ctx. cuda_session . kernel ( & self . encoding_id ( ) ) else {
359271 tracing:: debug!(
360272 encoding = %self . encoding_id( ) ,
361273 "No CUDA support registered for encoding, falling back to CPU execution"
362274 ) ;
363- let mut array_ctx = ctx. vortex_session . create_execution_ctx ( ) ;
364- return self . execute ( & mut array_ctx) ;
275+ return self . execute ( & mut ctx. ctx ) ;
365276 } ;
366277
367278 tracing:: debug!(
0 commit comments