Skip to content

Commit 3ebfe70

Browse files
authored
Merge pull request #7 from spiceai/jeadie/25-04-15/upstream-spiceai
fix other things too
2 parents 1b02ddb + 8823085 commit 3ebfe70

15 files changed

Lines changed: 44 additions & 41 deletions

File tree

candle-core/src/cpu_backend/mod.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,19 +3553,21 @@ impl BackendDevice for CpuDevice {
35533553
}
35543554
DType::BF16 => {
35553555
let mut data = Vec::with_capacity(elem_count);
3556-
let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
3557-
.map_err(Error::wrap)?;
3556+
let normal: rand_distr::Uniform<f32> =
3557+
rand_distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
35583558
for _i in 0..elem_count {
3559-
data.push(rng.sample::<bf16, _>(uniform))
3559+
let sample: f32 = normal.sample(&mut rng);
3560+
data.push(bf16::from_f32(sample));
35603561
}
35613562
Ok(CpuStorage::BF16(data))
35623563
}
35633564
DType::F16 => {
35643565
let mut data = Vec::with_capacity(elem_count);
3565-
let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
3566-
.map_err(Error::wrap)?;
3566+
let normal: rand_distr::Uniform<f32> =
3567+
rand_distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
35673568
for _i in 0..elem_count {
3568-
data.push(rng.sample::<f16, _>(uniform))
3569+
let sample: f32 = normal.sample(&mut rng);
3570+
data.push(f16::from_f32(sample));
35693571
}
35703572
Ok(CpuStorage::F16(data))
35713573
}
@@ -3610,19 +3612,21 @@ impl BackendDevice for CpuDevice {
36103612
}
36113613
DType::BF16 => {
36123614
let mut data = Vec::with_capacity(elem_count);
3613-
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
3614-
.map_err(Error::wrap)?;
3615+
let normal: rand_distr::Normal<f32> =
3616+
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
36153617
for _i in 0..elem_count {
3616-
data.push(normal.sample(&mut rng))
3618+
let sample: f32 = normal.sample(&mut rng);
3619+
data.push(bf16::from_f32(sample));
36173620
}
36183621
Ok(CpuStorage::BF16(data))
36193622
}
36203623
DType::F16 => {
36213624
let mut data = Vec::with_capacity(elem_count);
3622-
let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
3623-
.map_err(Error::wrap)?;
3625+
let normal: rand_distr::Normal<f32> =
3626+
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
36243627
for _i in 0..elem_count {
3625-
data.push(normal.sample(&mut rng))
3628+
let sample: f32 = normal.sample(&mut rng);
3629+
data.push(f16::from_f32(sample));
36263630
}
36273631
Ok(CpuStorage::F16(data))
36283632
}

candle-core/src/pickle.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ impl PthTensors {
792792
/// # Arguments
793793
/// * `path` - Path to the pth file.
794794
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
795-
/// contains multiple objects and the `state_dict` is the one we are interested in.
795+
/// contains multiple objects and the `state_dict` is the one we are interested in.
796796
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
797797
path: P,
798798
key: Option<&str>,

candle-core/tests/quantized_tests.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use candle_core::{
66
DType, Device, IndexOp, Module, Result, Tensor, Var,
77
};
88
use quantized::{k_quants, GgmlType};
9-
use rand::prelude::*;
9+
use rand::{prelude::*, random};
1010

1111
const GGML_TEST_SIZE: usize = 32 * 128;
1212

@@ -1110,13 +1110,11 @@ fn get_random_tensors(
11101110
n: usize,
11111111
device: &Device,
11121112
) -> Result<(Tensor, Tensor, Tensor)> {
1113-
let mut rng = StdRng::seed_from_u64(314159265358979);
1114-
11151113
let lhs = (0..m * k)
1116-
.map(|_| rng.gen::<f32>() - 0.5)
1114+
.map(|_| random::<f32>() - 0.5)
11171115
.collect::<Vec<_>>();
11181116
let rhs = (0..n * k)
1119-
.map(|_| rng.gen::<f32>() - 0.5)
1117+
.map(|_| random::<f32>() - 0.5)
11201118
.collect::<Vec<_>>();
11211119

11221120
let lhs = Tensor::from_vec(lhs, (m, k), device)?;

candle-examples/examples/mamba-minimal/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ impl Config {
2121
}
2222

2323
fn dt_rank(&self) -> usize {
24-
(self.d_model + 15) / 16
24+
self.d_model.div_ceil(16)
2525
}
2626

2727
fn d_conv(&self) -> usize {

candle-examples/examples/metavoice/main.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extern crate accelerate_src;
66

77
use anyhow::Result;
88
use clap::Parser;
9+
use rand::distr::Distribution;
910
use std::io::Write;
1011

1112
use candle_transformers::generation::LogitsProcessor;
@@ -16,7 +17,7 @@ use candle_transformers::models::quantized_metavoice::transformer as qtransforme
1617
use candle::{DType, IndexOp, Tensor};
1718
use candle_nn::VarBuilder;
1819
use hf_hub::api::sync::Api;
19-
use rand::{distributions::Distribution, SeedableRng};
20+
use rand::SeedableRng;
2021

2122
pub const ENCODEC_NTOKENS: u32 = 1024;
2223

@@ -250,7 +251,7 @@ fn main() -> Result<()> {
250251
let logits = logits.i(step)?.to_dtype(DType::F32)?;
251252
let logits = &(&logits / 1.0)?;
252253
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
253-
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
254+
let distr = rand::distr::weighted::WeightedIndex::new(prs.as_slice())?;
254255
let sample = distr.sample(&mut rng) as u32;
255256
codes_.push(sample)
256257
}

candle-nn/src/loss.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use candle::{Result, Tensor};
77
/// Arguments
88
///
99
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
10-
/// of categories. This is expected to contain log probabilities.
10+
/// of categories. This is expected to contain log probabilities.
1111
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
1212
///
1313
/// The resulting tensor is a scalar containing the average value over the batch.
@@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
3434
/// Arguments
3535
///
3636
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
37-
/// of categories. This is expected to raw logits.
37+
/// of categories. This is expected to raw logits.
3838
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
3939
///
4040
/// The resulting tensor is a scalar containing the average value over the batch.
@@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
5656
/// Arguments
5757
///
5858
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
59-
/// of categories. This is expected to raw logits.
59+
/// of categories. This is expected to raw logits.
6060
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
61-
/// of categories.
61+
/// of categories.
6262
///
6363
/// The resulting tensor is a scalar containing the average value over the batch.
6464
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {

candle-nn/src/var_builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ impl SimpleBackend for candle::npy::NpzTensors {
423423
}
424424

425425
fn contains_tensor(&self, name: &str) -> bool {
426-
self.get(name).map_or(false, |v| v.is_some())
426+
self.get(name).is_ok_and(|v| v.is_some())
427427
}
428428
}
429429

@@ -461,7 +461,7 @@ impl SimpleBackend for candle::pickle::PthTensors {
461461
}
462462

463463
fn contains_tensor(&self, name: &str) -> bool {
464-
self.get(name).map_or(false, |v| v.is_some())
464+
self.get(name).is_ok_and(|v| v.is_some())
465465
}
466466
}
467467

candle-transformers/src/generation/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use candle::{DType, Error, Result, Tensor};
2-
use rand::{distributions::Distribution, SeedableRng};
2+
use rand::{distr::Distribution, SeedableRng};
33

44
#[derive(Clone, PartialEq, Debug)]
55
pub enum Sampling {
@@ -45,7 +45,7 @@ impl LogitsProcessor {
4545
}
4646

4747
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
48-
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
48+
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
4949
let next_token = distr.sample(&mut self.rng) as u32;
5050
Ok(next_token)
5151
}

candle-transformers/src/models/dac.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl EncoderBlock {
9999
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
100100
let cfg1 = Conv1dConfig {
101101
stride,
102-
padding: (stride + 1) / 2,
102+
padding: stride.div_ceil(2),
103103
..Default::default()
104104
};
105105
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
@@ -191,7 +191,7 @@ impl DecoderBlock {
191191
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
192192
let cfg = ConvTranspose1dConfig {
193193
stride,
194-
padding: (stride + 1) / 2,
194+
padding: stride.div_ceil(2),
195195
..Default::default()
196196
};
197197
let conv_tr1 = encodec::conv_transpose1d_weight_norm(

candle-transformers/src/models/flux/sampling.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ pub fn get_noise(
66
width: usize,
77
device: &Device,
88
) -> Result<Tensor> {
9-
let height = (height + 15) / 16 * 2;
10-
let width = (width + 15) / 16 * 2;
9+
let height = height.div_ceil(16) * 2;
10+
let width = width.div_ceil(16) * 2;
1111
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
1212
}
1313

@@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
8484

8585
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
8686
let (b, _h_w, c_ph_pw) = xs.dims3()?;
87-
let height = (height + 15) / 16;
88-
let width = (width + 15) / 16;
87+
let height = height.div_ceil(16);
88+
let width = width.div_ceil(16);
8989
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
9090
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
9191
.reshape((b, c_ph_pw / 4, height * 2, width * 2))

0 commit comments

Comments
 (0)