Skip to content

Commit 1b02ddb

Browse files
authored
Merge pull request #6 from spiceai/jeadie/25-04-15/upstream-spiceai
Update spiceai from EricLBuehler/candle
2 parents e9db971 + 93af275 commit 1b02ddb

32 files changed

Lines changed: 4373 additions & 989 deletions

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
[submodule "candle-flash-attn-v3/cutlass"]
55
url = https://github.com/NVIDIA/cutlass.git
66
path = candle-flash-attn-v3/cutlass
7+
[submodule "candle-flash-mla/cutlass"]
8+
path = candle-flash-mla/cutlass
9+
url = https://github.com/NVIDIA/cutlass

Cargo.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ exclude = [
1717
"candle-kernels",
1818
"candle-metal-kernels",
1919
"candle-onnx",
20+
"candle-flash-mla",
2021
]
2122
resolver = "2"
2223

@@ -38,6 +39,7 @@ candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" }
3839
candle-datasets = { path = "./candle-datasets", version = "0.8.0" }
3940
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" }
4041
candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.8.0" }
42+
candle-flash-mla = { path = "./candle-flash-mla", version = "0.8.0" }
4143
candle-kernels = { path = "./candle-kernels", version = "0.8.0" }
4244
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" }
4345
candle-nn = { path = "./candle-nn", version = "0.8.0" }
@@ -50,7 +52,7 @@ fancy-regex = "0.13.0"
5052
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
5153
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
5254
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
53-
float8 = { version = "0.1.2", features = ["num-traits", "rand_distr"] }
55+
float8 = { version = "0.2.0", features = ["num-traits", "rand_distr"] }
5456
hound = "3.5.1"
5557
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
5658
imageproc = { version = "0.24.0", default-features = false }
@@ -61,8 +63,8 @@ memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
6163
num_cpus = "1.15.0"
6264
num-traits = "0.2.15"
6365
parquet = { version = "51.0.0" }
64-
rand = "0.8.5"
65-
rand_distr = "0.4.3"
66+
rand = "0.9.0"
67+
rand_distr = "0.5"
6668
rayon = "1.7.0"
6769
safetensors = "0.4.1"
6870
serde = { version = "1.0.171", features = ["derive"] }

candle-core/src/cpu_backend/mod.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3159,6 +3159,9 @@ impl BackendStorage for CpuStorage {
31593159
(Self::F64(src), Self::F64(dst)) => {
31603160
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
31613161
}
3162+
(Self::F8E4M3(src), Self::F8E4M3(dst)) => {
3163+
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
3164+
}
31623165
(_, dst) => {
31633166
return Err(Error::DTypeMismatchBinaryOp {
31643167
lhs: self.dtype(),
@@ -3182,6 +3185,9 @@ impl BackendStorage for CpuStorage {
31823185
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
31833186
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
31843187
(Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
3188+
(Self::F8E4M3(src), Self::F8E4M3(dst)) => {
3189+
copy_strided_src_(src, dst, dst_offset, src_l)
3190+
}
31853191
(_, dst) => {
31863192
// This should be covered by the dtype check above.
31873193
return Err(Error::DTypeMismatchBinaryOp {
@@ -3540,24 +3546,24 @@ impl BackendDevice for CpuDevice {
35403546
use rand::prelude::*;
35413547

35423548
let elem_count = shape.elem_count();
3543-
let mut rng = rand::thread_rng();
3549+
let mut rng = rand::rng();
35443550
match dtype {
35453551
DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => {
35463552
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
35473553
}
35483554
DType::BF16 => {
35493555
let mut data = Vec::with_capacity(elem_count);
3550-
let uniform =
3551-
rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
3556+
let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
3557+
.map_err(Error::wrap)?;
35523558
for _i in 0..elem_count {
35533559
data.push(rng.sample::<bf16, _>(uniform))
35543560
}
35553561
Ok(CpuStorage::BF16(data))
35563562
}
35573563
DType::F16 => {
35583564
let mut data = Vec::with_capacity(elem_count);
3559-
let uniform =
3560-
rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
3565+
let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
3566+
.map_err(Error::wrap)?;
35613567
for _i in 0..elem_count {
35623568
data.push(rng.sample::<f16, _>(uniform))
35633569
}
@@ -3566,23 +3572,25 @@ impl BackendDevice for CpuDevice {
35663572
DType::F8E4M3 => {
35673573
let mut data = Vec::with_capacity(elem_count);
35683574
let uniform =
3569-
rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max));
3575+
rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
3576+
.map_err(Error::wrap)?;
35703577
for _i in 0..elem_count {
35713578
data.push(rng.sample::<F8E4M3, _>(uniform))
35723579
}
35733580
Ok(CpuStorage::F8E4M3(data))
35743581
}
35753582
DType::F32 => {
35763583
let mut data = Vec::with_capacity(elem_count);
3577-
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
3584+
let uniform =
3585+
rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
35783586
for _i in 0..elem_count {
35793587
data.push(rng.sample::<f32, _>(uniform))
35803588
}
35813589
Ok(CpuStorage::F32(data))
35823590
}
35833591
DType::F64 => {
35843592
let mut data = Vec::with_capacity(elem_count);
3585-
let uniform = rand::distributions::Uniform::new(min, max);
3593+
let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
35863594
for _i in 0..elem_count {
35873595
data.push(rng.sample::<f64, _>(uniform))
35883596
}
@@ -3595,7 +3603,7 @@ impl BackendDevice for CpuDevice {
35953603
use rand::prelude::*;
35963604

35973605
let elem_count = shape.elem_count();
3598-
let mut rng = rand::thread_rng();
3606+
let mut rng = rand::rng();
35993607
match dtype {
36003608
DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => {
36013609
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())

candle-core/src/pickle.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ impl PthTensors {
792792
/// # Arguments
793793
/// * `path` - Path to the pth file.
794794
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
795-
/// contains multiple objects and the state_dict is the one we are interested in.
795+
/// contains multiple objects and the `state_dict` is the one we are interested in.
796796
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
797797
path: P,
798798
key: Option<&str>,

candle-core/src/quantized/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ use half::{bf16, f16};
3232

3333
pub use k_quants::GgmlType;
3434

35+
fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
36+
let size = std::mem::size_of::<T>();
37+
assert_eq!(
38+
data.len() % size,
39+
0,
40+
"Data length must be a multiple of T's size"
41+
);
42+
let ptr = data.as_ptr();
43+
assert_eq!(
44+
(ptr as usize) % std::mem::align_of::<T>(),
45+
0,
46+
"Data pointer must be aligned to T's alignment"
47+
);
48+
unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
49+
}
50+
3551
pub struct QTensor {
3652
storage: QStorage,
3753
shape: Shape,
@@ -63,6 +79,46 @@ pub enum QStorage {
6379
}
6480

6581
impl QStorage {
82+
pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
83+
match device {
84+
Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
85+
Device::Metal(d) => match dtype {
86+
GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
87+
GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
88+
GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
89+
GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
90+
GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
91+
GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
92+
GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
93+
GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
94+
GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
95+
GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
96+
GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
97+
GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
98+
GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
99+
GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
100+
GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
101+
},
102+
Device::Cuda(d) => match dtype {
103+
GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
104+
GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
105+
GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
106+
GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
107+
GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
108+
GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
109+
GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
110+
GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
111+
GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
112+
GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
113+
GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
114+
GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
115+
GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
116+
GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
117+
GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
118+
},
119+
}
120+
}
121+
66122
fn block_size(&self) -> usize {
67123
match self {
68124
QStorage::Cpu(storage) => storage.block_size(),
@@ -267,6 +323,27 @@ impl GgmlDType {
267323
Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
268324
}
269325
}
326+
327+
pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
328+
match self {
329+
Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
330+
Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
331+
Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
332+
Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
333+
Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
334+
Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
335+
Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
336+
Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
337+
Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
338+
Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
339+
Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
340+
Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
341+
Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
342+
Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
343+
Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
344+
}
345+
}
346+
270347
/// The type size for blocks in bytes.
271348
pub fn type_size(&self) -> usize {
272349
use k_quants::*;

candle-flash-attn-v3/build.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ fn main() -> Result<()> {
202202
command.arg("-DCUTLASS_DEBUG_TRACE_LEVEL=0");
203203
command.arg("-DNDEBUG");
204204

205+
// https://github.com/EricLBuehler/mistral.rs/issues/941
206+
command.arg("-D_USE_MATH_DEFINES");
207+
205208
if let Some(ccbin_path) = &ccbin_env {
206209
command.arg("-allow-unsupported-compiler");
207210
command.args(["-ccbin", ccbin_path]);

candle-flash-attn/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ fn main() -> Result<()> {
9595
builder = builder.arg("-D_USE_MATH_DEFINES");
9696
}
9797
}
98+
// https://github.com/EricLBuehler/mistral.rs/issues/941
99+
builder = builder.arg("-D_USE_MATH_DEFINES");
98100

99101
// https://github.com/EricLBuehler/mistral.rs/issues/286
100102
// https://github.com/huggingface/candle-flash-attn-v1/pull/2

candle-flash-mla/.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
.idea
2+
target
3+
Cargo.lock
4+
.venv
5+
hkernel/build/*
6+
__pycache__
7+
*.egg-info

candle-flash-mla/.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "cutlass"]
2+
path = cutlass
3+
url = https://github.com/NVIDIA/cutlass.git

candle-flash-mla/Cargo.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[package]
2+
name = "candle-flash-mla"
3+
version = "0.8.0"
4+
edition = "2021"
5+
6+
description = "Flash MLA layer for the candle ML framework."
7+
keywords = ["blas", "tensor", "machine-learning"]
8+
categories = ["science"]
9+
license = "MIT OR Apache-2.0"
10+
readme = "README.md"
11+
12+
[dependencies]
13+
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.0" }
14+
half = { version = "2.3.1", features = ["num-traits"] }
15+
16+
[build-dependencies]
17+
anyhow = { version = "1", features = ["backtrace"] }
18+
num_cpus = "1.15.0"
19+
rayon = "1.7.0"
20+
21+
[dev-dependencies]
22+
anyhow = { version = "1", features = ["backtrace"] }
23+
candle-nn = { path = "../candle-nn", features = ["cuda"] }
24+
rstest = "0.23"

0 commit comments

Comments
 (0)