Skip to content

Commit a4b40fa

Browse files
authored
Experimental (#115)
* Fixed some minor issues with NEON * added masked_add for NEON
1 parent 732d2eb commit a4b40fa

4 files changed

Lines changed: 41 additions & 1 deletion

File tree

primitive_data/primitives/calc.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,19 @@ definitions:
413413
ctype: ["float", "double"]
414414
lscpu_flags: ["avx512f", "avx512vl"]
415415
implementation: "return _mm_mask_add_{{ intrin_tp_full[ctype] }}(vec_a, tsl::to_integral<Vec>(mask), vec_a, vec_b);"
416+
#ARM NEON
417+
- target_extension: "neon"
418+
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t"]
419+
lscpu_flags: ['neon']
420+
implementation: |
421+
return vaddq_{{ intrin_tp_full[ctype] }}(vec_a, tsl::binary_and<Vec>(mask, vec_b));
422+
- target_extension: "neon"
423+
ctype: ["int8_t", "int16_t", "int32_t", "int64_t", "float", "double"]
424+
lscpu_flags: ['neon']
425+
implementation: |
426+
using T = typename Vec::offset_base_type;
427+
using OffsetExt = typename Vec::template transform_extension<T>;
428+
return vaddq_{{ intrin_tp_full[ctype] }}(vec_a, tsl::reinterpret<OffsetExt, Vec>(tsl::binary_and<OffsetExt>(mask, tsl::reinterpret<Vec, OffsetExt>(vec_b))));
416429
#SCALAR
417430
- target_extension: "scalar"
418431
ctype: ["uint8_t", "int8_t", "uint16_t", "int16_t", "uint32_t", "int32_t", "uint64_t", "int64_t", "float", "double"]
@@ -554,6 +567,19 @@ definitions:
554567
ctype: ["float", "double"]
555568
lscpu_flags: ["avx512f", "avx512vl"]
556569
implementation: "return _mm_mask_add_{{ intrin_tp_full[ctype] }}(vec_a, mask, vec_a, vec_b);"
570+
#ARM NEON
571+
- target_extension: "neon"
572+
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t"]
573+
lscpu_flags: ['neon']
574+
implementation: |
575+
return vaddq_{{ intrin_tp_full[ctype] }}(vec_a, tsl::binary_and<Vec>(tsl::to_mask<Vec>(mask), vec_b));
576+
- target_extension: "neon"
577+
ctype: ["int8_t", "int16_t", "int32_t", "int64_t", "float", "double"]
578+
lscpu_flags: ['neon']
579+
implementation: |
580+
using T = typename Vec::offset_base_type;
581+
using OffsetExt = typename Vec::template transform_extension<T>;
582+
return vaddq_{{ intrin_tp_full[ctype] }}(vec_a, tsl::reinterpret<OffsetExt, Vec>(tsl::binary_and<OffsetExt>(tsl::to_mask<Vec>(mask), tsl::reinterpret<Vec, OffsetExt>(vec_b))));
557583
#SCALAR
558584
- target_extension: "scalar"
559585
ctype: ["uint8_t", "int8_t", "uint16_t", "int16_t", "uint32_t", "int32_t", "uint64_t", "int64_t", "float", "double"]

primitive_data/primitives/convert.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ definitions:
163163
additional_simd_template_base_type: ["int8_t", "uint8_t", "int16_t", "uint16_t", "int32_t", "uint32_t", "int64_t", "uint64_t"]
164164
lscpu_flags: ["neon"]
165165
implementation: "return vreinterpretq_{{ intrin_tp_full[additional_simd_template_base_type] }}_{{ intrin_tp_full[ctype] }}(data);"
166+
- target_extension: "neon"
167+
ctype: ["int8_t", "int16_t", "int32_t", "int64_t"]
168+
additional_simd_template_base_type: ["uint8_t", "uint16_t", "uint32_t", "uint64_t"]
169+
lscpu_flags: ["neon"]
170+
implementation: "return vreinterpretq_{{ intrin_tp_full[additional_simd_template_base_type] }}_{{ intrin_tp_full[ctype] }}(data);"
171+
- target_extension: "neon"
172+
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t"]
173+
additional_simd_template_base_type: ["int8_t", "int16_t", "int32_t", "int64_t"]
174+
lscpu_flags: ["neon"]
175+
implementation: "return vreinterpretq_{{ intrin_tp_full[additional_simd_template_base_type] }}_{{ intrin_tp_full[ctype] }}(data);"
166176
#INTEL - FPGA
167177
- target_extension: ["oneAPIfpga", "oneAPIfpgaRTL"]
168178
ctype: ["float", "double"]

primitive_data/primitives/ls.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,10 @@ definitions:
569569
ctype: ["float", "double"]
570570
lscpu_flags: ["sse2"]
571571
implementation: "return _mm_setzero_{{ intrin_tp_full[ctype] }}();"
572+
- target_extension: "neon"
573+
ctype: ["uint8_t", "int8_t", "uint16_t", "int16_t", "uint32_t", "int32_t", "uint64_t", "int64_t", "float", "double"]
574+
lscpu_flags: ["neon"]
575+
implementation: "return vdupq_n_{{ intrin_tp_full[ctype] }}(0);"
572576
#FPGA
573577
- target_extension: ["oneAPIfpga", "oneAPIfpgaRTL"]
574578
ctype: ["uint8_t", "int8_t", "uint16_t", "int16_t", "uint32_t", "int32_t", "float", "uint64_t", "int64_t", "double"]

primitive_data/primitives/mask.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ definitions:
206206
for(int i = 0; i < Vec::vector_element_count(); i++){
207207
((mask >> i) & 0b1) ? result[i] = static_cast<T>(-1) : result[i] = 0;
208208
}
209-
return reinterpret_cast<Vec::mask_type>(tsl::loadu<OffsetExt>(result));
209+
return tsl::loadu<OffsetExt>(result);
210210
...
211211
---
212212
primitive_name: "mask_binary_not"

0 commit comments

Comments
 (0)