Skip to content

Commit c10cf5a

Browse files
committed
Unify InstanceNormalization and BatchNormalization outer loops
This brings the improvement from b245b72 to InstanceNormalization.
1 parent 0fb9130 commit c10cf5a

File tree

1 file changed

+38
-51
lines changed

1 file changed

+38
-51
lines changed

src/ops/norm.rs

+38-51
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,27 @@ fn normalize_slice<'src, 'dst>(
156156
op.dispatch()
157157
}
158158

159+
/// Normalize each channel separately in an `(N, C, ...)` tensor.
160+
fn normalize_each_channel<'a>(
161+
input: &mut Tensor,
162+
chan_opts: impl Fn(usize) -> NormalizeOptions<'a>,
163+
) {
164+
let batch = input.size(0);
165+
let chans = input.size(1);
166+
167+
input.make_contiguous();
168+
let chunk_len = input.len() / (batch * chans);
169+
let mut chunks = input.data_mut().unwrap().chunks_mut(chunk_len);
170+
171+
for _ in 0..batch {
172+
for c in 0..chans {
173+
let chan_data = chunks.next().unwrap();
174+
let opts = chan_opts(c);
175+
normalize_slice(chan_data.into(), opts);
176+
}
177+
}
178+
}
179+
159180
/// Perform in-place batch normalization on the `NC*` tensor `out`.
160181
///
161182
/// See <https://github.com/onnx/onnx/blob/main/docs/Operators.md#batchnormalization>.
@@ -171,37 +192,16 @@ pub fn batch_norm_in_place(
171192
return Err(OpError::InvalidValue("Input must have at least 3 dims"));
172193
}
173194

174-
let batch = input.size(0);
175-
let chans = input.size(1);
176-
177-
input.make_contiguous();
178-
179-
let chunk_len = input.len() / (batch * chans);
180-
let mut chunks = input.data_mut().unwrap().chunks_mut(chunk_len);
181-
182-
for _ in 0..batch {
183-
for c in 0..chans {
184-
let chan_mean = mean[c];
185-
let chan_var = var[c];
186-
let chan_scale = scale[c];
187-
let chan_bias = bias[c];
188-
let chan_data = chunks.next().unwrap();
189-
190-
normalize_slice(
191-
chan_data.into(),
192-
NormalizeOptions {
193-
mean_normalize: MeanNormalize::Static {
194-
mean: chan_mean,
195-
variance: chan_var,
196-
},
197-
epsilon,
198-
scale: chan_scale,
199-
bias: chan_bias,
200-
..Default::default()
201-
},
202-
);
203-
}
204-
}
195+
normalize_each_channel(input, |chan| NormalizeOptions {
196+
mean_normalize: MeanNormalize::Static {
197+
mean: mean[chan],
198+
variance: var[chan],
199+
},
200+
epsilon,
201+
scale: scale[chan],
202+
bias: bias[chan],
203+
..Default::default()
204+
});
205205

206206
Ok(())
207207
}
@@ -300,7 +300,7 @@ pub fn instance_normalization_in_place(
300300
bias: NdTensorView<f32, 1>,
301301
epsilon: Option<f32>,
302302
) -> Result<(), OpError> {
303-
let &[batch, chans, ..] = input.shape() else {
303+
let &[_batch, chans, ..] = input.shape() else {
304304
return Err(OpError::InvalidValue("expected input with >= 2 dims"));
305305
};
306306

@@ -319,25 +319,12 @@ pub fn instance_normalization_in_place(
319319
));
320320
}
321321

322-
// Needed for `vec_*` ops below.
323-
input.make_contiguous();
324-
325-
for n in 0..batch {
326-
for c in 0..chans {
327-
let mut slice = input.slice_mut([n, c]);
328-
let chan_data = slice.data_mut().unwrap();
329-
330-
normalize_slice(
331-
chan_data.into(),
332-
NormalizeOptions {
333-
epsilon,
334-
scale: scale[c],
335-
bias: bias[c],
336-
..Default::default()
337-
},
338-
);
339-
}
340-
}
322+
normalize_each_channel(input, |chan| NormalizeOptions {
323+
epsilon,
324+
scale: scale[chan],
325+
bias: bias[chan],
326+
..Default::default()
327+
});
341328

342329
Ok(())
343330
}

0 commit comments

Comments
 (0)