Skip to content

Commit 913e18c

Browse files
EricLBuehlerlukekim
authored andcommitted
Add varbuilder get_unchecked (huggingface#52)
1 parent 2a2b051 commit 913e18c

2 files changed

Lines changed: 120 additions & 45 deletions

File tree

candle-nn/src/var_builder.rs

Lines changed: 115 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
//! A `VarBuilder` for variable retrieval from models
2-
//!
31
//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
42
//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
53
//! for training, e.g. using `VarBuilder::from_varmap`.
@@ -59,6 +57,9 @@ pub trait Backend: Send + Sync {
5957
dev: &Device,
6058
) -> Result<Tensor>;
6159

60+
/// Retrieve a tensor based on the name.
61+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;
62+
6263
fn contains_tensor(&self, name: &str) -> bool;
6364
}
6465

@@ -73,6 +74,9 @@ pub trait SimpleBackend: Send + Sync {
7374
dev: &Device,
7475
) -> Result<Tensor>;
7576

77+
/// Retrieve a tensor based on the name.
78+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;
79+
7680
fn contains_tensor(&self, name: &str) -> bool;
7781
}
7882

@@ -89,6 +93,10 @@ impl Backend for Box<dyn SimpleBackend + '_> {
8993
self.as_ref().get(s, name, h, dtype, dev)
9094
}
9195

96+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
97+
self.as_ref().get_unchecked(name, dtype, dev)
98+
}
99+
92100
fn contains_tensor(&self, name: &str) -> bool {
93101
self.as_ref().contains_tensor(name)
94102
}
@@ -194,14 +202,27 @@ impl<B: Backend> VarBuilderArgs<'_, B> {
194202
name: &str,
195203
hints: B::Hints,
196204
) -> Result<Tensor> {
197-
self.get_with_hints_dtype(s, name, hints, self.dtype)
205+
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
198206
}
199207

200208
/// Retrieve the tensor associated with the given name at the current path.
201209
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
202210
self.get_with_hints(s, name, Default::default())
203211
}
204212

213+
/// Retrieve the tensor associated with the given name at the current path.
214+
pub fn get_unchecked(&self, name: &str) -> Result<Tensor> {
215+
self.get_unchecked_dtype(name, self.data.dtype)
216+
}
217+
218+
/// Retrieve the tensor associated with the given name & dtype at the current path.
219+
pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result<Tensor> {
220+
let name = self.path(name);
221+
self.data
222+
.backend
223+
.get_unchecked(&name, dtype, &self.data.device)
224+
}
225+
205226
/// Retrieve the tensor associated with the given name & dtype at the current path.
206227
pub fn get_with_hints_dtype<S: Into<Shape>>(
207228
&self,
@@ -224,6 +245,12 @@ impl SimpleBackend for Zeros {
224245
Tensor::zeros(s, dtype, dev)
225246
}
226247

248+
fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
249+
candle::bail!(
250+
"`Zeros` requires a shape for tensor retrieval, use `get` instead of `get_unchecked`"
251+
)
252+
}
253+
227254
fn contains_tensor(&self, _name: &str) -> bool {
228255
true
229256
}
@@ -238,6 +265,19 @@ impl SimpleBackend for HashMap<String, Tensor> {
238265
dtype: DType,
239266
dev: &Device,
240267
) -> Result<Tensor> {
268+
let tensor = self.get_unchecked(name, dtype, dev)?;
269+
if tensor.shape() != &s {
270+
Err(candle::Error::UnexpectedShape {
271+
msg: format!("shape mismatch for {name}"),
272+
expected: s,
273+
got: tensor.shape().clone(),
274+
}
275+
.bt())?
276+
}
277+
Ok(tensor)
278+
}
279+
280+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
241281
let tensor = self
242282
.get(name)
243283
.ok_or_else(|| {
@@ -247,14 +287,6 @@ impl SimpleBackend for HashMap<String, Tensor> {
247287
.bt()
248288
})?
249289
.clone();
250-
if tensor.shape() != &s {
251-
Err(candle::Error::UnexpectedShape {
252-
msg: format!("shape mismatch for {name}"),
253-
expected: s,
254-
got: tensor.shape().clone(),
255-
}
256-
.bt())?
257-
}
258290
tensor.to_device(dev)?.to_dtype(dtype)
259291
}
260292

@@ -275,6 +307,10 @@ impl SimpleBackend for VarMap {
275307
VarMap::get(self, s, name, h, dtype, dev)
276308
}
277309

310+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
311+
VarMap::get_unchecked(self, name, dtype, dev)
312+
}
313+
278314
fn contains_tensor(&self, name: &str) -> bool {
279315
self.data().lock().unwrap().contains_key(name)
280316
}
@@ -290,11 +326,24 @@ impl SimpleBackend for SafeTensorWithRouting<'_> {
290326
fn get(
291327
&self,
292328
s: Shape,
293-
path: &str,
329+
name: &str,
294330
_: crate::Init,
295331
dtype: DType,
296332
dev: &Device,
297333
) -> Result<Tensor> {
334+
let tensor = self.get_unchecked(name, dtype, dev)?;
335+
if tensor.shape() != &s {
336+
Err(candle::Error::UnexpectedShape {
337+
msg: format!("shape mismatch for {name}"),
338+
expected: s,
339+
got: tensor.shape().clone(),
340+
}
341+
.bt())?
342+
}
343+
Ok(tensor)
344+
}
345+
346+
fn get_unchecked(&self, path: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
298347
let index = self.routing.get(path).ok_or_else(|| {
299348
Error::CannotFindTensor {
300349
path: path.to_string(),
@@ -305,14 +354,6 @@ impl SimpleBackend for SafeTensorWithRouting<'_> {
305354
.tensor(path)?
306355
.load(dev)?
307356
.to_dtype(dtype)?;
308-
if tensor.shape() != &s {
309-
Err(candle::Error::UnexpectedShape {
310-
msg: format!("shape mismatch for {path}"),
311-
expected: s,
312-
got: tensor.shape().clone(),
313-
}
314-
.bt())?
315-
}
316357
Ok(tensor)
317358
}
318359

@@ -325,22 +366,15 @@ impl SimpleBackend for candle::npy::NpzTensors {
325366
fn get(
326367
&self,
327368
s: Shape,
328-
path: &str,
369+
name: &str,
329370
_: crate::Init,
330371
dtype: DType,
331372
dev: &Device,
332373
) -> Result<Tensor> {
333-
let tensor = match self.get(path)? {
334-
None => Err(Error::CannotFindTensor {
335-
path: path.to_string(),
336-
}
337-
.bt())?,
338-
Some(tensor) => tensor,
339-
};
340-
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
374+
let tensor = self.get_unchecked(name, dtype, dev)?;
341375
if tensor.shape() != &s {
342376
Err(candle::Error::UnexpectedShape {
343-
msg: format!("shape mismatch for {path}"),
377+
msg: format!("shape mismatch for {name}"),
344378
expected: s,
345379
got: tensor.shape().clone(),
346380
}
@@ -349,6 +383,18 @@ impl SimpleBackend for candle::npy::NpzTensors {
349383
Ok(tensor)
350384
}
351385

386+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
387+
let tensor = match self.get(name)? {
388+
None => Err(Error::CannotFindTensor {
389+
path: name.to_string(),
390+
}
391+
.bt())?,
392+
Some(tensor) => tensor,
393+
};
394+
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
395+
Ok(tensor)
396+
}
397+
352398
fn contains_tensor(&self, name: &str) -> bool {
353399
self.get(name).is_ok_and(|v| v.is_some())
354400
}
@@ -358,22 +404,15 @@ impl SimpleBackend for candle::pickle::PthTensors {
358404
fn get(
359405
&self,
360406
s: Shape,
361-
path: &str,
407+
name: &str,
362408
_: crate::Init,
363409
dtype: DType,
364410
dev: &Device,
365411
) -> Result<Tensor> {
366-
let tensor = match self.get(path)? {
367-
None => Err(Error::CannotFindTensor {
368-
path: path.to_string(),
369-
}
370-
.bt())?,
371-
Some(tensor) => tensor,
372-
};
373-
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
412+
let tensor = self.get_unchecked(name, dtype, dev)?;
374413
if tensor.shape() != &s {
375414
Err(candle::Error::UnexpectedShape {
376-
msg: format!("shape mismatch for {path}"),
415+
msg: format!("shape mismatch for {name}"),
377416
expected: s,
378417
got: tensor.shape().clone(),
379418
}
@@ -382,6 +421,18 @@ impl SimpleBackend for candle::pickle::PthTensors {
382421
Ok(tensor)
383422
}
384423

424+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
425+
let tensor = match self.get(name)? {
426+
None => Err(Error::CannotFindTensor {
427+
path: name.to_string(),
428+
}
429+
.bt())?,
430+
Some(tensor) => tensor,
431+
};
432+
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
433+
Ok(tensor)
434+
}
435+
385436
fn contains_tensor(&self, name: &str) -> bool {
386437
self.get(name).is_ok_and(|v| v.is_some())
387438
}
@@ -396,7 +447,7 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors {
396447
dtype: DType,
397448
dev: &Device,
398449
) -> Result<Tensor> {
399-
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
450+
let tensor = self.get_unchecked(name, dtype, dev)?;
400451
if tensor.shape() != &s {
401452
Err(candle::Error::UnexpectedShape {
402453
msg: format!("shape mismatch for {name}"),
@@ -408,6 +459,10 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors {
408459
Ok(tensor)
409460
}
410461

462+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
463+
self.load(name, dev)?.to_dtype(dtype)
464+
}
465+
411466
fn contains_tensor(&self, name: &str) -> bool {
412467
self.get(name).is_ok()
413468
}
@@ -422,7 +477,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
422477
dtype: DType,
423478
dev: &Device,
424479
) -> Result<Tensor> {
425-
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
480+
let tensor = self.get_unchecked(name, dtype, dev)?;
426481
if tensor.shape() != &s {
427482
Err(candle::Error::UnexpectedShape {
428483
msg: format!("shape mismatch for {name}"),
@@ -434,6 +489,10 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
434489
Ok(tensor)
435490
}
436491

492+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
493+
self.load(name, dev)?.to_dtype(dtype)
494+
}
495+
437496
fn contains_tensor(&self, name: &str) -> bool {
438497
self.get(name).is_ok()
439498
}
@@ -448,7 +507,7 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
448507
dtype: DType,
449508
dev: &Device,
450509
) -> Result<Tensor> {
451-
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
510+
let tensor = self.get_unchecked(name, dtype, dev)?;
452511
if tensor.shape() != &s {
453512
Err(candle::Error::UnexpectedShape {
454513
msg: format!("shape mismatch for {name}"),
@@ -460,6 +519,10 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
460519
Ok(tensor)
461520
}
462521

522+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
523+
self.load(name, dev)?.to_dtype(dtype)
524+
}
525+
463526
fn contains_tensor(&self, name: &str) -> bool {
464527
self.get(name).is_ok()
465528
}
@@ -714,6 +777,10 @@ impl Backend for ShardedSafeTensors {
714777
Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
715778
}
716779

780+
fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
781+
candle::bail!("`get_unchecked` does not make sense for `ShardedSafeTensors`, use `get`.");
782+
}
783+
717784
fn contains_tensor(&self, name: &str) -> bool {
718785
self.0.get(name).is_ok()
719786
}
@@ -747,6 +814,11 @@ impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
747814
.to_device(dev)
748815
}
749816

817+
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
818+
let name = self.renamer.rename(name);
819+
self.inner.get_unchecked_dtype(&name, dtype)?.to_device(dev)
820+
}
821+
750822
fn contains_tensor(&self, name: &str) -> bool {
751823
let name = self.renamer.rename(name);
752824
self.inner.contains_tensor(&name)

candle-nn/src/var_map.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
//! A `VarMap` is a store that holds named variables.
2-
//!
31
use candle::{DType, Device, Result, Shape, Tensor, Var};
42
use std::collections::HashMap;
53
use std::sync::{Arc, Mutex};
@@ -115,6 +113,11 @@ impl VarMap {
115113
Ok(tensor)
116114
}
117115

116+
/// Retrieve or add a new variable.
117+
pub fn get_unchecked(&self, _path: &str, _dtype: DType, _device: &Device) -> Result<Tensor> {
118+
candle::bail!("`get_unchecked` does not make sense for `VarMap`, use `get`.");
119+
}
120+
118121
pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
119122
&self.data
120123
}

0 commit comments

Comments
 (0)