Skip to content

Commit beeadf9

Browse files
committed
Formatting
1 parent 4d42f63 commit beeadf9

10 files changed

Lines changed: 146 additions & 98 deletions

File tree

candle-core/src/cpu_backend/mod.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,7 +2607,7 @@ impl BackendStorage for CpuStorage {
26072607
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
26082608
.transpose(1, 2)?
26092609
.broadcast_as((b, k, n))?;
2610-
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2610+
col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)?
26112611
} else {
26122612
// Make the kernel contiguous if not already the case.
26132613
let mut kernel_c = unsafe {
@@ -2618,7 +2618,7 @@ impl BackendStorage for CpuStorage {
26182618
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
26192619
.transpose(1, 2)?
26202620
.broadcast_as((b, k, n))?;
2621-
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2621+
col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)?
26222622
};
26232623
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
26242624
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
@@ -2659,8 +2659,9 @@ impl BackendStorage for CpuStorage {
26592659
vec![0, k_size * c_out, 1],
26602660
kernel_l.start_offset(),
26612661
);
2662-
self.matmul(
2662+
self.matmul_with_alpha(
26632663
kernel,
2664+
None,
26642665
(
26652666
b_size,
26662667
/* m */ l_in,
@@ -3144,11 +3145,6 @@ impl BackendDevice for CpuDevice {
31443145
Ok(storage)
31453146
}
31463147

3147-
fn get_current_seed(&self) -> Result<u64> {
3148-
// CPU backend doesn't maintain a seed state
3149-
Ok(0)
3150-
}
3151-
31523148
fn synchronize(&self) -> Result<()> {
31533149
Ok(())
31543150
}

candle-core/src/storage.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ impl Storage {
738738
}
739739
}
740740

741+
#[allow(dead_code)]
741742
#[allow(clippy::too_many_arguments)]
742743
pub(crate) fn matmul_with_alpha_beta(
743744
&self,

candle-core/src/tensor.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,6 +1425,122 @@ impl Tensor {
14251425
Ok(from_storage(storage, c_shape, op, false))
14261426
}
14271427

1428+
/// Matrix-multiplication with a scalar multiplier (alpha).
1429+
///
1430+
/// Computes `alpha * (self @ rhs)` where `@` represents matrix multiplication.
1431+
/// If `alpha` is `None`, it defaults to 1.0.
1432+
///
1433+
/// # Arguments
1434+
///
1435+
/// * `rhs` - The right-hand side matrix.
1436+
/// * `alpha` - Optional scalar multiplier applied to the result.
1437+
pub fn matmul_with_alpha(&self, rhs: &Self, alpha: Option<f64>) -> Result<Self> {
1438+
let a_dims = self.shape().dims();
1439+
let b_dims = rhs.shape().dims();
1440+
1441+
let dim = a_dims.len();
1442+
1443+
if dim < 2 || b_dims.len() != dim {
1444+
Err(Error::ShapeMismatchBinaryOp {
1445+
lhs: self.shape().clone(),
1446+
rhs: rhs.shape().clone(),
1447+
op: "matmul_with_alpha",
1448+
}
1449+
.bt())?
1450+
}
1451+
1452+
let m = a_dims[dim - 2];
1453+
let k = a_dims[dim - 1];
1454+
let k2 = b_dims[dim - 2];
1455+
let n = b_dims[dim - 1];
1456+
1457+
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1458+
if c_shape.elem_count() == 0 || k == 0 {
1459+
return Tensor::zeros(c_shape, self.dtype(), self.device());
1460+
}
1461+
let batching: usize = a_dims[..dim - 2].iter().product();
1462+
let batching_b: usize = b_dims[..dim - 2].iter().product();
1463+
if k != k2 || batching != batching_b {
1464+
Err(Error::ShapeMismatchBinaryOp {
1465+
lhs: self.shape().clone(),
1466+
rhs: rhs.shape().clone(),
1467+
op: "matmul_with_alpha",
1468+
}
1469+
.bt())?
1470+
}
1471+
1472+
let storage = self.storage().matmul_with_alpha(
1473+
&rhs.storage(),
1474+
alpha,
1475+
(batching, m, n, k),
1476+
self.layout(),
1477+
rhs.layout(),
1478+
)?;
1479+
// Note: No backprop for alpha-scaled matmul for now
1480+
let op = BackpropOp::none();
1481+
Ok(from_storage(storage, c_shape, op, false))
1482+
}
1483+
1484+
/// Matrix-multiplication with alpha and beta scaling, using a mutable output tensor.
1485+
///
1486+
/// Computes `c = alpha * (self @ rhs) + beta * c` where `@` represents matrix multiplication.
1487+
/// This is an in-place operation that modifies `c`.
1488+
/// If `alpha` is `None`, it defaults to 1.0. Beta is implicitly 1.0.
1489+
///
1490+
/// # Arguments
1491+
///
1492+
/// * `rhs` - The right-hand side matrix.
1493+
/// * `c` - The mutable output tensor that will be modified in-place.
1494+
/// * `alpha` - Optional scalar multiplier applied to the matmul result.
1495+
pub fn matmul_with_alpha_beta(
1496+
&self,
1497+
rhs: &Self,
1498+
c: &mut Self,
1499+
alpha: Option<f64>,
1500+
) -> Result<()> {
1501+
let a_dims = self.shape().dims();
1502+
let b_dims = rhs.shape().dims();
1503+
let c_dims = c.shape().dims();
1504+
1505+
let dim = a_dims.len();
1506+
1507+
if dim < 2 || b_dims.len() != dim || c_dims.len() != dim {
1508+
Err(Error::ShapeMismatchBinaryOp {
1509+
lhs: self.shape().clone(),
1510+
rhs: rhs.shape().clone(),
1511+
op: "matmul_with_alpha_beta",
1512+
}
1513+
.bt())?
1514+
}
1515+
1516+
let m = a_dims[dim - 2];
1517+
let k = a_dims[dim - 1];
1518+
let k2 = b_dims[dim - 2];
1519+
let n = b_dims[dim - 1];
1520+
1521+
let batching: usize = a_dims[..dim - 2].iter().product();
1522+
let batching_b: usize = b_dims[..dim - 2].iter().product();
1523+
if k != k2 || batching != batching_b {
1524+
Err(Error::ShapeMismatchBinaryOp {
1525+
lhs: self.shape().clone(),
1526+
rhs: rhs.shape().clone(),
1527+
op: "matmul_with_alpha_beta",
1528+
}
1529+
.bt())?
1530+
}
1531+
1532+
self.storage().matmul_with_alpha_beta(
1533+
&rhs.storage(),
1534+
&mut c.storage_mut(),
1535+
alpha,
1536+
(batching, m, n, k),
1537+
self.layout(),
1538+
rhs.layout(),
1539+
c.layout(),
1540+
)?;
1541+
Ok(())
1542+
}
1543+
14281544
/// Matrix-multiplication with broadcasting support.
14291545
///
14301546
/// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as

candle-core/src/tensor_indexing.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
op::{BackpropOp, Op},
66
shape::Dim,
77
tensor::from_storage,
8-
DType, Error, Result, Tensor,
8+
DType, Error, Layout, Result, Tensor,
99
};
1010

1111
/// Specialization of `std::ops::RangeBounds` for `usize` to allow trait objects.
@@ -171,8 +171,13 @@ impl Tensor {
171171
}
172172
.bt())?
173173
}
174-
let storage = self.storage().scatter_add(
175-
self.layout(),
174+
let shape = self.shape();
175+
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
176+
self.storage()
177+
.copy_strided_src(&mut storage, 0, self.layout())?;
178+
let layout = Layout::contiguous(shape);
179+
storage.scatter_add(
180+
&layout,
176181
&indexes.storage(),
177182
indexes.layout(),
178183
&source.storage(),

candle-nn/benches/benchmarks/attention.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
22
use candle::{DType, Device, Tensor};
33
use candle_nn::scaled_dot_product_attention;
4-
use criterion::{black_box, criterion_group, Criterion, Throughput};
4+
use criterion::{criterion_group, Criterion, Throughput};
5+
use std::hint::black_box;
56
use std::time::Instant;
67

78
fn run_attention(q: &Tensor, k: &Tensor, v: &Tensor, m: &Tensor, s: f64) {

candle-nn/src/layer_norm.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,13 @@ pub fn rms_norm_quant(
351351
_ghost: PhantomData,
352352
})
353353
}
354+
355+
/// Create an RmsNorm layer (non-quantized version).
356+
/// This is an alias for `rms_norm_non_quant` for convenience.
357+
pub fn rms_norm(
358+
size: usize,
359+
eps: f64,
360+
vb: crate::VarBuilder,
361+
) -> Result<RmsNorm<RmsNormNonQuantized>> {
362+
rms_norm_non_quant(size, eps, vb)
363+
}

candle-nn/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ pub use func::{func, func_t, Func, FuncT};
5252
pub use group_norm::{group_norm, GroupNorm};
5353
pub use init::Init;
5454
pub use layer_norm::{
55-
layer_norm, layer_norm_no_bias, rms_norm_non_quant, rms_norm_quant, LayerNorm, LayerNormConfig,
56-
RmsNorm,
55+
layer_norm, layer_norm_no_bias, rms_norm, rms_norm_non_quant, rms_norm_quant, LayerNorm,
56+
LayerNormConfig, RmsNorm,
5757
};
5858
pub use linear::{linear, linear_b, linear_no_bias, Linear};
59-
pub use ops::{kvconcat, Dropout};
59+
pub use ops::Dropout;
6060
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
6161
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
6262
pub use rope::RotaryEmbedding;

candle-nn/src/var_builder.rs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -317,32 +317,6 @@ impl SimpleBackend for HashMap<String, Tensor> {
317317
tensor.to_device(dev)?.to_dtype(dtype)
318318
}
319319

320-
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
321-
let tensor = self
322-
.get(name)
323-
.ok_or_else(|| {
324-
Error::CannotFindTensor {
325-
path: name.to_string(),
326-
}
327-
.bt()
328-
})?
329-
.clone();
330-
tensor.to_device(dev)?.to_dtype(dtype)
331-
}
332-
333-
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
334-
let tensor = self
335-
.get(name)
336-
.ok_or_else(|| {
337-
Error::CannotFindTensor {
338-
path: name.to_string(),
339-
}
340-
.bt()
341-
})?
342-
.clone();
343-
tensor.to_device(dev)?.to_dtype(dtype)
344-
}
345-
346320
fn contains_tensor(&self, name: &str) -> bool {
347321
self.contains_key(name)
348322
}

candle-nn/tests/ops.rs

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ extern crate intel_mkl_src;
55
extern crate accelerate_src;
66

77
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
8-
use candle_nn::Activation;
98

109
fn softmax(device: &Device) -> Result<()> {
1110
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
@@ -53,22 +52,6 @@ fn softmax(device: &Device) -> Result<()> {
5352
Ok(())
5453
}
5554

56-
fn inplace_softmax(device: &Device) -> Result<()> {
57-
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
58-
let mut tensor = Tensor::new(data, device)?.log()?;
59-
candle_nn::ops::inplace_softmax_last_dim(&mut tensor)?;
60-
assert_eq!(
61-
to_vec3_round(&tensor, 4)?,
62-
&[
63-
// (3, 1, 4) / 8, (1, 5, 9) / 15
64-
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
65-
// (2, 1, 7) / 10, (8, 2, 8) / 18
66-
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
67-
]
68-
);
69-
Ok(())
70-
}
71-
7255
fn rms_norm(device: &Device) -> Result<()> {
7356
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
7457
let tensor = Tensor::new(data, device)?;
@@ -341,44 +324,12 @@ fn sigmoid(device: &Device) -> Result<()> {
341324
Ok(())
342325
}
343326

344-
fn mul_and_act(device: &Device) -> Result<()> {
345-
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
346-
let cpu = Tensor::new(data, &Device::Cpu)?;
347-
let x = Tensor::new(data, device)?;
348-
349-
for act in [Activation::Gelu, Activation::Relu, Activation::Silu] {
350-
let truth = candle_nn::ops::mul_and_act(&cpu, &cpu, act)?;
351-
let test = candle_nn::ops::mul_and_act(&x, &x, act)?.to_device(&Device::Cpu)?;
352-
353-
let sum_diff = (truth - test)?.abs()?.sum_all()?.to_vec0::<f32>()?;
354-
if device.is_cpu() {
355-
assert_eq!(sum_diff, 0., "act = {act:?}");
356-
} else {
357-
assert!(sum_diff < 3e-3, "act = {act:?}");
358-
}
359-
}
360-
361-
Ok(())
362-
}
363-
364327
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
365328
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
366329
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
367330
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
368-
test_device!(
369-
inplace_softmax,
370-
inplace_softmax_cpu,
371-
inplace_softmax_gpu,
372-
inplace_softmax_metal
373-
);
374331
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
375332
test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal);
376333
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
377334
test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal);
378335
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
379-
test_device!(
380-
mul_and_act,
381-
mul_and_act_cpu,
382-
mul_and_act_gpu,
383-
mul_and_act_metal
384-
);

candle-pyo3/src/lib.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ extern crate intel_mkl_src;
1717
#[cfg(feature = "accelerate")]
1818
extern crate accelerate_src;
1919

20-
use candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
20+
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
2121

2222
mod utils;
2323
use utils::wrap_err;
@@ -217,12 +217,6 @@ trait MapDType {
217217
DType::F16 => self.f::<f16>(t),
218218
DType::F32 => self.f::<f32>(t),
219219
DType::F64 => self.f::<f64>(t),
220-
DType::I16 => Err(PyErr::new::<PyTypeError, _>(
221-
"i16 dtype is not supported in Python interface",
222-
)),
223-
DType::I32 => Err(PyErr::new::<PyTypeError, _>(
224-
"i32 dtype is not supported in Python interface",
225-
)),
226220
DType::F8E4M3 => Err(PyErr::new::<PyTypeError, _>(
227221
"f8e4m3 dtype is not supported in Python interface",
228222
)),
@@ -1104,7 +1098,7 @@ impl PyTensor {
11041098
/// Quantize the tensor.
11051099
/// &RETURNS&: QTensor
11061100
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
1107-
use candle::quantized;
1101+
use ::candle::quantized;
11081102
let res = match quantized_dtype.to_lowercase().as_str() {
11091103
"q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K),
11101104
"q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K),

0 commit comments

Comments
 (0)