Skip to content

Commit 9197108

Browse files
authored
Merge pull request #664 from robertknight/unary-op-simplify
Refactor some duplication in unary ops
2 parents f5c1d04 + 7176da9 commit 9197108

File tree

1 file changed

+12
-52
lines changed

1 file changed

+12
-52
lines changed

Diff for: src/ops/unary_elementwise.rs

+12-52
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ fn par_unary_op_in_place<T: Copy + Send, VF: Fn(&mut [T]) + Send + Sync, SF: Fn(
179179
}
180180
}
181181

182-
/// Define an operator which supports float tensors and is optimize using SIMD
182+
/// Define an operator which supports float tensors and is optimized using SIMD
183183
/// and multithreading.
184184
macro_rules! parallel_unary_float_op {
185-
($op_name:ident, $func_name:ident, $in_place_func_name:ident, $impl_func:expr, $impl_in_place_func:expr, $impl_scalar:expr) => {
185+
($op_name:ident, $func_name:ident, $in_place_func_name:ident, $simd_kernel:expr) => {
186186
#[derive(Debug)]
187187
pub struct $op_name {}
188188

@@ -214,11 +214,13 @@ macro_rules! parallel_unary_float_op {
214214
}
215215

216216
pub fn $func_name(pool: &TensorPool, input: TensorView) -> Tensor {
217-
par_unary_op(pool, input, $impl_func)
217+
let kernel = $simd_kernel;
218+
par_unary_op(pool, input, |src, dst| kernel.map(src, dst))
218219
}
219220

220221
pub fn $in_place_func_name(input: TensorViewMut) {
221-
par_unary_op_in_place(input, $impl_in_place_func, $impl_scalar);
222+
let kernel = $simd_kernel;
223+
par_unary_op_in_place(input, |dst| kernel.map_mut(dst), |x| kernel.scalar_eval(x));
222224
}
223225
};
224226
}
@@ -399,32 +401,11 @@ pub fn elu_in_place(input: TensorViewMut, alpha: f32) {
399401
Elu { alpha }.apply(input)
400402
}
401403

402-
parallel_unary_float_op!(
403-
Erf,
404-
erf,
405-
erf_in_place,
406-
|src, dest| vecmath::Erf {}.map(src, dest),
407-
|src| vecmath::Erf {}.map_mut(src),
408-
|x| vecmath::Erf {}.scalar_eval(x)
409-
);
410-
parallel_unary_float_op!(
411-
Exp,
412-
exp,
413-
exp_in_place,
414-
|src, dest| vecmath::Exp {}.map(src, dest),
415-
|src| vecmath::Exp {}.map_mut(src),
416-
|x| vecmath::Exp {}.scalar_eval(x)
417-
);
404+
parallel_unary_float_op!(Erf, erf, erf_in_place, vecmath::Erf {});
405+
parallel_unary_float_op!(Exp, exp, exp_in_place, vecmath::Exp {});
418406
unary_float_op!(Floor, floor, floor_in_place, |val: f32| val.floor());
419407

420-
parallel_unary_float_op!(
421-
Gelu,
422-
gelu,
423-
gelu_in_place,
424-
|src, dest| vecmath::Gelu {}.map(src, dest),
425-
|src| vecmath::Gelu {}.map_mut(src),
426-
|x| vecmath::Gelu {}.scalar_eval(x)
427-
);
408+
parallel_unary_float_op!(Gelu, gelu, gelu_in_place, vecmath::Gelu {});
428409

429410
#[derive(Debug)]
430411
pub struct HardSigmoid {
@@ -575,14 +556,7 @@ pub fn round_in_place(x: TensorViewMut) {
575556
Round {}.apply(x)
576557
}
577558

578-
parallel_unary_float_op!(
579-
Sigmoid,
580-
sigmoid,
581-
sigmoid_in_place,
582-
|src, dest| vecmath::Sigmoid {}.map(src, dest),
583-
|src| vecmath::Sigmoid {}.map_mut(src),
584-
|x| vecmath::Sigmoid {}.scalar_eval(x)
585-
);
559+
parallel_unary_float_op!(Sigmoid, sigmoid, sigmoid_in_place, vecmath::Sigmoid {});
586560

587561
// Sigmoid Linear Unit (SiLU) function.
588562
//
@@ -591,14 +565,7 @@ parallel_unary_float_op!(
591565
//
592566
// Not an official ONNX operator, but used in popular object detection models.
593567
// See https://github.com/onnx/onnx/issues/4854.
594-
parallel_unary_float_op!(
595-
Silu,
596-
silu,
597-
silu_in_place,
598-
|src, dest| vecmath::Silu {}.map(src, dest),
599-
|src| vecmath::Silu {}.map_mut(src),
600-
|x| vecmath::Silu {}.scalar_eval(x)
601-
);
568+
parallel_unary_float_op!(Silu, silu, silu_in_place, vecmath::Silu {});
602569

603570
/// Swish function (<https://en.wikipedia.org/wiki/Swish_function>).
604571
///
@@ -682,14 +649,7 @@ unary_float_op!(Softplus, softplus, softplus_in_place, |val: f32| {
682649
val.exp().ln_1p()
683650
});
684651
unary_float_op!(Tan, tan, tan_in_place, |val: f32| val.tan());
685-
parallel_unary_float_op!(
686-
Tanh,
687-
tanh,
688-
tanh_in_place,
689-
|src, dest| vecmath::Tanh {}.map(src, dest),
690-
|src| vecmath::Tanh {}.map_mut(src),
691-
|x| vecmath::Tanh {}.scalar_eval(x)
692-
);
652+
parallel_unary_float_op!(Tanh, tanh, tanh_in_place, vecmath::Tanh {});
693653

694654
#[cfg(test)]
695655
mod tests {

0 commit comments

Comments
 (0)