@@ -162,7 +162,10 @@ fn normalize_each_channel<'a>(
162
162
chan_opts : impl Fn ( usize ) -> NormalizeOptions < ' a > ,
163
163
) {
164
164
let batch = input. size ( 0 ) ;
165
- let chans = input. size ( 1 ) ;
165
+
166
+ // Per BatchNormalization spec: "The op also accepts single dimension input
167
+ // of size N in which case C is assumed to be 1"
168
+ let chans = if input. ndim ( ) >= 2 { input. size ( 1 ) } else { 1 } ;
166
169
167
170
input. make_contiguous ( ) ;
168
171
let chunk_len = input. len ( ) / ( batch * chans) ;
@@ -188,8 +191,8 @@ pub fn batch_norm_in_place(
188
191
var : & NdTensorView < f32 , 1 > ,
189
192
epsilon : f32 ,
190
193
) -> Result < ( ) , OpError > {
191
- if input. ndim ( ) < 3 {
192
- return Err ( OpError :: InvalidValue ( "Input must have at least 3 dims " ) ) ;
194
+ if input. ndim ( ) < 1 {
195
+ return Err ( OpError :: InvalidValue ( "Input must have at least 1 dim " ) ) ;
193
196
}
194
197
195
198
normalize_each_channel ( input, |chan| NormalizeOptions {
@@ -729,6 +732,14 @@ mod tests {
729
732
Case {
730
733
input : Tensor :: from_data ( & [ 1 , 2 , 1 ] , vec ! [ 1.0 , 2.0 ] ) ,
731
734
} ,
735
+ // 2D input
736
+ Case {
737
+ input : Tensor :: from_data ( & [ 1 , 2 ] , vec ! [ 1.0 , 2.0 ] ) ,
738
+ } ,
739
+ // 1D input. Channel count is implicitly 1.
740
+ Case {
741
+ input : Tensor :: from ( [ 1.0 , 2.0 ] ) ,
742
+ } ,
732
743
] ;
733
744
734
745
cases. test_each ( |Case { input } | {
@@ -737,14 +748,17 @@ mod tests {
737
748
let bias = & [ 0.1 , 0.2 ] ;
738
749
let mean = & [ 0.5 , -0.5 ] ;
739
750
let var = & [ 1.0 , 2.0 ] ;
740
-
741
751
let epsilon = 1e-5 as f32 ;
742
752
743
- let flattened = input. reshaped ( [ input. len ( ) ] ) ;
753
+ let expected = if input. ndim ( ) >= 2 {
754
+ let flattened = input. reshaped ( [ input. len ( ) ] ) ;
755
+ let y1 = ( flattened[ 0 ] - mean[ 0 ] ) / ( var[ 0 ] + epsilon) . sqrt ( ) * scale[ 0 ] + bias[ 0 ] ;
756
+ let y2 = ( flattened[ 1 ] - mean[ 1 ] ) / ( var[ 1 ] + epsilon) . sqrt ( ) * scale[ 1 ] + bias[ 1 ] ;
757
+ Tensor :: from_data ( input. shape ( ) , vec ! [ y1, y2] )
758
+ } else {
759
+ input. map ( |& x| ( x - mean[ 0 ] ) / ( var[ 0 ] + epsilon) . sqrt ( ) * scale[ 0 ] + bias[ 0 ] )
760
+ } ;
744
761
745
- let y1 = ( flattened[ 0 ] - mean[ 0 ] ) / ( var[ 0 ] + epsilon) . sqrt ( ) * scale[ 0 ] + bias[ 0 ] ;
746
- let y2 = ( flattened[ 1 ] - mean[ 1 ] ) / ( var[ 1 ] + epsilon) . sqrt ( ) * scale[ 1 ] + bias[ 1 ] ;
747
- let expected = Tensor :: from_data ( input. shape ( ) , vec ! [ y1, y2] ) ;
748
762
let result = batch_norm (
749
763
& pool,
750
764
input. view ( ) ,
@@ -767,7 +781,7 @@ mod tests {
767
781
let mean = & [ 0.5 , -0.5 ] ;
768
782
let var = & [ 1.0 , 2.0 ] ;
769
783
let epsilon = 1e-5 as f32 ;
770
- let input = Tensor :: zeros ( & [ 2 ] ) ;
784
+ let input = Tensor :: from ( 5.0 ) ;
771
785
772
786
let pool = new_pool ( ) ;
773
787
let result = batch_norm (
@@ -782,7 +796,7 @@ mod tests {
782
796
783
797
assert_eq ! (
784
798
result,
785
- Err ( OpError :: InvalidValue ( "Input must have at least 3 dims " ) )
799
+ Err ( OpError :: InvalidValue ( "Input must have at least 1 dim " ) )
786
800
) ;
787
801
}
788
802
0 commit comments