Skip to content

Commit 81e8f39

Browse files
committed
fix other things too
1 parent 93af275 commit 81e8f39

4 files changed

Lines changed: 21 additions & 17 deletions

File tree

  • candle-core/src/cpu_backend
  • candle-examples/examples/metavoice
  • candle-transformers/src/generation
  • candle-wasm-examples/whisper/src

candle-core/src/cpu_backend/mod.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,19 +3553,21 @@ impl BackendDevice for CpuDevice {
35533553
}
35543554
DType::BF16 => {
35553555
let mut data = Vec::with_capacity(elem_count);
3556-
let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
3557-
.map_err(Error::wrap)?;
3556+
let normal: rand_distr::Uniform<f32> =
3557+
rand_distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
35583558
for _i in 0..elem_count {
3559-
data.push(rng.sample::<bf16, _>(uniform))
3559+
let sample: f32 = normal.sample(&mut rng);
3560+
data.push(bf16::from_f32(sample));
35603561
}
35613562
Ok(CpuStorage::BF16(data))
35623563
}
35633564
DType::F16 => {
35643565
let mut data = Vec::with_capacity(elem_count);
3565-
let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
3566-
.map_err(Error::wrap)?;
3566+
let normal: rand_distr::Uniform<f32> =
3567+
rand_distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
35673568
for _i in 0..elem_count {
3568-
data.push(rng.sample::<f16, _>(uniform))
3569+
let sample: f32 = normal.sample(&mut rng);
3570+
data.push(f16::from_f32(sample));
35693571
}
35703572
Ok(CpuStorage::F16(data))
35713573
}
@@ -3610,19 +3612,21 @@ impl BackendDevice for CpuDevice {
36103612
}
36113613
DType::BF16 => {
36123614
let mut data = Vec::with_capacity(elem_count);
3613-
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
3614-
.map_err(Error::wrap)?;
3615+
let normal: rand_distr::Normal<f32> =
3616+
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
36153617
for _i in 0..elem_count {
3616-
data.push(normal.sample(&mut rng))
3618+
let sample: f32 = normal.sample(&mut rng);
3619+
data.push(bf16::from_f32(sample));
36173620
}
36183621
Ok(CpuStorage::BF16(data))
36193622
}
36203623
DType::F16 => {
36213624
let mut data = Vec::with_capacity(elem_count);
3622-
let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
3623-
.map_err(Error::wrap)?;
3625+
let normal: rand_distr::Normal<f32> =
3626+
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
36243627
for _i in 0..elem_count {
3625-
data.push(normal.sample(&mut rng))
3628+
let sample: f32 = normal.sample(&mut rng);
3629+
data.push(f16::from_f32(sample));
36263630
}
36273631
Ok(CpuStorage::F16(data))
36283632
}

candle-examples/examples/metavoice/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ fn main() -> Result<()> {
250250
let logits = logits.i(step)?.to_dtype(DType::F32)?;
251251
let logits = &(&logits / 1.0)?;
252252
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
253-
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
253+
let distr = rand::distr::weighted::WeightedIndex::new(prs.as_slice())?;
254254
let sample = distr.sample(&mut rng) as u32;
255255
codes_.push(sample)
256256
}

candle-transformers/src/generation/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use candle::{DType, Error, Result, Tensor};
2-
use rand::{distributions::Distribution, SeedableRng};
2+
use rand::{distr::Distribution, SeedableRng};
33

44
#[derive(Clone, PartialEq, Debug)]
55
pub enum Sampling {
@@ -45,7 +45,7 @@ impl LogitsProcessor {
4545
}
4646

4747
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
48-
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
48+
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
4949
let next_token = distr.sample(&mut self.rng) as u32;
5050
Ok(next_token)
5151
}

candle-wasm-examples/whisper/src/worker.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use anyhow::Error as E;
33
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
44
use candle_nn::{ops::softmax, VarBuilder};
55
pub use candle_transformers::models::whisper::{self as m, Config};
6-
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
6+
use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
77
use serde::{Deserialize, Serialize};
88
use tokenizers::Tokenizer;
99
use wasm_bindgen::prelude::*;
@@ -221,7 +221,7 @@ impl Decoder {
221221
let next_token = if t > 0f64 {
222222
let prs = softmax(&(&logits / t)?, 0)?;
223223
let logits_v: Vec<f32> = prs.to_vec1()?;
224-
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
224+
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
225225
distr.sample(&mut self.rng) as u32
226226
} else {
227227
let logits_v: Vec<f32> = logits.to_vec1()?;

0 commit comments

Comments
 (0)