-
-
Notifications
You must be signed in to change notification settings - Fork 159
Add safe API for CUDA constant memory operations #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| use cudarc::{ | ||
| driver::{CudaContext, DriverError, LaunchConfig, PushKernelArg}, | ||
| nvrtc::Ptx, | ||
| }; | ||
|
|
||
| fn main() -> Result<(), DriverError> { | ||
| let ctx = CudaContext::new(0)?; | ||
| let stream = ctx.default_stream(); | ||
|
|
||
| // Load the module containing the kernel with constant memory | ||
| let module = ctx.load_module(Ptx::from_file("./examples/constant_memory.ptx"))?; | ||
|
|
||
| // Get the constant memory symbol | ||
| let coefficients_symbol = module.get_global("coefficients")?; | ||
| println!( | ||
| "Constant memory symbol 'coefficients' has {} bytes", | ||
| coefficients_symbol.num_bytes() | ||
| ); | ||
|
|
||
| // Set up polynomial coefficients: 1.0 + 2.0*x + 3.0*x^2 + 4.0*x^3 | ||
| let coefficients = [1.0f32, 2.0, 3.0, 4.0]; | ||
|
|
||
| // Copy coefficients to constant memory | ||
| stream.memcpy_htos(&coefficients, &coefficients_symbol)?; | ||
|
|
||
| // Load the kernel function | ||
| let polynomial_kernel = module.load_function("polynomial_kernel")?; | ||
|
|
||
| // Prepare input data | ||
| let input = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; | ||
| let n = input.len(); | ||
|
|
||
| // Copy input to device | ||
| let input_dev = stream.memcpy_stod(&input)?; | ||
| let mut output_dev = stream.alloc_zeros::<f32>(n)?; | ||
|
|
||
| // Launch kernel | ||
| let cfg = LaunchConfig::for_num_elems(n as u32); | ||
| unsafe { | ||
| stream | ||
| .launch_builder(&polynomial_kernel) | ||
| .arg(&mut output_dev) | ||
| .arg(&input_dev) | ||
| .arg(&(n as i32)) | ||
| .launch(cfg) | ||
| }?; | ||
|
|
||
| // Copy results back | ||
| let output = stream.memcpy_dtov(&output_dev)?; | ||
|
|
||
| // Verify results | ||
| println!("\nPolynomial evaluation (1.0 + 2.0*x + 3.0*x^2 + 4.0*x^3):"); | ||
| for (i, (&x, &y)) in input.iter().zip(output.iter()).enumerate() { | ||
| let expected = coefficients[0] | ||
| + coefficients[1] * x | ||
| + coefficients[2] * x * x | ||
| + coefficients[3] * x * x * x; | ||
| println!( | ||
| " f({:.1}) = {:.1} (expected {:.1})", | ||
| x, y, expected | ||
| ); | ||
| assert!( | ||
| (y - expected).abs() < 1e-4, | ||
| "Mismatch at index {}: got {}, expected {}", | ||
| i, | ||
| y, | ||
| expected | ||
| ); | ||
| } | ||
|
|
||
| println!("\nAll results match expected values!"); | ||
|
|
||
| Ok(()) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| // Constant memory - faster than global memory for read-only data | ||
| // accessed by all threads | ||
| __constant__ float coefficients[4]; | ||
|
|
||
| extern "C" __global__ void polynomial_kernel( | ||
| float *out, | ||
| const float *inp, | ||
| int numel | ||
| ) { | ||
| int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (i < numel) { | ||
| float x = inp[i]; | ||
| // Compute polynomial: coefficients[0] + coefficients[1]*x + coefficients[2]*x^2 + coefficients[3]*x^3 | ||
| out[i] = coefficients[0] + | ||
| coefficients[1] * x + | ||
| coefficients[2] * x * x + | ||
| coefficients[3] * x * x * x; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| // | ||
| // Generated by NVIDIA NVVM Compiler | ||
| // | ||
| // Compiler Build ID: CL-36424714 | ||
| // Cuda compilation tools, release 13.0, V13.0.88 | ||
| // Based on NVVM 7.0.1 | ||
| // | ||
|
|
||
| .version 9.0 | ||
| .target sm_75 | ||
| .address_size 64 | ||
|
|
||
| // .globl polynomial_kernel | ||
| .const .align 4 .b8 coefficients[16]; | ||
|
|
||
| .visible .entry polynomial_kernel( | ||
| .param .u64 polynomial_kernel_param_0, | ||
| .param .u64 polynomial_kernel_param_1, | ||
| .param .u32 polynomial_kernel_param_2 | ||
| ) | ||
| { | ||
| .reg .pred %p<2>; | ||
| .reg .f32 %f<12>; | ||
| .reg .b32 %r<6>; | ||
| .reg .b64 %rd<8>; | ||
|
|
||
|
|
||
| ld.param.u64 %rd1, [polynomial_kernel_param_0]; | ||
| ld.param.u64 %rd2, [polynomial_kernel_param_1]; | ||
| ld.param.u32 %r2, [polynomial_kernel_param_2]; | ||
| mov.u32 %r3, %ctaid.x; | ||
| mov.u32 %r4, %ntid.x; | ||
| mov.u32 %r5, %tid.x; | ||
| mad.lo.s32 %r1, %r3, %r4, %r5; | ||
| setp.ge.s32 %p1, %r1, %r2; | ||
| @%p1 bra $L__BB0_2; | ||
|
|
||
| cvta.to.global.u64 %rd3, %rd2; | ||
| mul.wide.s32 %rd4, %r1, 4; | ||
| add.s64 %rd5, %rd3, %rd4; | ||
| ld.const.f32 %f1, [coefficients+4]; | ||
| ld.global.f32 %f2, [%rd5]; | ||
| ld.const.f32 %f3, [coefficients]; | ||
| fma.rn.f32 %f4, %f2, %f1, %f3; | ||
| ld.const.f32 %f5, [coefficients+8]; | ||
| mul.f32 %f6, %f2, %f5; | ||
| fma.rn.f32 %f7, %f2, %f6, %f4; | ||
| ld.const.f32 %f8, [coefficients+12]; | ||
| mul.f32 %f9, %f2, %f8; | ||
| mul.f32 %f10, %f2, %f9; | ||
| fma.rn.f32 %f11, %f2, %f10, %f7; | ||
| cvta.to.global.u64 %rd6, %rd1; | ||
| add.s64 %rd7, %rd6, %rd4; | ||
| st.global.f32 [%rd7], %f11; | ||
|
|
||
| $L__BB0_2: | ||
| ret; | ||
|
|
||
| } | ||
|
|
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1298,6 +1298,62 @@ impl CudaStream { | |||||||||||||||||||
| unsafe { result::memcpy_dtod_async(dst, src, num_bytes, self.cu_stream) } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /// Copy a `[T]`/`Vec<T>`/[`PinnedHostSlice<T>`] to a global/constant symbol. | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// This is used to copy data into `__constant__` memory declared in CUDA kernels. | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// # Example | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// ```ignore | ||||||||||||||||||||
| /// // In CUDA: __constant__ float my_const[256]; | ||||||||||||||||||||
| /// let symbol = module.get_global("my_const")?; | ||||||||||||||||||||
| /// let data = vec![1.0f32; 256]; | ||||||||||||||||||||
| /// stream.memcpy_htos(&data, &symbol)?; | ||||||||||||||||||||
| /// ``` | ||||||||||||||||||||
| pub fn memcpy_htos<T: DeviceRepr, Src: HostSlice<T> + ?Sized>( | ||||||||||||||||||||
| self: &Arc<Self>, | ||||||||||||||||||||
| src: &Src, | ||||||||||||||||||||
| symbol: &CudaSymbol, | ||||||||||||||||||||
| ) -> Result<(), DriverError> { | ||||||||||||||||||||
| let src_bytes = std::mem::size_of::<T>() * src.len(); | ||||||||||||||||||||
| assert!( | ||||||||||||||||||||
| symbol.bytes >= src_bytes, | ||||||||||||||||||||
| "Symbol size ({} bytes) is smaller than source data ({} bytes)", | ||||||||||||||||||||
| symbol.bytes, | ||||||||||||||||||||
| src_bytes | ||||||||||||||||||||
| ); | ||||||||||||||||||||
| let (src, _record_src) = unsafe { src.stream_synced_slice(self) }; | ||||||||||||||||||||
| unsafe { result::memcpy_htod_async(symbol.cu_device_ptr, src, self.cu_stream) } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /// Copy a [`CudaSlice`]/[`CudaView`] to a global/constant symbol. | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// This is used to copy data into `__constant__` memory declared in CUDA kernels. | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// # Example | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// ```ignore | ||||||||||||||||||||
| /// // In CUDA: __constant__ float my_const[256]; | ||||||||||||||||||||
| /// let symbol = module.get_global("my_const")?; | ||||||||||||||||||||
| /// let device_data = stream.memcpy_stod(&vec![1.0f32; 256])?; | ||||||||||||||||||||
| /// stream.memcpy_dtos(&device_data, &symbol)?; | ||||||||||||||||||||
| /// ``` | ||||||||||||||||||||
| pub fn memcpy_dtos<T, Src: DevicePtr<T>>( | ||||||||||||||||||||
| self: &Arc<Self>, | ||||||||||||||||||||
| src: &Src, | ||||||||||||||||||||
| symbol: &CudaSymbol, | ||||||||||||||||||||
| ) -> Result<(), DriverError> { | ||||||||||||||||||||
| let src_bytes = src.num_bytes(); | ||||||||||||||||||||
| assert!( | ||||||||||||||||||||
| symbol.bytes >= src_bytes, | ||||||||||||||||||||
| "Symbol size ({} bytes) is smaller than source data ({} bytes)", | ||||||||||||||||||||
| symbol.bytes, | ||||||||||||||||||||
| src_bytes | ||||||||||||||||||||
| ); | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should do this on all of the memcpy assertions, seems reasonable
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potentially - either way should be a separate PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will keep this for now, and if it's merged, I'll get another pr to change asserts to results |
||||||||||||||||||||
| let (src, _record_src) = src.device_ptr(self); | ||||||||||||||||||||
| unsafe { result::memcpy_dtod_async(symbol.cu_device_ptr, src, src_bytes, self.cu_stream) } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /// Copy a [`CudaSlice`]/[`CudaView`] to a new [`CudaSlice`]. | ||||||||||||||||||||
| pub fn clone_dtod<T: DeviceRepr, Src: DevicePtr<T>>( | ||||||||||||||||||||
| self: &Arc<Self>, | ||||||||||||||||||||
|
|
@@ -1750,6 +1806,50 @@ impl CudaModule { | |||||||||||||||||||
| module: self.clone(), | ||||||||||||||||||||
| }) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /// Gets a global/constant symbol from the loaded module. | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// This can be used to access `__constant__` memory declared in CUDA kernels. | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// # Example | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// ```ignore | ||||||||||||||||||||
| /// // In CUDA: __constant__ float my_const[256]; | ||||||||||||||||||||
| /// let symbol = module.get_global("my_const")?; | ||||||||||||||||||||
| /// stream.memcpy_htos(&data, &symbol)?; | ||||||||||||||||||||
| /// ``` | ||||||||||||||||||||
| pub fn get_global(self: &Arc<Self>, name: &str) -> Result<CudaSymbol, DriverError> { | ||||||||||||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking it might be better to just return a
Thoughts? Open to discussion, but definitely leaning towards
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CudaSlice seems right, will check the impl details and give you another commit
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get_global should now return a CudaSlice :) |
||||||||||||||||||||
| let name_c = CString::new(name).unwrap(); | ||||||||||||||||||||
|
wizenink marked this conversation as resolved.
Outdated
|
||||||||||||||||||||
| let (cu_device_ptr, bytes) = | ||||||||||||||||||||
| unsafe { result::module::get_global(self.cu_module, name_c) }?; | ||||||||||||||||||||
| Ok(CudaSymbol { | ||||||||||||||||||||
| cu_device_ptr, | ||||||||||||||||||||
| bytes, | ||||||||||||||||||||
| module: self.clone(), | ||||||||||||||||||||
| }) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /// Wrapper around a global/constant symbol from a CUDA module. | ||||||||||||||||||||
| /// | ||||||||||||||||||||
| /// Created with [CudaModule::get_global()]. Use [CudaStream::memcpy_htos()] | ||||||||||||||||||||
| /// or [CudaStream::memcpy_dtos()] to copy data to the symbol. | ||||||||||||||||||||
| #[derive(Debug, Clone)] | ||||||||||||||||||||
| pub struct CudaSymbol { | ||||||||||||||||||||
| pub(crate) cu_device_ptr: sys::CUdeviceptr, | ||||||||||||||||||||
| pub(crate) bytes: usize, | ||||||||||||||||||||
| #[allow(unused)] | ||||||||||||||||||||
| pub(crate) module: Arc<CudaModule>, | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| unsafe impl Send for CudaSymbol {} | ||||||||||||||||||||
| unsafe impl Sync for CudaSymbol {} | ||||||||||||||||||||
|
|
||||||||||||||||||||
| impl CudaSymbol { | ||||||||||||||||||||
| /// Returns the size of the symbol in bytes. | ||||||||||||||||||||
| pub fn num_bytes(&self) -> usize { | ||||||||||||||||||||
| self.bytes | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| impl CudaFunction { | ||||||||||||||||||||
|
|
||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.