Skip to content

Commit fcc2f7e

Browse files
committed
Merge branch 'spiceai' into spiceai-0.9.1
2 parents cd96fa8 + 3ebfe70 commit fcc2f7e

240 files changed

Lines changed: 35455 additions & 4580 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
[submodule "candle-examples/examples/flash-attn/cutlass"]
22
path = candle-flash-attn/cutlass
33
url = https://github.com/NVIDIA/cutlass.git
4+
[submodule "candle-flash-attn-v3/cutlass"]
5+
url = https://github.com/NVIDIA/cutlass.git
6+
path = candle-flash-attn-v3/cutlass
7+
[submodule "candle-flash-mla/cutlass"]
8+
path = candle-flash-mla/cutlass
9+
url = https://github.com/NVIDIA/cutlass

.vscode/settings.json

Lines changed: 0 additions & 11 deletions
This file was deleted.

Cargo.toml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ members = [
1313
exclude = [
1414
"candle-book",
1515
"candle-flash-attn",
16+
"candle-flash-attn-v3",
1617
"candle-kernels",
1718
"candle-metal-kernels",
1819
"candle-onnx",
20+
"candle-flash-mla",
1921
]
2022
resolver = "2"
2123

@@ -36,18 +38,21 @@ byteorder = "1.4.3"
3638
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
3739
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
3840
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
41+
candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.9.1" }
42+
candle-flash-mla = { path = "./candle-flash-mla", version = "0.9.1" }
3943
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
4044
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
4145
candle-nn = { path = "./candle-nn", version = "0.9.1" }
4246
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
4347
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
4448
clap = { version = "4.2.4", features = ["derive"] }
4549
criterion = { version = "0.5.1", default-features=false }
46-
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
50+
cudarc = { version = "0.13.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
4751
fancy-regex = "0.13.0"
4852
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
49-
hf-hub = "0.4.1"
50-
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
53+
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
54+
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
55+
float8 = { version = "0.2.0", features = ["num-traits", "rand_distr"] }
5156
hound = "3.5.1"
5257
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
5358
imageproc = { version = "0.24.0", default-features = false }
@@ -59,7 +64,7 @@ num_cpus = "1.15.0"
5964
num-traits = "0.2.15"
6065
parquet = { version = "51.0.0" }
6166
rand = "0.9.0"
62-
rand_distr = "0.5.1"
67+
rand_distr = "0.5"
6368
rayon = "1.7.0"
6469
safetensors = "0.4.1"
6570
serde = { version = "1.0.171", features = ["derive"] }
@@ -70,9 +75,6 @@ tokenizers = { version = "0.21.0", default-features = false }
7075
tracing = "0.1.37"
7176
tracing-chrome = "0.7.1"
7277
tracing-subscriber = "0.3.7"
73-
ug = "0.4.0"
74-
ug-cuda = "0.4.0"
75-
ug-metal = "0.4.0"
7678
yoke = { version = "0.7.2", features = ["derive"] }
7779
zip = { version = "1.1.1", default-features = false }
7880
metal = { version = "0.27.0", features = ["mps"]}

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
66
[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)
77

8+
**This is an optimized implmentation by Eric Buehler.**
9+
810
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
911
and ease of use. Try our online demos:
1012
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),

candle-core/Cargo.toml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ metal = { workspace = true, optional = true }
1818
cudarc = { workspace = true, optional = true }
1919
gemm = { workspace = true }
2020
half = { workspace = true }
21+
float8 = { workspace = true }
2122
intel-mkl-src = { workspace = true, optional = true }
2223
libc = { workspace = true, optional = true }
2324
memmap2 = { workspace = true }
@@ -28,26 +29,22 @@ rand_distr = { workspace = true }
2829
rayon = { workspace = true }
2930
safetensors = { workspace = true }
3031
thiserror = { workspace = true }
31-
ug-cuda = { workspace = true, optional = true }
32-
ug-metal = { workspace = true, optional = true }
3332
yoke = { workspace = true }
3433
zip = { workspace = true }
3534

36-
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
37-
ug = { workspace = true }
38-
3935
[dev-dependencies]
4036
anyhow = { workspace = true }
4137
clap = { workspace = true }
4238
criterion = { workspace = true }
4339

4440
[features]
4541
default = []
46-
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
42+
cuda = ["cudarc", "dep:candle-kernels", "float8/cuda"]
4743
cudnn = ["cuda", "cudarc/cudnn"]
44+
nccl = ["cuda", "cudarc/nccl"]
4845
mkl = ["dep:libc", "dep:intel-mkl-src"]
4946
accelerate = ["dep:libc", "dep:accelerate-src"]
50-
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
47+
metal = ["dep:metal", "dep:candle-metal-kernels"]
5148

5249
[[bench]]
5350
name = "bench_main"

candle-core/benches/benchmarks/mod.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ impl BenchDevice for Device {
2222
Device::Cpu => Ok(()),
2323
Device::Cuda(device) => {
2424
#[cfg(feature = "cuda")]
25-
return Ok(device
26-
.synchronize()
27-
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
25+
{
26+
use candle_core::cuda::WrapErr;
27+
return Ok(device.synchronize().w()?);
28+
}
2829
#[cfg(not(feature = "cuda"))]
2930
panic!("Cuda device without cuda feature enabled: {:?}", device)
3031
}
3132
Device::Metal(device) => {
3233
#[cfg(feature = "metal")]
33-
return Ok(device.wait_until_completed()?);
34+
return device.wait_until_completed();
3435
#[cfg(not(feature = "metal"))]
3536
panic!("Metal device without metal feature enabled: {:?}", device)
3637
}

candle-core/benches/benchmarks/where_cond.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ const M: usize = 1024;
2222
const K: usize = 1024;
2323
const SIZE: usize = B * M * K;
2424

25-
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
25+
static DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
2626

2727
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
2828
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();

candle-core/src/backend.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,23 @@ pub trait BackendStorage: Sized {
103103
_: usize,
104104
) -> Result<Self>;
105105

106-
fn matmul(
106+
#[allow(clippy::too_many_arguments)]
107+
fn matmul_with_alpha_beta(
108+
&self,
109+
_: &Self,
110+
_: &mut Self,
111+
_: Option<f64>,
112+
_: (usize, usize, usize, usize),
113+
_: &Layout,
114+
_: &Layout,
115+
_: &Layout,
116+
) -> Result<()>;
117+
118+
#[allow(clippy::too_many_arguments)]
119+
fn matmul_with_alpha(
107120
&self,
108121
_: &Self,
122+
_: Option<f64>,
109123
_: (usize, usize, usize, usize),
110124
_: &Layout,
111125
_: &Layout,
@@ -158,6 +172,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
158172
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
159173

160174
fn set_seed(&self, _: u64) -> Result<()>;
175+
fn get_current_seed(&self) -> Result<u64>;
161176

162177
/// Synchronize should block until all the operations on the device are completed.
163178
fn synchronize(&self) -> Result<()>;

candle-core/src/convert.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Implement conversion traits for tensors
22
use crate::{DType, Device, Error, Tensor, WithDType};
3+
use float8::F8E4M3;
34
use half::{bf16, f16, slice::HalfFloatSliceExt};
45
use std::convert::TryFrom;
56

@@ -130,6 +131,16 @@ impl Tensor {
130131
f.write_u32::<LittleEndian>(v)?
131132
}
132133
}
134+
DType::I16 => {
135+
for v in vs.to_vec1::<i16>()? {
136+
f.write_i16::<LittleEndian>(v)?
137+
}
138+
}
139+
DType::I32 => {
140+
for v in vs.to_vec1::<i32>()? {
141+
f.write_i32::<LittleEndian>(v)?
142+
}
143+
}
133144
DType::I64 => {
134145
for v in vs.to_vec1::<i64>()? {
135146
f.write_i64::<LittleEndian>(v)?
@@ -139,6 +150,11 @@ impl Tensor {
139150
let vs = vs.to_vec1::<u8>()?;
140151
f.write_all(&vs)?;
141152
}
153+
DType::F8E4M3 => {
154+
for v in vs.to_vec1::<F8E4M3>()? {
155+
f.write_u8(v.to_bits())?
156+
}
157+
}
142158
}
143159
Ok(())
144160
}

candle-core/src/cpu/avx.rs

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use super::{Cpu, CpuF16};
1+
use super::{Cpu, CpuBF16, CpuF16};
22
#[cfg(target_arch = "x86")]
33
use core::arch::x86::*;
44
#[cfg(target_arch = "x86_64")]
55
use core::arch::x86_64::*;
66

7-
use half::f16;
7+
use half::{bf16, f16};
88

99
pub struct CurrentCpu {}
1010

@@ -146,3 +146,82 @@ impl CpuF16<ARR> for CurrentCpuF16 {
146146
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
147147
}
148148
}
149+
150+
pub struct CurrentCpuBF16 {}
151+
impl CpuBF16<ARR> for CurrentCpuBF16 {
152+
type Unit = __m256;
153+
type Array = [__m256; ARR];
154+
155+
const STEP: usize = STEP;
156+
const EPR: usize = EPR;
157+
158+
fn n() -> usize {
159+
ARR
160+
}
161+
162+
unsafe fn zero() -> Self::Unit {
163+
_mm256_setzero_ps()
164+
}
165+
166+
unsafe fn zero_array() -> Self::Array {
167+
[Self::zero(); ARR]
168+
}
169+
170+
unsafe fn from_f32(v: f32) -> Self::Unit {
171+
_mm256_set1_ps(v)
172+
}
173+
174+
#[cfg(target_feature = "f16c")]
175+
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
176+
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
177+
}
178+
179+
#[cfg(not(target_feature = "f16c"))]
180+
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
181+
let mut tmp = [0.0f32; 8];
182+
for i in 0..8 {
183+
tmp[i] = (*mem_addr.add(i)).to_f32();
184+
}
185+
_mm256_loadu_ps(tmp.as_ptr())
186+
}
187+
188+
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
189+
_mm256_add_ps(a, b)
190+
}
191+
192+
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
193+
_mm256_add_ps(_mm256_mul_ps(b, c), a)
194+
}
195+
196+
#[cfg(target_feature = "f16c")]
197+
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
198+
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
199+
}
200+
201+
#[cfg(not(target_feature = "f16c"))]
202+
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
203+
let mut tmp = [0.0f32; 8];
204+
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
205+
for i in 0..8 {
206+
*mem_addr.add(i) = bf16::from_f32(tmp[i]);
207+
}
208+
}
209+
210+
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
211+
let mut offset = ARR >> 1;
212+
for i in 0..offset {
213+
x[i] = _mm256_add_ps(x[i], x[offset + i]);
214+
}
215+
offset >>= 1;
216+
for i in 0..offset {
217+
x[i] = _mm256_add_ps(x[i], x[offset + i]);
218+
}
219+
offset >>= 1;
220+
for i in 0..offset {
221+
x[i] = _mm256_add_ps(x[i], x[offset + i]);
222+
}
223+
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
224+
let t1 = _mm_hadd_ps(t0, t0);
225+
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
226+
}
227+
}

0 commit comments

Comments
 (0)