Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "candle-flash-attn-v3/cutlass"]
url = https://github.com/NVIDIA/cutlass.git
path = candle-flash-attn-v3/cutlass
[submodule "candle-flash-mla/cutlass"]
path = candle-flash-mla/cutlass
url = https://github.com/NVIDIA/cutlass
8 changes: 5 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ exclude = [
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
"candle-flash-mla",
]
resolver = "2"

Expand All @@ -38,6 +39,7 @@ candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" }
candle-datasets = { path = "./candle-datasets", version = "0.8.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" }
candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.8.0" }
candle-flash-mla = { path = "./candle-flash-mla", version = "0.8.0" }
candle-kernels = { path = "./candle-kernels", version = "0.8.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" }
candle-nn = { path = "./candle-nn", version = "0.8.0" }
Expand All @@ -50,7 +52,7 @@ fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
float8 = { version = "0.1.2", features = ["num-traits", "rand_distr"] }
float8 = { version = "0.2.0", features = ["num-traits", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
Expand All @@ -61,8 +63,8 @@ memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "51.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rand = "0.9.0"
rand_distr = "0.5"
rayon = "1.7.0"
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] }
Expand Down
26 changes: 17 additions & 9 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3159,6 +3159,9 @@ impl BackendStorage for CpuStorage {
(Self::F64(src), Self::F64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F8E4M3(src), Self::F8E4M3(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(_, dst) => {
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
Expand All @@ -3182,6 +3185,9 @@ impl BackendStorage for CpuStorage {
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F8E4M3(src), Self::F8E4M3(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_l)
}
(_, dst) => {
// This should be covered by the dtype check above.
return Err(Error::DTypeMismatchBinaryOp {
Expand Down Expand Up @@ -3540,24 +3546,24 @@ impl BackendDevice for CpuDevice {
use rand::prelude::*;

let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
match dtype {
DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
}
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<bf16, _>(uniform))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f16, _>(uniform))
}
Expand All @@ -3566,23 +3572,25 @@ impl BackendDevice for CpuDevice {
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max));
rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<F8E4M3, _>(uniform))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
let uniform =
rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(uniform))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min, max);
let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(uniform))
}
Expand All @@ -3595,7 +3603,7 @@ impl BackendDevice for CpuDevice {
use rand::prelude::*;

let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
match dtype {
DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/pickle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ impl PthTensors {
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
/// contains multiple objects and the `state_dict` is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
Expand Down
77 changes: 77 additions & 0 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ use half::{bf16, f16};

pub use k_quants::GgmlType;

fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
let size = std::mem::size_of::<T>();
assert_eq!(
data.len() % size,
0,
"Data length must be a multiple of T's size"
);
let ptr = data.as_ptr();
assert_eq!(
(ptr as usize) % std::mem::align_of::<T>(),
0,
"Data pointer must be aligned to T's alignment"
);
unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
}

pub struct QTensor {
storage: QStorage,
shape: Shape,
Expand Down Expand Up @@ -63,6 +79,46 @@ pub enum QStorage {
}

impl QStorage {
pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
match device {
Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
Device::Metal(d) => match dtype {
GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
},
Device::Cuda(d) => match dtype {
GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
},
}
}

fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
Expand Down Expand Up @@ -267,6 +323,27 @@ impl GgmlDType {
Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
}
}

pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
}
}

/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
use k_quants::*;
Expand Down
3 changes: 3 additions & 0 deletions candle-flash-attn-v3/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ fn main() -> Result<()> {
command.arg("-DCUTLASS_DEBUG_TRACE_LEVEL=0");
command.arg("-DNDEBUG");

// https://github.com/EricLBuehler/mistral.rs/issues/941
command.arg("-D_USE_MATH_DEFINES");

if let Some(ccbin_path) = &ccbin_env {
command.arg("-allow-unsupported-compiler");
command.args(["-ccbin", ccbin_path]);
Expand Down
2 changes: 2 additions & 0 deletions candle-flash-attn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ fn main() -> Result<()> {
builder = builder.arg("-D_USE_MATH_DEFINES");
}
}
// https://github.com/EricLBuehler/mistral.rs/issues/941
builder = builder.arg("-D_USE_MATH_DEFINES");

// https://github.com/EricLBuehler/mistral.rs/issues/286
// https://github.com/huggingface/candle-flash-attn-v1/pull/2
Expand Down
7 changes: 7 additions & 0 deletions candle-flash-mla/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.idea
target
Cargo.lock
.venv
hkernel/build/*
__pycache__
*.egg-info
3 changes: 3 additions & 0 deletions candle-flash-mla/.gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "cutlass"]
path = cutlass
url = https://github.com/NVIDIA/cutlass.git
24 changes: 24 additions & 0 deletions candle-flash-mla/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "candle-flash-mla"
version = "0.8.0"
edition = "2021"

description = "Flash MLA layer for the candle ML framework."
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
readme = "README.md"

[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.0" }
half = { version = "2.3.1", features = ["num-traits"] }

[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"

[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", features = ["cuda"] }
rstest = "0.23"
Loading