@@ -156,6 +156,27 @@ fn normalize_slice<'src, 'dst>(
156
156
op. dispatch ( )
157
157
}
158
158
159
+ /// Normalize each channel separately in an `(N, C, ...)` tensor.
160
+ fn normalize_each_channel < ' a > (
161
+ input : & mut Tensor ,
162
+ chan_opts : impl Fn ( usize ) -> NormalizeOptions < ' a > ,
163
+ ) {
164
+ let batch = input. size ( 0 ) ;
165
+ let chans = input. size ( 1 ) ;
166
+
167
+ input. make_contiguous ( ) ;
168
+ let chunk_len = input. len ( ) / ( batch * chans) ;
169
+ let mut chunks = input. data_mut ( ) . unwrap ( ) . chunks_mut ( chunk_len) ;
170
+
171
+ for _ in 0 ..batch {
172
+ for c in 0 ..chans {
173
+ let chan_data = chunks. next ( ) . unwrap ( ) ;
174
+ let opts = chan_opts ( c) ;
175
+ normalize_slice ( chan_data. into ( ) , opts) ;
176
+ }
177
+ }
178
+ }
179
+
159
180
/// Perform in-place batch normalization on the `NC*` tensor `out`.
160
181
///
161
182
/// See <https://github.com/onnx/onnx/blob/main/docs/Operators.md#batchnormalization>.
@@ -171,37 +192,16 @@ pub fn batch_norm_in_place(
171
192
return Err ( OpError :: InvalidValue ( "Input must have at least 3 dims" ) ) ;
172
193
}
173
194
174
- let batch = input. size ( 0 ) ;
175
- let chans = input. size ( 1 ) ;
176
-
177
- input. make_contiguous ( ) ;
178
-
179
- let chunk_len = input. len ( ) / ( batch * chans) ;
180
- let mut chunks = input. data_mut ( ) . unwrap ( ) . chunks_mut ( chunk_len) ;
181
-
182
- for _ in 0 ..batch {
183
- for c in 0 ..chans {
184
- let chan_mean = mean[ c] ;
185
- let chan_var = var[ c] ;
186
- let chan_scale = scale[ c] ;
187
- let chan_bias = bias[ c] ;
188
- let chan_data = chunks. next ( ) . unwrap ( ) ;
189
-
190
- normalize_slice (
191
- chan_data. into ( ) ,
192
- NormalizeOptions {
193
- mean_normalize : MeanNormalize :: Static {
194
- mean : chan_mean,
195
- variance : chan_var,
196
- } ,
197
- epsilon,
198
- scale : chan_scale,
199
- bias : chan_bias,
200
- ..Default :: default ( )
201
- } ,
202
- ) ;
203
- }
204
- }
195
+ normalize_each_channel ( input, |chan| NormalizeOptions {
196
+ mean_normalize : MeanNormalize :: Static {
197
+ mean : mean[ chan] ,
198
+ variance : var[ chan] ,
199
+ } ,
200
+ epsilon,
201
+ scale : scale[ chan] ,
202
+ bias : bias[ chan] ,
203
+ ..Default :: default ( )
204
+ } ) ;
205
205
206
206
Ok ( ( ) )
207
207
}
@@ -300,7 +300,7 @@ pub fn instance_normalization_in_place(
300
300
bias : NdTensorView < f32 , 1 > ,
301
301
epsilon : Option < f32 > ,
302
302
) -> Result < ( ) , OpError > {
303
- let & [ batch , chans, ..] = input. shape ( ) else {
303
+ let & [ _batch , chans, ..] = input. shape ( ) else {
304
304
return Err ( OpError :: InvalidValue ( "expected input with >= 2 dims" ) ) ;
305
305
} ;
306
306
@@ -319,25 +319,12 @@ pub fn instance_normalization_in_place(
319
319
) ) ;
320
320
}
321
321
322
- // Needed for `vec_*` ops below.
323
- input. make_contiguous ( ) ;
324
-
325
- for n in 0 ..batch {
326
- for c in 0 ..chans {
327
- let mut slice = input. slice_mut ( [ n, c] ) ;
328
- let chan_data = slice. data_mut ( ) . unwrap ( ) ;
329
-
330
- normalize_slice (
331
- chan_data. into ( ) ,
332
- NormalizeOptions {
333
- epsilon,
334
- scale : scale[ c] ,
335
- bias : bias[ c] ,
336
- ..Default :: default ( )
337
- } ,
338
- ) ;
339
- }
340
- }
322
+ normalize_each_channel ( input, |chan| NormalizeOptions {
323
+ epsilon,
324
+ scale : scale[ chan] ,
325
+ bias : bias[ chan] ,
326
+ ..Default :: default ( )
327
+ } ) ;
341
328
342
329
Ok ( ( ) )
343
330
}
0 commit comments