Skip to content

Commit 0fb9130

Browse files
authored
Merge pull request #661 from robertknight/batch-norm-opt
Optimize BatchNormalization by avoiding tensor slicing
2 parents a6fe23a + b245b72 commit 0fb9130

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

Diff for: src/ops/norm.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,16 @@ pub fn batch_norm_in_place(
176176

177177
input.make_contiguous();
178178

179-
for n in 0..batch {
179+
let chunk_len = input.len() / (batch * chans);
180+
let mut chunks = input.data_mut().unwrap().chunks_mut(chunk_len);
181+
182+
for _ in 0..batch {
180183
for c in 0..chans {
181184
let chan_mean = mean[c];
182185
let chan_var = var[c];
183186
let chan_scale = scale[c];
184187
let chan_bias = bias[c];
185-
let mut chan = input.slice_mut([n, c]);
186-
let chan_data = chan.data_mut().unwrap();
188+
let chan_data = chunks.next().unwrap();
187189

188190
normalize_slice(
189191
chan_data.into(),

0 commit comments

Comments
 (0)