Skip to content

Commit b7b6ac9

Browse files
EricLBuehlerlukekim
authored andcommitted
Add QStorage::from_data api
1 parent 913e18c commit b7b6ac9

1 file changed

Lines changed: 77 additions & 0 deletions

File tree

  • candle-core/src/quantized

candle-core/src/quantized/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ use half::f16;
3232

3333
pub use k_quants::GgmlType;
3434

35+
fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
36+
let size = std::mem::size_of::<T>();
37+
assert_eq!(
38+
data.len() % size,
39+
0,
40+
"Data length must be a multiple of T's size"
41+
);
42+
let ptr = data.as_ptr();
43+
assert_eq!(
44+
(ptr as usize) % std::mem::align_of::<T>(),
45+
0,
46+
"Data pointer must be aligned to T's alignment"
47+
);
48+
unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
49+
}
50+
3551
pub struct QTensor {
3652
storage: QStorage,
3753
shape: Shape,
@@ -63,6 +79,46 @@ pub enum QStorage {
6379
}
6480

6581
impl QStorage {
82+
pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
83+
match device {
84+
Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
85+
Device::Metal(d) => match dtype {
86+
GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
87+
GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
88+
GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
89+
GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
90+
GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
91+
GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
92+
GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
93+
GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
94+
GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
95+
GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
96+
GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
97+
GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
98+
GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
99+
GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
100+
GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
101+
},
102+
Device::Cuda(d) => match dtype {
103+
GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
104+
GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
105+
GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
106+
GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
107+
GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
108+
GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
109+
GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
110+
GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
111+
GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
112+
GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
113+
GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
114+
GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
115+
GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
116+
GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
117+
GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
118+
},
119+
}
120+
}
121+
66122
fn block_size(&self) -> usize {
67123
match self {
68124
QStorage::Cpu(storage) => storage.block_size(),
@@ -208,6 +264,27 @@ impl GgmlDType {
208264
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
209265
}
210266
}
267+
268+
pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
269+
match self {
270+
Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
271+
Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
272+
Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
273+
Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
274+
Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
275+
Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
276+
Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
277+
Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
278+
Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
279+
Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
280+
Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
281+
Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
282+
Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
283+
Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
284+
Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
285+
}
286+
}
287+
211288
/// The type size for blocks in bytes.
212289
pub fn type_size(&self) -> usize {
213290
use k_quants::*;

0 commit comments

Comments
 (0)