Skip to content

Commit 02d80c2

Browse files
committed
Merge spiceai patches into candle 0.10.1
2 parents 904bf22 + 3ebfe70 commit 02d80c2

92 files changed

Lines changed: 6056 additions & 185 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitmodules

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[submodule "candle-examples/examples/flash-attn/cutlass"]
2+
path = candle-flash-attn/cutlass
3+
url = https://github.com/NVIDIA/cutlass.git
4+
[submodule "candle-flash-attn-v3/cutlass"]
5+
url = https://github.com/NVIDIA/cutlass.git
6+
path = candle-flash-attn-v3/cutlass
7+
[submodule "candle-flash-mla/cutlass"]
8+
path = candle-flash-mla/cutlass
9+
url = https://github.com/NVIDIA/cutlass

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ exclude = [
1818
"candle-kernels",
1919
"candle-metal-kernels",
2020
"candle-onnx",
21+
"candle-flash-mla",
2122
]
2223
resolver = "2"
2324

@@ -39,6 +40,7 @@ candle = { path = "./candle-core", package = "candle-core", version = "0.10.1" }
3940
candle-datasets = { path = "./candle-datasets", version = "0.10.1" }
4041
candle-flash-attn = { path = "./candle-flash-attn", version = "0.10.1" }
4142
candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.10.1" }
43+
candle-flash-mla = { path = "./candle-flash-mla", version = "0.10.1" }
4244
candle-kernels = { path = "./candle-kernels", version = "0.10.1" }
4345
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.10.1" }
4446
candle-nn = { path = "./candle-nn", version = "0.10.1" }

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
66
[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)
77

8+
**This is an optimized implmentation by Eric Buehler.**
9+
810
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
911
and ease of use. Try our online demos:
1012
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),

candle-core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ criterion = { workspace = true }
4646

4747
[features]
4848
default = []
49-
cuda = ["cudarc", "dep:candle-kernels", "candle-ug?/cuda"]
49+
cuda = ["cudarc", "dep:candle-kernels", "float8/cuda", "candle-ug?/cuda"]
5050
cudnn = ["cuda", "cudarc/cudnn"]
5151
nccl = ["cuda", "cudarc/nccl"]
5252
mkl = ["dep:libc", "dep:intel-mkl-src"]

candle-core/src/backend.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,23 @@ pub trait BackendStorage: Sized {
112112
_: usize,
113113
) -> Result<Self>;
114114

115-
fn matmul(
115+
#[allow(clippy::too_many_arguments)]
116+
fn matmul_with_alpha_beta(
116117
&self,
117118
_: &Self,
119+
_: &mut Self,
120+
_: Option<f64>,
121+
_: (usize, usize, usize, usize),
122+
_: &Layout,
123+
_: &Layout,
124+
_: &Layout,
125+
) -> Result<()>;
126+
127+
#[allow(clippy::too_many_arguments)]
128+
fn matmul_with_alpha(
129+
&self,
130+
_: &Self,
131+
_: Option<f64>,
118132
_: (usize, usize, usize, usize),
119133
_: &Layout,
120134
_: &Layout,

candle-core/src/cpu_backend/utils.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,91 @@ pub trait Map2 {
7777
}
7878
}
7979

80+
pub trait Map2Alpha {
81+
const OP: &'static str;
82+
fn f<T: WithDType>(
83+
&self,
84+
v1: &[T],
85+
l1: &Layout,
86+
v2: &[T],
87+
l2: &Layout,
88+
alpha: f64,
89+
) -> Result<Vec<T>>;
90+
91+
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, alpha: f64) -> Result<C> {
92+
match (v1, v2) {
93+
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, alpha)?)),
94+
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, alpha)?)),
95+
(C::I16(v1), C::I16(v2)) => Ok(C::I16(self.f(v1, l1, v2, l2, alpha)?)),
96+
(C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2, alpha)?)),
97+
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, alpha)?)),
98+
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, alpha)?)),
99+
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, alpha)?)),
100+
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, alpha)?)),
101+
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, alpha)?)),
102+
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2, alpha)?)),
103+
_ => Err(Error::DTypeMismatchBinaryOp {
104+
lhs: v1.dtype(),
105+
rhs: v2.dtype(),
106+
op: Self::OP,
107+
}
108+
.bt()),
109+
}
110+
}
111+
}
112+
113+
pub trait Map3 {
114+
const OP: &'static str;
115+
fn f<T: WithDType>(
116+
&self,
117+
v1: &[T],
118+
l1: &Layout,
119+
v2: &[T],
120+
l2: &Layout,
121+
v3: &[T],
122+
l3: &Layout,
123+
) -> Result<Vec<T>>;
124+
125+
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, v3: &C, l3: &Layout) -> Result<C> {
126+
match (v1, v2, v3) {
127+
(C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(C::U8(self.f(v1, l1, v2, l2, v3, l3)?)),
128+
(C::U32(v1), C::U32(v2), C::U32(v3)) => {
129+
Ok(C::U32(self.f(v1, l1, v2, l2, v3, l3)?))
130+
}
131+
(C::I16(v1), C::I16(v2), C::I16(v3)) => {
132+
Ok(C::I16(self.f(v1, l1, v2, l2, v3, l3)?))
133+
}
134+
(C::I32(v1), C::I32(v2), C::I32(v3)) => {
135+
Ok(C::I32(self.f(v1, l1, v2, l2, v3, l3)?))
136+
}
137+
(C::I64(v1), C::I64(v2), C::I64(v3)) => {
138+
Ok(C::I64(self.f(v1, l1, v2, l2, v3, l3)?))
139+
}
140+
(C::BF16(v1), C::BF16(v2), C::BF16(v3)) => {
141+
Ok(C::BF16(self.f(v1, l1, v2, l2, v3, l3)?))
142+
}
143+
(C::F16(v1), C::F16(v2), C::F16(v3)) => {
144+
Ok(C::F16(self.f(v1, l1, v2, l2, v3, l3)?))
145+
}
146+
(C::F32(v1), C::F32(v2), C::F32(v3)) => {
147+
Ok(C::F32(self.f(v1, l1, v2, l2, v3, l3)?))
148+
}
149+
(C::F64(v1), C::F64(v2), C::F64(v3)) => {
150+
Ok(C::F64(self.f(v1, l1, v2, l2, v3, l3)?))
151+
}
152+
(C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => {
153+
Ok(C::F8E4M3(self.f(v1, l1, v2, l2, v3, l3)?))
154+
}
155+
_ => Err(Error::DTypeMismatchBinaryOp {
156+
lhs: v1.dtype(),
157+
rhs: v2.dtype(),
158+
op: Self::OP,
159+
}
160+
.bt()),
161+
}
162+
}
163+
}
164+
80165
pub trait Map2InPlace {
81166
const OP: &'static str;
82167
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;

candle-core/src/device.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ impl Device {
267267
}
268268
}
269269

270+
/// Get the current seed for the device RNG.
270271
pub fn get_current_seed(&self) -> Result<u64> {
271272
match self {
272273
Self::Cpu => CpuDevice.get_current_seed(),
@@ -465,12 +466,12 @@ impl Device {
465466
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
466467
Device::Cuda(device) => {
467468
let storage = array.to_cpu_storage();
468-
let storage = device.storage_from_cpu_storage_owned(storage)?;
469+
let storage = device.storage_from_cpu_storage(&storage)?;
469470
Ok(Storage::Cuda(storage))
470471
}
471472
Device::Metal(device) => {
472473
let storage = array.to_cpu_storage();
473-
let storage = device.storage_from_cpu_storage_owned(storage)?;
474+
let storage = device.storage_from_cpu_storage(&storage)?;
474475
Ok(Storage::Metal(storage))
475476
}
476477
}
@@ -481,12 +482,12 @@ impl Device {
481482
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
482483
Device::Cuda(device) => {
483484
let storage = S::to_cpu_storage_owned(data);
484-
let storage = device.storage_from_cpu_storage_owned(storage)?;
485+
let storage = device.storage_from_cpu_storage(&storage)?;
485486
Ok(Storage::Cuda(storage))
486487
}
487488
Device::Metal(device) => {
488489
let storage = S::to_cpu_storage_owned(data);
489-
let storage = device.storage_from_cpu_storage_owned(storage)?;
490+
let storage = device.storage_from_cpu_storage(&storage)?;
490491
Ok(Storage::Metal(storage))
491492
}
492493
}

candle-core/src/dtype.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Types for elements that can be stored and manipulated using tensors.
22
#![allow(clippy::redundant_closure_call)]
33
use crate::backend::BackendStorage;
4+
use crate::cpu::kernels::VecOps;
45
use crate::{CpuStorage, CpuStorageRef, Error, Result};
56

67
/// The different types of elements allowed in tensors.
@@ -96,6 +97,7 @@ impl DType {
9697
pub fn size_in_bytes(&self) -> usize {
9798
match self {
9899
Self::U8 => 1,
100+
Self::F8E4M3 => 1,
99101
Self::U32 => 4,
100102
Self::I16 => 2,
101103
Self::I32 => 4,
@@ -235,13 +237,14 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
235237
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
236238
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
237239
with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64());
240+
with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64());
238241

239242
pub trait IntDType: WithDType + num_traits::Bounded {
240243
fn is_true(&self) -> bool;
241244
fn as_usize(&self) -> usize;
242245
}
243246

244-
impl IntDType for i64 {
247+
impl IntDType for i16 {
245248
fn is_true(&self) -> bool {
246249
*self != 0
247250
}
@@ -250,7 +253,7 @@ impl IntDType for i64 {
250253
}
251254
}
252255

253-
impl IntDType for u32 {
256+
impl IntDType for i32 {
254257
fn is_true(&self) -> bool {
255258
*self != 0
256259
}
@@ -259,7 +262,7 @@ impl IntDType for u32 {
259262
}
260263
}
261264

262-
impl IntDType for u8 {
265+
impl IntDType for i64 {
263266
fn is_true(&self) -> bool {
264267
*self != 0
265268
}
@@ -268,7 +271,7 @@ impl IntDType for u8 {
268271
}
269272
}
270273

271-
impl IntDType for i16 {
274+
impl IntDType for u32 {
272275
fn is_true(&self) -> bool {
273276
*self != 0
274277
}
@@ -277,7 +280,7 @@ impl IntDType for i16 {
277280
}
278281
}
279282

280-
impl IntDType for i32 {
283+
impl IntDType for u8 {
281284
fn is_true(&self) -> bool {
282285
*self != 0
283286
}

candle-core/src/dummy_cuda_backend.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,23 @@ impl crate::backend::BackendStorage for CudaStorage {
175175
Err(Error::NotCompiledWithCudaSupport)
176176
}
177177

178-
fn matmul(
178+
fn matmul_with_alpha_beta(
179179
&self,
180180
_: &Self,
181+
_: &mut Self,
182+
_: Option<f64>,
183+
_: (usize, usize, usize, usize),
184+
_: &Layout,
185+
_: &Layout,
186+
_: &Layout,
187+
) -> Result<()> {
188+
Err(Error::NotCompiledWithCudaSupport)
189+
}
190+
191+
fn matmul_with_alpha(
192+
&self,
193+
_: &Self,
194+
_: Option<f64>,
181195
_: (usize, usize, usize, usize),
182196
_: &Layout,
183197
_: &Layout,

candle-core/src/dummy_metal_backend.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,23 @@ impl crate::backend::BackendStorage for MetalStorage {
168168
Err(Error::NotCompiledWithMetalSupport)
169169
}
170170

171-
fn matmul(
171+
fn matmul_with_alpha_beta(
172172
&self,
173173
_: &Self,
174+
_: &mut Self,
175+
_: Option<f64>,
176+
_: (usize, usize, usize, usize),
177+
_: &Layout,
178+
_: &Layout,
179+
_: &Layout,
180+
) -> Result<()> {
181+
Err(Error::NotCompiledWithMetalSupport)
182+
}
183+
184+
fn matmul_with_alpha(
185+
&self,
186+
_: &Self,
187+
_: Option<f64>,
174188
_: (usize, usize, usize, usize),
175189
_: &Layout,
176190
_: &Layout,

0 commit comments

Comments
 (0)