Skip to content

Commit 12cbbab

Browse files
Add safe API for CUDA constant memory operations (#478)
* Add safe API for CUDA constant memory operations * Don't panic on invalid name Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Moved get_global to return CudaSlice<u8> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e08a924 commit 12cbbab

5 files changed

Lines changed: 202 additions & 0 deletions

File tree

examples/09-constant-memory.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use cudarc::{
2+
driver::{CudaContext, DriverError, LaunchConfig, PushKernelArg},
3+
nvrtc::Ptx,
4+
};
5+
6+
fn main() -> Result<(), DriverError> {
7+
let ctx = CudaContext::new(0)?;
8+
let stream = ctx.default_stream();
9+
10+
// Load the module containing the kernel with constant memory
11+
let module = ctx.load_module(Ptx::from_file("./examples/constant_memory.ptx"))?;
12+
13+
// Get the constant memory symbol as a CudaSlice<u8>
14+
let mut coefficients_symbol = module.get_global("coefficients", &stream)?;
15+
println!(
16+
"Constant memory symbol 'coefficients' has {} bytes",
17+
coefficients_symbol.len()
18+
);
19+
20+
// Set up polynomial coefficients: 1.0 + 2.0*x + 3.0*x^2 + 4.0*x^3
21+
let coefficients = [1.0f32, 2.0, 3.0, 4.0];
22+
23+
// Transmute the symbol to f32 and copy coefficients to constant memory
24+
let mut symbol_view = coefficients_symbol.as_view_mut();
25+
let mut symbol_f32 = unsafe { symbol_view.transmute_mut::<f32>(4).unwrap() };
26+
stream.memcpy_htod(&coefficients, &mut symbol_f32)?;
27+
28+
// Load the kernel function
29+
let polynomial_kernel = module.load_function("polynomial_kernel")?;
30+
31+
// Prepare input data
32+
let input = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
33+
let n = input.len();
34+
35+
// Copy input to device
36+
let input_dev = stream.memcpy_stod(&input)?;
37+
let mut output_dev = stream.alloc_zeros::<f32>(n)?;
38+
39+
// Launch kernel
40+
let cfg = LaunchConfig::for_num_elems(n as u32);
41+
unsafe {
42+
stream
43+
.launch_builder(&polynomial_kernel)
44+
.arg(&mut output_dev)
45+
.arg(&input_dev)
46+
.arg(&(n as i32))
47+
.launch(cfg)
48+
}?;
49+
50+
// Copy results back
51+
let output = stream.memcpy_dtov(&output_dev)?;
52+
53+
// Verify results
54+
println!("\nPolynomial evaluation (1.0 + 2.0*x + 3.0*x^2 + 4.0*x^3):");
55+
for (i, (&x, &y)) in input.iter().zip(output.iter()).enumerate() {
56+
let expected = coefficients[0]
57+
+ coefficients[1] * x
58+
+ coefficients[2] * x * x
59+
+ coefficients[3] * x * x * x;
60+
println!(" f({:.1}) = {:.1} (expected {:.1})", x, y, expected);
61+
assert!(
62+
(y - expected).abs() < 1e-4,
63+
"Mismatch at index {}: got {}, expected {}",
64+
i,
65+
y,
66+
expected
67+
);
68+
}
69+
70+
println!("\nAll results match expected values!");
71+
72+
Ok(())
73+
}

examples/constant_memory.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Constant memory - faster than global memory for read-only data
2+
// accessed by all threads
3+
__constant__ float coefficients[4];
4+
5+
extern "C" __global__ void polynomial_kernel(
6+
float *out,
7+
const float *inp,
8+
int numel
9+
) {
10+
int i = blockIdx.x * blockDim.x + threadIdx.x;
11+
if (i < numel) {
12+
float x = inp[i];
13+
// Compute polynomial: coefficients[0] + coefficients[1]*x + coefficients[2]*x^2 + coefficients[3]*x^3
14+
out[i] = coefficients[0] +
15+
coefficients[1] * x +
16+
coefficients[2] * x * x +
17+
coefficients[3] * x * x * x;
18+
}
19+
}

examples/constant_memory.ptx

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//
2+
// Generated by NVIDIA NVVM Compiler
3+
//
4+
// Compiler Build ID: CL-36424714
5+
// Cuda compilation tools, release 13.0, V13.0.88
6+
// Based on NVVM 7.0.1
7+
//
8+
9+
.version 9.0
10+
.target sm_75
11+
.address_size 64
12+
13+
// .globl polynomial_kernel
14+
.const .align 4 .b8 coefficients[16];
15+
16+
.visible .entry polynomial_kernel(
17+
.param .u64 polynomial_kernel_param_0,
18+
.param .u64 polynomial_kernel_param_1,
19+
.param .u32 polynomial_kernel_param_2
20+
)
21+
{
22+
.reg .pred %p<2>;
23+
.reg .f32 %f<12>;
24+
.reg .b32 %r<6>;
25+
.reg .b64 %rd<8>;
26+
27+
28+
ld.param.u64 %rd1, [polynomial_kernel_param_0];
29+
ld.param.u64 %rd2, [polynomial_kernel_param_1];
30+
ld.param.u32 %r2, [polynomial_kernel_param_2];
31+
mov.u32 %r3, %ctaid.x;
32+
mov.u32 %r4, %ntid.x;
33+
mov.u32 %r5, %tid.x;
34+
mad.lo.s32 %r1, %r3, %r4, %r5;
35+
setp.ge.s32 %p1, %r1, %r2;
36+
@%p1 bra $L__BB0_2;
37+
38+
cvta.to.global.u64 %rd3, %rd2;
39+
mul.wide.s32 %rd4, %r1, 4;
40+
add.s64 %rd5, %rd3, %rd4;
41+
ld.const.f32 %f1, [coefficients+4];
42+
ld.global.f32 %f2, [%rd5];
43+
ld.const.f32 %f3, [coefficients];
44+
fma.rn.f32 %f4, %f2, %f1, %f3;
45+
ld.const.f32 %f5, [coefficients+8];
46+
mul.f32 %f6, %f2, %f5;
47+
fma.rn.f32 %f7, %f2, %f6, %f4;
48+
ld.const.f32 %f8, [coefficients+12];
49+
mul.f32 %f9, %f2, %f8;
50+
mul.f32 %f10, %f2, %f9;
51+
fma.rn.f32 %f11, %f2, %f10, %f7;
52+
cvta.to.global.u64 %rd6, %rd1;
53+
add.s64 %rd7, %rd6, %rd4;
54+
st.global.f32 [%rd7], %f11;
55+
56+
$L__BB0_2:
57+
ret;
58+
59+
}
60+

src/driver/result.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,24 @@ pub mod module {
965965
Ok(func.assume_init())
966966
}
967967

968+
/// Returns a pointer to a global/constant symbol in the module.
969+
///
970+
/// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html#group__CUDA__MODULE_1gf3e43972c23c2d5c8a662f2d9a4d0c24)
971+
///
972+
/// # Safety
973+
/// `module` must be a properly allocated and not freed module.
974+
pub unsafe fn get_global(
975+
module: sys::CUmodule,
976+
name: CString,
977+
) -> Result<(sys::CUdeviceptr, usize), DriverError> {
978+
let name_ptr = name.as_c_str().as_ptr();
979+
let mut dptr = MaybeUninit::uninit();
980+
let mut bytes = MaybeUninit::uninit();
981+
sys::cuModuleGetGlobal_v2(dptr.as_mut_ptr(), bytes.as_mut_ptr(), module, name_ptr)
982+
.result()?;
983+
Ok((dptr.assume_init(), bytes.assume_init()))
984+
}
985+
968986
/// Unloads a module.
969987
///
970988
/// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html#group__CUDA__MODULE_1g8ea3d716524369de3763104ced4ea57b)

src/driver/safe/core.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,38 @@ impl CudaModule {
17611761
module: self.clone(),
17621762
})
17631763
}
1764+
1765+
/// Gets a global/constant symbol from the loaded module as a [CudaSlice<u8>].
1766+
///
1767+
/// This can be used to access `__constant__` memory declared in CUDA kernels.
1768+
/// The returned slice can be transmuted to the appropriate type via views.
1769+
///
1770+
/// # Example
1771+
///
1772+
/// ```ignore
1773+
/// // In CUDA: __constant__ float my_const[4];
1774+
/// let symbol = module.get_global("my_const", &stream)?;
1775+
/// let mut symbol_view = symbol.as_view_mut();
1776+
/// let mut symbol_f32 = unsafe { symbol_view.transmute_mut::<f32>(4).unwrap() };
1777+
/// stream.memcpy_htod(&[1.0f32, 2.0, 3.0, 4.0], &mut symbol_f32)?;
1778+
/// ```
1779+
pub fn get_global(
1780+
self: &Arc<Self>,
1781+
name: &str,
1782+
stream: &Arc<CudaStream>,
1783+
) -> Result<CudaSlice<u8>, DriverError> {
1784+
let name_c =
1785+
CString::new(name).map_err(|_| DriverError(sys::CUresult::CUDA_ERROR_INVALID_VALUE))?;
1786+
let (cu_device_ptr, bytes) = unsafe { result::module::get_global(self.cu_module, name_c) }?;
1787+
Ok(CudaSlice {
1788+
cu_device_ptr,
1789+
len: bytes,
1790+
read: None,
1791+
write: None,
1792+
stream: stream.clone(),
1793+
marker: PhantomData,
1794+
})
1795+
}
17641796
}
17651797

17661798
impl CudaFunction {

0 commit comments

Comments
 (0)