@@ -179,10 +179,10 @@ fn par_unary_op_in_place<T: Copy + Send, VF: Fn(&mut [T]) + Send + Sync, SF: Fn(
179
179
}
180
180
}
181
181
182
- /// Define an operator which supports float tensors and is optimize using SIMD
182
+ /// Define an operator which supports float tensors and is optimized using SIMD
183
183
/// and multithreading.
184
184
macro_rules! parallel_unary_float_op {
185
- ( $op_name: ident, $func_name: ident, $in_place_func_name: ident, $impl_func : expr , $impl_in_place_func : expr , $impl_scalar : expr) => {
185
+ ( $op_name: ident, $func_name: ident, $in_place_func_name: ident, $simd_kernel : expr) => {
186
186
#[ derive( Debug ) ]
187
187
pub struct $op_name { }
188
188
@@ -214,11 +214,13 @@ macro_rules! parallel_unary_float_op {
214
214
}
215
215
216
216
pub fn $func_name( pool: & TensorPool , input: TensorView ) -> Tensor {
217
- par_unary_op( pool, input, $impl_func)
217
+ let kernel = $simd_kernel;
218
+ par_unary_op( pool, input, |src, dst| kernel. map( src, dst) )
218
219
}
219
220
220
221
pub fn $in_place_func_name( input: TensorViewMut ) {
221
- par_unary_op_in_place( input, $impl_in_place_func, $impl_scalar) ;
222
+ let kernel = $simd_kernel;
223
+ par_unary_op_in_place( input, |dst| kernel. map_mut( dst) , |x| kernel. scalar_eval( x) ) ;
222
224
}
223
225
} ;
224
226
}
@@ -399,32 +401,11 @@ pub fn elu_in_place(input: TensorViewMut, alpha: f32) {
399
401
Elu { alpha } . apply ( input)
400
402
}
401
403
402
- parallel_unary_float_op ! (
403
- Erf ,
404
- erf,
405
- erf_in_place,
406
- |src, dest| vecmath:: Erf { } . map( src, dest) ,
407
- |src| vecmath:: Erf { } . map_mut( src) ,
408
- |x| vecmath:: Erf { } . scalar_eval( x)
409
- ) ;
410
- parallel_unary_float_op ! (
411
- Exp ,
412
- exp,
413
- exp_in_place,
414
- |src, dest| vecmath:: Exp { } . map( src, dest) ,
415
- |src| vecmath:: Exp { } . map_mut( src) ,
416
- |x| vecmath:: Exp { } . scalar_eval( x)
417
- ) ;
404
+ parallel_unary_float_op ! ( Erf , erf, erf_in_place, vecmath:: Erf { } ) ;
405
+ parallel_unary_float_op ! ( Exp , exp, exp_in_place, vecmath:: Exp { } ) ;
418
406
unary_float_op ! ( Floor , floor, floor_in_place, |val: f32 | val. floor( ) ) ;
419
407
420
- parallel_unary_float_op ! (
421
- Gelu ,
422
- gelu,
423
- gelu_in_place,
424
- |src, dest| vecmath:: Gelu { } . map( src, dest) ,
425
- |src| vecmath:: Gelu { } . map_mut( src) ,
426
- |x| vecmath:: Gelu { } . scalar_eval( x)
427
- ) ;
408
+ parallel_unary_float_op ! ( Gelu , gelu, gelu_in_place, vecmath:: Gelu { } ) ;
428
409
429
410
#[ derive( Debug ) ]
430
411
pub struct HardSigmoid {
@@ -575,14 +556,7 @@ pub fn round_in_place(x: TensorViewMut) {
575
556
Round { } . apply ( x)
576
557
}
577
558
578
- parallel_unary_float_op ! (
579
- Sigmoid ,
580
- sigmoid,
581
- sigmoid_in_place,
582
- |src, dest| vecmath:: Sigmoid { } . map( src, dest) ,
583
- |src| vecmath:: Sigmoid { } . map_mut( src) ,
584
- |x| vecmath:: Sigmoid { } . scalar_eval( x)
585
- ) ;
559
+ parallel_unary_float_op ! ( Sigmoid , sigmoid, sigmoid_in_place, vecmath:: Sigmoid { } ) ;
586
560
587
561
// Sigmoid Linear Unit (SiLU) function.
588
562
//
@@ -591,14 +565,7 @@ parallel_unary_float_op!(
591
565
//
592
566
// Not an official ONNX operator, but used in popular object detection models.
593
567
// See https://github.com/onnx/onnx/issues/4854.
594
- parallel_unary_float_op ! (
595
- Silu ,
596
- silu,
597
- silu_in_place,
598
- |src, dest| vecmath:: Silu { } . map( src, dest) ,
599
- |src| vecmath:: Silu { } . map_mut( src) ,
600
- |x| vecmath:: Silu { } . scalar_eval( x)
601
- ) ;
568
+ parallel_unary_float_op ! ( Silu , silu, silu_in_place, vecmath:: Silu { } ) ;
602
569
603
570
/// Swish function (<https://en.wikipedia.org/wiki/Swish_function>).
604
571
///
@@ -682,14 +649,7 @@ unary_float_op!(Softplus, softplus, softplus_in_place, |val: f32| {
682
649
val. exp( ) . ln_1p( )
683
650
} ) ;
684
651
unary_float_op ! ( Tan , tan, tan_in_place, |val: f32 | val. tan( ) ) ;
685
- parallel_unary_float_op ! (
686
- Tanh ,
687
- tanh,
688
- tanh_in_place,
689
- |src, dest| vecmath:: Tanh { } . map( src, dest) ,
690
- |src| vecmath:: Tanh { } . map_mut( src) ,
691
- |x| vecmath:: Tanh { } . scalar_eval( x)
692
- ) ;
652
+ parallel_unary_float_op ! ( Tanh , tanh, tanh_in_place, vecmath:: Tanh { } ) ;
693
653
694
654
#[ cfg( test) ]
695
655
mod tests {
0 commit comments