Skip to content

Add safe API for CUDA constant memory operations#478

Merged
chelsea0x3b merged 3 commits into
chelsea0x3b:mainfrom
wizenink:copy_to_global
Nov 1, 2025
Merged

Add safe API for CUDA constant memory operations#478
chelsea0x3b merged 3 commits into
chelsea0x3b:mainfrom
wizenink:copy_to_global

Conversation

@wizenink

@wizenink wizenink commented Nov 1, 2025

Copy link
Copy Markdown
Contributor

Adds support for accessing and copying data to __constant__ memory in CUDA modules.

Changes

  • Result layer: Added module::get_global() wrapper for cuModuleGetGlobal_v2
  • Safe API: New CudaSymbol type to represent global/constant symbols
  • Memory operations: Added CudaStream::memcpy_htos() and memcpy_dtos() for copying to symbols
  • Example: Added 09-constant-memory.rs demonstrating polynomial evaluation using constant memory

Usage

let symbol = module.get_global("my_constant")?;
stream.memcpy_htos(&data, &symbol)?;

Closes the gap in constant memory support by providing a safe, ergonomic API for symbol access and data transfer.

@wizenink wizenink requested a review from chelsea0x3b as a code owner November 1, 2025 09:37

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a safe API for CUDA constant memory, which is a valuable addition. The implementation is well-structured, with a new CudaSymbol type, associated memory copy functions, and a clear example demonstrating its usage. My review focuses on improving error handling in the new public APIs. I've identified a few places where the code could panic on invalid user input. By replacing these panics with Result::Err returns, the library will become more robust and user-friendly. These are high-severity issues for a library.

Comment thread src/driver/safe/core.rs Outdated
Comment thread src/driver/safe/core.rs Outdated
Comment on lines +1347 to +1352
assert!(
symbol.bytes >= src_bytes,
"Symbol size ({} bytes) is smaller than source data ({} bytes)",
symbol.bytes,
src_bytes
);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to memcpy_htos, using assert! here can cause a panic if the source data is larger than the symbol's capacity. It's better to return a Result::Err to allow for graceful error handling by the caller.

Suggested change
assert!(
symbol.bytes >= src_bytes,
"Symbol size ({} bytes) is smaller than source data ({} bytes)",
symbol.bytes,
src_bytes
);
if symbol.bytes < src_bytes {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_VALUE));
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do this on all of the memcpy assertions, seems reasonable

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially - either way should be a separate PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will keep this for now, and if it's merged, I'll get another pr to change asserts to results

Comment thread src/driver/safe/core.rs Outdated
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Comment thread src/driver/safe/core.rs Outdated
/// let symbol = module.get_global("my_const")?;
/// stream.memcpy_htos(&data, &symbol)?;
/// ```
pub fn get_global(self: &Arc<Self>, name: &str) -> Result<CudaSymbol, DriverError> {

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking it might be better to just return a CudaSlice<u8> from this function. This has the following benefits:

  1. We don't need the CudaSymbol struct at all & we can use the existing memcpy stuff.
  2. CudaSlice can already be transmuted to different types
  3. We have built in support for event tracking - which right now CudaSymbol does not track
  4. We don't need to add support to CudaSymbol for multi stream synchronization

Thoughts? Open to discussion, but definitely leaning towards CudaSlice<u8>

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CudaSlice seems right, will check the impl details and give you another commit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_global should now return a CudaSlice :)

@chelsea0x3b

Copy link
Copy Markdown
Owner

Love this addition - just need to discuss the return from get_global. Thank you for this work

@chelsea0x3b chelsea0x3b merged commit 12cbbab into chelsea0x3b:main Nov 1, 2025
33 checks passed
@chelsea0x3b chelsea0x3b mentioned this pull request Nov 1, 2025
@wizenink wizenink deleted the copy_to_global branch November 4, 2025 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants