Skip to content

Commit 671d3d4

Browse files
committed
feat(driver): prevent pinned host allocation size overflow
Compute pinned allocation bytes without wrapping before calling CUDA while preserving the existing panic-based slice-size invariant. Add a regression test for overflowing typed lengths.
1 parent 2afd12c commit 671d3d4

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

src/driver/safe/core.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,10 +1433,13 @@ impl CudaContext {
14331433
flags: u32,
14341434
) -> Result<PinnedHostSlice<T>, DriverError> {
14351435
self.bind_to_thread()?;
1436-
let ptr = result::malloc_host(len * std::mem::size_of::<T>(), flags)?;
1436+
let num_bytes = len
1437+
.checked_mul(std::mem::size_of::<T>())
1438+
.expect("Pinned host allocation size overflow");
1439+
assert!(num_bytes < isize::MAX as usize);
1440+
let ptr = result::malloc_host(num_bytes, flags)?;
14371441
let ptr = ptr as *mut T;
14381442
assert!(!ptr.is_null());
1439-
assert!(len * std::mem::size_of::<T>() < isize::MAX as usize);
14401443
assert!(ptr.is_aligned());
14411444
let event = self.new_event(Some(sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
14421445
Ok(PinnedHostSlice { ptr, len, event })
@@ -2717,6 +2720,13 @@ mod tests {
27172720
assert_eq!(&host, &truth);
27182721
}
27192722

2723+
#[test]
2724+
#[should_panic(expected = "Pinned host allocation size overflow")]
2725+
fn test_alloc_pinned_panics_on_size_overflow() {
2726+
let ctx = CudaContext::new(0).unwrap();
2727+
let _ = unsafe { ctx.alloc_pinned_with_flags::<u32>(usize::MAX, 0) };
2728+
}
2729+
27202730
#[test]
27212731
fn test_default_pinned_host_reads_are_faster_than_write_combined() {
27222732
fn timed_host_reads(values: &[u32], n_samples: usize) -> (std::time::Duration, u64) {

0 commit comments

Comments
 (0)