1- //! A `VarBuilder` for variable retrieval from models
2- //!
31//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
42//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
53//! for training, e.g. using `VarBuilder::from_varmap`.
@@ -59,6 +57,9 @@ pub trait Backend: Send + Sync {
5957 dev : & Device ,
6058 ) -> Result < Tensor > ;
6159
60+ /// Retrieve a tensor based on the name.
61+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > ;
62+
6263 fn contains_tensor ( & self , name : & str ) -> bool ;
6364}
6465
@@ -73,6 +74,9 @@ pub trait SimpleBackend: Send + Sync {
7374 dev : & Device ,
7475 ) -> Result < Tensor > ;
7576
77+ /// Retrieve a tensor based on the name.
78+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > ;
79+
7680 fn contains_tensor ( & self , name : & str ) -> bool ;
7781}
7882
@@ -89,6 +93,10 @@ impl Backend for Box<dyn SimpleBackend + '_> {
8993 self . as_ref ( ) . get ( s, name, h, dtype, dev)
9094 }
9195
96+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
97+ self . as_ref ( ) . get_unchecked ( name, dtype, dev)
98+ }
99+
92100 fn contains_tensor ( & self , name : & str ) -> bool {
93101 self . as_ref ( ) . contains_tensor ( name)
94102 }
@@ -194,14 +202,27 @@ impl<B: Backend> VarBuilderArgs<'_, B> {
194202 name : & str ,
195203 hints : B :: Hints ,
196204 ) -> Result < Tensor > {
197- self . get_with_hints_dtype ( s, name, hints, self . dtype )
205+ self . get_with_hints_dtype ( s, name, hints, self . data . dtype )
198206 }
199207
200208 /// Retrieve the tensor associated with the given name at the current path.
201209 pub fn get < S : Into < Shape > > ( & self , s : S , name : & str ) -> Result < Tensor > {
202210 self . get_with_hints ( s, name, Default :: default ( ) )
203211 }
204212
213+ /// Retrieve the tensor associated with the given name at the current path.
214+ pub fn get_unchecked ( & self , name : & str ) -> Result < Tensor > {
215+ self . get_unchecked_dtype ( name, self . data . dtype )
216+ }
217+
218+ /// Retrieve the tensor associated with the given name & dtype at the current path.
219+ pub fn get_unchecked_dtype ( & self , name : & str , dtype : DType ) -> Result < Tensor > {
220+ let name = self . path ( name) ;
221+ self . data
222+ . backend
223+ . get_unchecked ( & name, dtype, & self . data . device )
224+ }
225+
205226 /// Retrieve the tensor associated with the given name & dtype at the current path.
206227 pub fn get_with_hints_dtype < S : Into < Shape > > (
207228 & self ,
@@ -224,6 +245,12 @@ impl SimpleBackend for Zeros {
224245 Tensor :: zeros ( s, dtype, dev)
225246 }
226247
248+ fn get_unchecked ( & self , _name : & str , _dtype : DType , _dev : & Device ) -> Result < Tensor > {
249+ candle:: bail!(
250+ "`Zeros` requires a shape for tensor retrieval, use `get` instead of `get_unchecked`"
251+ )
252+ }
253+
227254 fn contains_tensor ( & self , _name : & str ) -> bool {
228255 true
229256 }
@@ -238,6 +265,19 @@ impl SimpleBackend for HashMap<String, Tensor> {
238265 dtype : DType ,
239266 dev : & Device ,
240267 ) -> Result < Tensor > {
268+ let tensor = self . get_unchecked ( name, dtype, dev) ?;
269+ if tensor. shape ( ) != & s {
270+ Err ( candle:: Error :: UnexpectedShape {
271+ msg : format ! ( "shape mismatch for {name}" ) ,
272+ expected : s,
273+ got : tensor. shape ( ) . clone ( ) ,
274+ }
275+ . bt ( ) ) ?
276+ }
277+ Ok ( tensor)
278+ }
279+
280+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
241281 let tensor = self
242282 . get ( name)
243283 . ok_or_else ( || {
@@ -247,14 +287,6 @@ impl SimpleBackend for HashMap<String, Tensor> {
247287 . bt ( )
248288 } ) ?
249289 . clone ( ) ;
250- if tensor. shape ( ) != & s {
251- Err ( candle:: Error :: UnexpectedShape {
252- msg : format ! ( "shape mismatch for {name}" ) ,
253- expected : s,
254- got : tensor. shape ( ) . clone ( ) ,
255- }
256- . bt ( ) ) ?
257- }
258290 tensor. to_device ( dev) ?. to_dtype ( dtype)
259291 }
260292
@@ -275,6 +307,10 @@ impl SimpleBackend for VarMap {
275307 VarMap :: get ( self , s, name, h, dtype, dev)
276308 }
277309
310+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
311+ VarMap :: get_unchecked ( self , name, dtype, dev)
312+ }
313+
278314 fn contains_tensor ( & self , name : & str ) -> bool {
279315 self . data ( ) . lock ( ) . unwrap ( ) . contains_key ( name)
280316 }
@@ -290,11 +326,24 @@ impl SimpleBackend for SafeTensorWithRouting<'_> {
290326 fn get (
291327 & self ,
292328 s : Shape ,
293- path : & str ,
329+ name : & str ,
294330 _: crate :: Init ,
295331 dtype : DType ,
296332 dev : & Device ,
297333 ) -> Result < Tensor > {
334+ let tensor = self . get_unchecked ( name, dtype, dev) ?;
335+ if tensor. shape ( ) != & s {
336+ Err ( candle:: Error :: UnexpectedShape {
337+ msg : format ! ( "shape mismatch for {name}" ) ,
338+ expected : s,
339+ got : tensor. shape ( ) . clone ( ) ,
340+ }
341+ . bt ( ) ) ?
342+ }
343+ Ok ( tensor)
344+ }
345+
346+ fn get_unchecked ( & self , path : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
298347 let index = self . routing . get ( path) . ok_or_else ( || {
299348 Error :: CannotFindTensor {
300349 path : path. to_string ( ) ,
@@ -305,14 +354,6 @@ impl SimpleBackend for SafeTensorWithRouting<'_> {
305354 . tensor ( path) ?
306355 . load ( dev) ?
307356 . to_dtype ( dtype) ?;
308- if tensor. shape ( ) != & s {
309- Err ( candle:: Error :: UnexpectedShape {
310- msg : format ! ( "shape mismatch for {path}" ) ,
311- expected : s,
312- got : tensor. shape ( ) . clone ( ) ,
313- }
314- . bt ( ) ) ?
315- }
316357 Ok ( tensor)
317358 }
318359
@@ -325,22 +366,15 @@ impl SimpleBackend for candle::npy::NpzTensors {
325366 fn get (
326367 & self ,
327368 s : Shape ,
328- path : & str ,
369+ name : & str ,
329370 _: crate :: Init ,
330371 dtype : DType ,
331372 dev : & Device ,
332373 ) -> Result < Tensor > {
333- let tensor = match self . get ( path) ? {
334- None => Err ( Error :: CannotFindTensor {
335- path : path. to_string ( ) ,
336- }
337- . bt ( ) ) ?,
338- Some ( tensor) => tensor,
339- } ;
340- let tensor = tensor. to_device ( dev) ?. to_dtype ( dtype) ?;
374+ let tensor = self . get_unchecked ( name, dtype, dev) ?;
341375 if tensor. shape ( ) != & s {
342376 Err ( candle:: Error :: UnexpectedShape {
343- msg : format ! ( "shape mismatch for {path }" ) ,
377+ msg : format ! ( "shape mismatch for {name }" ) ,
344378 expected : s,
345379 got : tensor. shape ( ) . clone ( ) ,
346380 }
@@ -349,6 +383,18 @@ impl SimpleBackend for candle::npy::NpzTensors {
349383 Ok ( tensor)
350384 }
351385
386+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
387+ let tensor = match self . get ( name) ? {
388+ None => Err ( Error :: CannotFindTensor {
389+ path : name. to_string ( ) ,
390+ }
391+ . bt ( ) ) ?,
392+ Some ( tensor) => tensor,
393+ } ;
394+ let tensor = tensor. to_device ( dev) ?. to_dtype ( dtype) ?;
395+ Ok ( tensor)
396+ }
397+
352398 fn contains_tensor ( & self , name : & str ) -> bool {
353399 self . get ( name) . is_ok_and ( |v| v. is_some ( ) )
354400 }
@@ -358,22 +404,15 @@ impl SimpleBackend for candle::pickle::PthTensors {
358404 fn get (
359405 & self ,
360406 s : Shape ,
361- path : & str ,
407+ name : & str ,
362408 _: crate :: Init ,
363409 dtype : DType ,
364410 dev : & Device ,
365411 ) -> Result < Tensor > {
366- let tensor = match self . get ( path) ? {
367- None => Err ( Error :: CannotFindTensor {
368- path : path. to_string ( ) ,
369- }
370- . bt ( ) ) ?,
371- Some ( tensor) => tensor,
372- } ;
373- let tensor = tensor. to_device ( dev) ?. to_dtype ( dtype) ?;
412+ let tensor = self . get_unchecked ( name, dtype, dev) ?;
374413 if tensor. shape ( ) != & s {
375414 Err ( candle:: Error :: UnexpectedShape {
376- msg : format ! ( "shape mismatch for {path }" ) ,
415+ msg : format ! ( "shape mismatch for {name }" ) ,
377416 expected : s,
378417 got : tensor. shape ( ) . clone ( ) ,
379418 }
@@ -382,6 +421,18 @@ impl SimpleBackend for candle::pickle::PthTensors {
382421 Ok ( tensor)
383422 }
384423
424+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
425+ let tensor = match self . get ( name) ? {
426+ None => Err ( Error :: CannotFindTensor {
427+ path : name. to_string ( ) ,
428+ }
429+ . bt ( ) ) ?,
430+ Some ( tensor) => tensor,
431+ } ;
432+ let tensor = tensor. to_device ( dev) ?. to_dtype ( dtype) ?;
433+ Ok ( tensor)
434+ }
435+
385436 fn contains_tensor ( & self , name : & str ) -> bool {
386437 self . get ( name) . is_ok_and ( |v| v. is_some ( ) )
387438 }
@@ -396,7 +447,7 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors {
396447 dtype : DType ,
397448 dev : & Device ,
398449 ) -> Result < Tensor > {
399- let tensor = self . load ( name, dev ) ? . to_dtype ( dtype) ?;
450+ let tensor = self . get_unchecked ( name, dtype, dev ) ?;
400451 if tensor. shape ( ) != & s {
401452 Err ( candle:: Error :: UnexpectedShape {
402453 msg : format ! ( "shape mismatch for {name}" ) ,
@@ -408,6 +459,10 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors {
408459 Ok ( tensor)
409460 }
410461
462+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
463+ self . load ( name, dev) ?. to_dtype ( dtype)
464+ }
465+
411466 fn contains_tensor ( & self , name : & str ) -> bool {
412467 self . get ( name) . is_ok ( )
413468 }
@@ -422,7 +477,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
422477 dtype : DType ,
423478 dev : & Device ,
424479 ) -> Result < Tensor > {
425- let tensor = self . load ( name, dev ) ? . to_dtype ( dtype) ?;
480+ let tensor = self . get_unchecked ( name, dtype, dev ) ?;
426481 if tensor. shape ( ) != & s {
427482 Err ( candle:: Error :: UnexpectedShape {
428483 msg : format ! ( "shape mismatch for {name}" ) ,
@@ -434,6 +489,10 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
434489 Ok ( tensor)
435490 }
436491
492+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
493+ self . load ( name, dev) ?. to_dtype ( dtype)
494+ }
495+
437496 fn contains_tensor ( & self , name : & str ) -> bool {
438497 self . get ( name) . is_ok ( )
439498 }
@@ -448,7 +507,7 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
448507 dtype : DType ,
449508 dev : & Device ,
450509 ) -> Result < Tensor > {
451- let tensor = self . load ( name, dev ) ? . to_dtype ( dtype) ?;
510+ let tensor = self . get_unchecked ( name, dtype, dev ) ?;
452511 if tensor. shape ( ) != & s {
453512 Err ( candle:: Error :: UnexpectedShape {
454513 msg : format ! ( "shape mismatch for {name}" ) ,
@@ -460,6 +519,10 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
460519 Ok ( tensor)
461520 }
462521
522+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
523+ self . load ( name, dev) ?. to_dtype ( dtype)
524+ }
525+
463526 fn contains_tensor ( & self , name : & str ) -> bool {
464527 self . get ( name) . is_ok ( )
465528 }
@@ -714,6 +777,10 @@ impl Backend for ShardedSafeTensors {
714777 Tensor :: from_raw_buffer ( & raw , view_dtype, & shape, dev) ?. to_dtype ( dtype)
715778 }
716779
780+ fn get_unchecked ( & self , _name : & str , _dtype : DType , _dev : & Device ) -> Result < Tensor > {
781+ candle:: bail!( "`get_unchecked` does not make sense for `ShardedSafeTensors`, use `get`." ) ;
782+ }
783+
717784 fn contains_tensor ( & self , name : & str ) -> bool {
718785 self . 0 . get ( name) . is_ok ( )
719786 }
@@ -747,6 +814,11 @@ impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
747814 . to_device ( dev)
748815 }
749816
817+ fn get_unchecked ( & self , name : & str , dtype : DType , dev : & Device ) -> Result < Tensor > {
818+ let name = self . renamer . rename ( name) ;
819+ self . inner . get_unchecked_dtype ( & name, dtype) ?. to_device ( dev)
820+ }
821+
750822 fn contains_tensor ( & self , name : & str ) -> bool {
751823 let name = self . renamer . rename ( name) ;
752824 self . inner . contains_tensor ( & name)
0 commit comments