File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -258,7 +258,7 @@ pub fn batch_norm2d(
258258 let storage = Storage :: new ( output_data) ;
259259 let mut tensor = Tensor :: new_with_storage ( storage, shape) ;
260260
261- if training && ( input. requires_grad ( ) || gamma. map_or ( false , |g| g. requires_grad ( ) ) || beta. is_some_and ( |b| b. requires_grad ( ) ) ) {
261+ if training && ( input. requires_grad ( ) || gamma. is_some_and ( |g| g. requires_grad ( ) ) || beta. is_some_and ( |b| b. requires_grad ( ) ) ) {
262262 tensor. set_requires_grad_mut ( true ) ;
263263 // Need to pass mean/inv_std to backward
264264 // But mean/inv_std are Vec<f32>.
@@ -491,7 +491,7 @@ pub fn layer_norm(
491491 let storage = Storage :: new ( output_data) ;
492492 let mut tensor = Tensor :: new_with_storage ( storage, shape) ;
493493
494- if input. requires_grad ( ) || weight. map_or ( false , |w| w. requires_grad ( ) ) || bias. is_some_and ( |b| b. requires_grad ( ) ) {
494+ if input. requires_grad ( ) || weight. is_some_and ( |w| w. requires_grad ( ) ) || bias. is_some_and ( |b| b. requires_grad ( ) ) {
495495 tensor. set_requires_grad_mut ( true ) ;
496496 // Store mean/inv_std for backward
497497 // They are (OuterDim). We can store as (OuterDim) tensor.
Original file line number Diff line number Diff line change 11use std:: sync:: { Arc , Mutex , RwLockReadGuard , RwLockWriteGuard } ;
22use std:: fmt;
33use std:: ops:: { Add , Mul , Sub } ;
4- use rand:: Rng ;
4+ // use rand::Rng;
55use rand_distr:: { Normal , Uniform , Distribution } ;
66// use rayon::prelude::*;
77// use rayon::iter::{IntoParallelRefIterator, ParallelIterator, IndexedParallelIterator};
You can’t perform that action at this time.
0 commit comments