File tree Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Original file line number Diff line number Diff line change @@ -21,7 +21,9 @@ impl BenchDevice for Device {
2121 Device :: Cpu => Ok ( ( ) ) ,
2222 Device :: Cuda ( device) => {
2323 #[ cfg( feature = "cuda" ) ]
24- return Ok ( device. synchronize ( ) ?) ;
24+ return Ok ( device
25+ . synchronize ( )
26+ . map_err ( |e| candle_core:: Error :: Cuda ( Box :: new ( e) ) ) ?) ;
2527 #[ cfg( not( feature = "cuda" ) ) ]
2628 panic ! ( "Cuda device without cuda feature enabled: {:?}" , device)
2729 }
Original file line number Diff line number Diff line change @@ -1001,6 +1001,7 @@ pub struct CudaStorage {
10011001
10021002pub trait CudaDType : Sized {
10031003 fn as_cuda_slice ( s : & CudaStorage ) -> Result < & CudaSlice < Self > > ;
1004+ fn as_cuda_slice_mut ( s : & mut CudaStorage ) -> Result < & mut CudaSlice < Self > > ;
10041005 fn wrap_cuda_slice ( s : CudaSlice < Self > , dev : CudaDevice ) -> CudaStorage ;
10051006}
10061007
@@ -1019,6 +1020,18 @@ macro_rules! cuda_dtype {
10191020 }
10201021 }
10211022
1023+ fn as_cuda_slice_mut( s: & mut CudaStorage ) -> Result <& mut CudaSlice <Self >> {
1024+ match s. slice {
1025+ CudaStorageSlice :: $dtype( ref mut data) => Ok ( data) ,
1026+ _ => Err ( crate :: Error :: UnexpectedDType {
1027+ expected: DType :: $dtype,
1028+ got: s. dtype( ) ,
1029+ msg: "unexpected dtype" ,
1030+ }
1031+ . bt( ) ) ,
1032+ }
1033+ }
1034+
10221035 fn wrap_cuda_slice( slice: CudaSlice <Self >, device: CudaDevice ) -> CudaStorage {
10231036 let slice = CudaStorageSlice :: $dtype( slice) ;
10241037 CudaStorage { slice, device }
@@ -1042,6 +1055,10 @@ impl CudaStorage {
10421055 pub fn as_cuda_slice < T : CudaDType > ( & self ) -> Result < & CudaSlice < T > > {
10431056 T :: as_cuda_slice ( self )
10441057 }
1058+
1059+ pub fn as_cuda_slice_mut < T : CudaDType > ( & mut self ) -> Result < & mut CudaSlice < T > > {
1060+ T :: as_cuda_slice_mut ( self )
1061+ }
10451062}
10461063
10471064fn gemm_config < T > (
You can’t perform that action at this time.
0 commit comments