Add safe API for CUDA constant memory operations#478
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a safe API for CUDA constant memory, which is a valuable addition. The implementation is well-structured, with a new CudaSymbol type, associated memory copy functions, and a clear example demonstrating its usage. My review focuses on improving error handling in the new public APIs. I've identified a few places where the code could panic on invalid user input. By replacing these panics with Result::Err returns, the library will become more robust and user-friendly. These are high-severity issues for a library.
| assert!( | ||
| symbol.bytes >= src_bytes, | ||
| "Symbol size ({} bytes) is smaller than source data ({} bytes)", | ||
| symbol.bytes, | ||
| src_bytes | ||
| ); |
There was a problem hiding this comment.
Similar to memcpy_htos, using assert! here can cause a panic if the source data is larger than the symbol's capacity. It's better to return a Result::Err to allow for graceful error handling by the caller.
| assert!( | |
| symbol.bytes >= src_bytes, | |
| "Symbol size ({} bytes) is smaller than source data ({} bytes)", | |
| symbol.bytes, | |
| src_bytes | |
| ); | |
| if symbol.bytes < src_bytes { | |
| return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_VALUE)); | |
| } |
There was a problem hiding this comment.
Maybe we should do this on all of the memcpy assertions, seems reasonable
There was a problem hiding this comment.
Potentially - either way should be a separate PR
There was a problem hiding this comment.
Will keep this for now, and if it's merged, I'll get another pr to change asserts to results
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| /// let symbol = module.get_global("my_const")?; | ||
| /// stream.memcpy_htos(&data, &symbol)?; | ||
| /// ``` | ||
| pub fn get_global(self: &Arc<Self>, name: &str) -> Result<CudaSymbol, DriverError> { |
There was a problem hiding this comment.
I'm thinking it might be better to just return a CudaSlice<u8> from this function. This has the following benefits:
- We don't need the
CudaSymbolstruct at all & we can use the existing memcpy stuff. - CudaSlice can already be transmuted to different types
- We have built in support for event tracking - which right now CudaSymbol does not track
- We don't need to add support to CudaSymbol for multi stream synchronization
Thoughts? Open to discussion, but definitely leaning towards CudaSlice<u8>
There was a problem hiding this comment.
CudaSlice seems right, will check the impl details and give you another commit
There was a problem hiding this comment.
get_global should now return a CudaSlice :)
|
Love this addition - just need to discuss the return from get_global. Thank you for this work |
Adds support for accessing and copying data to
__constant__memory in CUDA modules.Changes
module::get_global()wrapper forcuModuleGetGlobal_v2CudaSymboltype to represent global/constant symbolsCudaStream::memcpy_htos()andmemcpy_dtos()for copying to symbols09-constant-memory.rsdemonstrating polynomial evaluation using constant memoryUsage
Closes the gap in constant memory support by providing a safe, ergonomic API for symbol access and data transfer.