Skip to content

Commit 2b3806b

Browse files
committed
Accept 1D and 2D inputs in BatchNormalization
Per the spec, BatchNormalization should accept inputs of rank 1 and up. https://huggingface.co/timm/levit_128s.fb_dist_in1k for example applies batch norm to 2D inputs. Normalization of 1D and 2D inputs is currently very inefficient due to the overhead of applying slice-oriented kernels to single-element slices, but at least it works now.
1 parent d4b2f19 commit 2b3806b

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

src/ops/norm.rs

+24-10
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ fn normalize_each_channel<'a>(
162162
chan_opts: impl Fn(usize) -> NormalizeOptions<'a>,
163163
) {
164164
let batch = input.size(0);
165-
let chans = input.size(1);
165+
166+
// Per BatchNormalization spec: "The op also accepts single dimension input
167+
// of size N in which case C is assumed to be 1"
168+
let chans = if input.ndim() >= 2 { input.size(1) } else { 1 };
166169

167170
input.make_contiguous();
168171
let chunk_len = input.len() / (batch * chans);
@@ -188,8 +191,8 @@ pub fn batch_norm_in_place(
188191
var: &NdTensorView<f32, 1>,
189192
epsilon: f32,
190193
) -> Result<(), OpError> {
191-
if input.ndim() < 3 {
192-
return Err(OpError::InvalidValue("Input must have at least 3 dims"));
194+
if input.ndim() < 1 {
195+
return Err(OpError::InvalidValue("Input must have at least 1 dim"));
193196
}
194197

195198
normalize_each_channel(input, |chan| NormalizeOptions {
@@ -729,6 +732,14 @@ mod tests {
729732
Case {
730733
input: Tensor::from_data(&[1, 2, 1], vec![1.0, 2.0]),
731734
},
735+
// 2D input
736+
Case {
737+
input: Tensor::from_data(&[1, 2], vec![1.0, 2.0]),
738+
},
739+
// 1D input. Channel count is implicitly 1.
740+
Case {
741+
input: Tensor::from([1.0, 2.0]),
742+
},
732743
];
733744

734745
cases.test_each(|Case { input }| {
@@ -737,14 +748,17 @@ mod tests {
737748
let bias = &[0.1, 0.2];
738749
let mean = &[0.5, -0.5];
739750
let var = &[1.0, 2.0];
740-
741751
let epsilon = 1e-5 as f32;
742752

743-
let flattened = input.reshaped([input.len()]);
753+
let expected = if input.ndim() >= 2 {
754+
let flattened = input.reshaped([input.len()]);
755+
let y1 = (flattened[0] - mean[0]) / (var[0] + epsilon).sqrt() * scale[0] + bias[0];
756+
let y2 = (flattened[1] - mean[1]) / (var[1] + epsilon).sqrt() * scale[1] + bias[1];
757+
Tensor::from_data(input.shape(), vec![y1, y2])
758+
} else {
759+
input.map(|&x| (x - mean[0]) / (var[0] + epsilon).sqrt() * scale[0] + bias[0])
760+
};
744761

745-
let y1 = (flattened[0] - mean[0]) / (var[0] + epsilon).sqrt() * scale[0] + bias[0];
746-
let y2 = (flattened[1] - mean[1]) / (var[1] + epsilon).sqrt() * scale[1] + bias[1];
747-
let expected = Tensor::from_data(input.shape(), vec![y1, y2]);
748762
let result = batch_norm(
749763
&pool,
750764
input.view(),
@@ -767,7 +781,7 @@ mod tests {
767781
let mean = &[0.5, -0.5];
768782
let var = &[1.0, 2.0];
769783
let epsilon = 1e-5 as f32;
770-
let input = Tensor::zeros(&[2]);
784+
let input = Tensor::from(5.0);
771785

772786
let pool = new_pool();
773787
let result = batch_norm(
@@ -782,7 +796,7 @@ mod tests {
782796

783797
assert_eq!(
784798
result,
785-
Err(OpError::InvalidValue("Input must have at least 3 dims"))
799+
Err(OpError::InvalidValue("Input must have at least 1 dim"))
786800
);
787801
}
788802

0 commit comments

Comments
 (0)