@@ -3553,19 +3553,21 @@ impl BackendDevice for CpuDevice {
35533553 }
35543554 DType :: BF16 => {
35553555 let mut data = Vec :: with_capacity ( elem_count) ;
3556- let uniform = rand :: distr :: Uniform :: new ( bf16 :: from_f64 ( min ) , bf16 :: from_f64 ( max ) )
3557- . map_err ( Error :: wrap) ?;
3556+ let normal : rand_distr :: Uniform < f32 > =
3557+ rand_distr :: Uniform :: new ( min as f32 , max as f32 ) . map_err ( Error :: wrap) ?;
35583558 for _i in 0 ..elem_count {
3559- data. push ( rng. sample :: < bf16 , _ > ( uniform) )
3559+ let sample: f32 = normal. sample ( & mut rng) ;
3560+ data. push ( bf16:: from_f32 ( sample) ) ;
35603561 }
35613562 Ok ( CpuStorage :: BF16 ( data) )
35623563 }
35633564 DType :: F16 => {
35643565 let mut data = Vec :: with_capacity ( elem_count) ;
3565- let uniform = rand :: distr :: Uniform :: new ( f16 :: from_f64 ( min ) , f16 :: from_f64 ( max ) )
3566- . map_err ( Error :: wrap) ?;
3566+ let normal : rand_distr :: Uniform < f32 > =
3567+ rand_distr :: Uniform :: new ( min as f32 , max as f32 ) . map_err ( Error :: wrap) ?;
35673568 for _i in 0 ..elem_count {
3568- data. push ( rng. sample :: < f16 , _ > ( uniform) )
3569+ let sample: f32 = normal. sample ( & mut rng) ;
3570+ data. push ( f16:: from_f32 ( sample) ) ;
35693571 }
35703572 Ok ( CpuStorage :: F16 ( data) )
35713573 }
@@ -3610,19 +3612,21 @@ impl BackendDevice for CpuDevice {
36103612 }
36113613 DType :: BF16 => {
36123614 let mut data = Vec :: with_capacity ( elem_count) ;
3613- let normal = rand_distr:: Normal :: new ( bf16 :: from_f64 ( mean ) , bf16 :: from_f64 ( std ) )
3614- . map_err ( Error :: wrap) ?;
3615+ let normal: rand_distr:: Normal < f32 > =
3616+ rand_distr :: Normal :: new ( mean as f32 , std as f32 ) . map_err ( Error :: wrap) ?;
36153617 for _i in 0 ..elem_count {
3616- data. push ( normal. sample ( & mut rng) )
3618+ let sample: f32 = normal. sample ( & mut rng) ;
3619+ data. push ( bf16:: from_f32 ( sample) ) ;
36173620 }
36183621 Ok ( CpuStorage :: BF16 ( data) )
36193622 }
36203623 DType :: F16 => {
36213624 let mut data = Vec :: with_capacity ( elem_count) ;
3622- let normal = rand_distr:: Normal :: new ( f16 :: from_f64 ( mean ) , f16 :: from_f64 ( std ) )
3623- . map_err ( Error :: wrap) ?;
3625+ let normal: rand_distr:: Normal < f32 > =
3626+ rand_distr :: Normal :: new ( mean as f32 , std as f32 ) . map_err ( Error :: wrap) ?;
36243627 for _i in 0 ..elem_count {
3625- data. push ( normal. sample ( & mut rng) )
3628+ let sample: f32 = normal. sample ( & mut rng) ;
3629+ data. push ( f16:: from_f32 ( sample) ) ;
36263630 }
36273631 Ok ( CpuStorage :: F16 ( data) )
36283632 }
0 commit comments