Skip to content

Commit b4daa03

Browse files
authored
add as_cuda_slice_mut to CudaStorage and CudaDType (#2859)
1 parent 9541467 commit b4daa03

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

candle-core/benches/benchmarks/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ impl BenchDevice for Device {
2121
Device::Cpu => Ok(()),
2222
Device::Cuda(device) => {
2323
#[cfg(feature = "cuda")]
24-
return Ok(device.synchronize()?);
24+
return Ok(device
25+
.synchronize()
26+
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
2527
#[cfg(not(feature = "cuda"))]
2628
panic!("Cuda device without cuda feature enabled: {:?}", device)
2729
}

candle-core/src/cuda_backend/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@ pub struct CudaStorage {
10011001

10021002
pub trait CudaDType: Sized {
10031003
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
1004+
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>>;
10041005
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
10051006
}
10061007

@@ -1019,6 +1020,18 @@ macro_rules! cuda_dtype {
10191020
}
10201021
}
10211022

1023+
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>> {
1024+
match s.slice {
1025+
CudaStorageSlice::$dtype(ref mut data) => Ok(data),
1026+
_ => Err(crate::Error::UnexpectedDType {
1027+
expected: DType::$dtype,
1028+
got: s.dtype(),
1029+
msg: "unexpected dtype",
1030+
}
1031+
.bt()),
1032+
}
1033+
}
1034+
10221035
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
10231036
let slice = CudaStorageSlice::$dtype(slice);
10241037
CudaStorage { slice, device }
@@ -1042,6 +1055,10 @@ impl CudaStorage {
10421055
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
10431056
T::as_cuda_slice(self)
10441057
}
1058+
1059+
pub fn as_cuda_slice_mut<T: CudaDType>(&mut self) -> Result<&mut CudaSlice<T>> {
1060+
T::as_cuda_slice_mut(self)
1061+
}
10451062
}
10461063

10471064
fn gemm_config<T>(

0 commit comments

Comments
 (0)