Skip to content

Commit c28ce78

Browse files
dsharletgxnnpack-bot
authored andcommitted
Add basic HVX reduce kernels.
Kernels that should be reasonably good: - min, max, min_max for all types - sum, sum_squared for int8 and uint8 for k1 > 1 Kernels that are not good and need work: - sum, sum_squared for int8 and uint8 for k1 = 1. These are currently naively implemented with conversions, and wide arithmetic (instead of widening arithmetic). - In general k1 = 1 is not good because we unroll the accumulator by 2x/4x, so we can load whole vectors, which makes the accumulators really large (e.g. 128). This means that we're very likely to hit tail case code. Example inner loop (k1 > 1 uint8 sum, sum_squared is almost identical): ``` .LBB30_116: // %while.body14.i // Parent Loop BB30_110 Depth=1 // Parent Loop BB30_112 Depth=2 // Parent Loop BB30_114 Depth=3 // => This Inner Loop Header: Depth=4 { v27 = vmemu(r5++#1) } { v28 = vmemu(r6++#1) } { v10.w += vrmpy(v27.ub,r9.b) v23 = vmemu(r0++#1) } { v9.w += vrmpy(v28.ub,r9.b) v15 = vmemu(r7++#1) } { v5.w += vrmpy(v23.ub,r9.b) } { v6.w += vrmpy(v15.ub,r9.b) r3 = add(r3,#-128) } { p3 = cmp.gtu(r3,#127) if (p3.new) jump:t .LBB30_116 } ``` PiperOrigin-RevId: 874380564
1 parent af8ea33 commit c28ce78

File tree

9 files changed

+229
-6
lines changed

9 files changed

+229
-6
lines changed

ynnpack/kernels/reduce/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ ynn_cc_library(
4242
"arm_neonfma": ["arm_neonfma.cc"],
4343
"arm_neondot": ["arm_neondot.cc"],
4444
"arm_neon": ["arm_neon.cc"],
45+
"hexagon_hvx": ["hexagon_hvx.cc"],
4546
"x86_ssse3": ["x86_ssse3.cc"],
4647
"x86_sse2": ["x86_sse2.cc"],
4748
"x86_sse41": ["x86_sse41.cc"],
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include "ynnpack/base/simd/hexagon_hvx.h"
7+
8+
#include <hexagon_protos.h>
9+
#include <hexagon_types.h>
10+
#include <hvx_hexagon_protos.h>
11+
12+
#include <cstddef>
13+
#include <cstdint>
14+
#include <type_traits>
15+
16+
#include "ynnpack/base/base.h"
17+
#include "ynnpack/base/bfloat16.h"
18+
#include "ynnpack/base/half.h"
19+
#include "ynnpack/base/simd/vec.h"
20+
#include "ynnpack/kernels/reduce/generic.h"
21+
#include "ynnpack/kernels/reduce/min_max_accumulator.h"
22+
#include "ynnpack/kernels/reduce/sum_accumulator.h"
23+
24+
namespace ynn {
25+
26+
namespace simd {
27+
28+
static s32x32 reduce_add(
29+
s32x32 a, u8x128 b, Identity /*map_fn*/,
30+
std::integral_constant<size_t, 4> /*horizontal_factor*/) {
31+
a.v = Q6_Vw_vrmpyacc_VwVubRb(a.v, b.v, 0x01010101);
32+
return a;
33+
}
34+
35+
static s32x32 reduce_add(
36+
s32x32 a, u8x128 b, Square /*map_fn*/,
37+
std::integral_constant<size_t, 4> /*horizontal_factor*/) {
38+
a.v = Q6_Vuw_vrmpyacc_VuwVubVub(a.v, b.v, b.v);
39+
return a;
40+
}
41+
42+
static s32x32 reduce_add(
43+
s32x32 a, s8x128 b, Identity /*map_fn*/,
44+
std::integral_constant<size_t, 4> /*horizontal_factor*/) {
45+
const auto ones = Q6_V_vsplat_R(0x01010101);
46+
a.v = Q6_Vw_vrmpyacc_VwVbVb(a.v, b.v, ones);
47+
return a;
48+
}
49+
50+
static s32x32 reduce_add(
51+
s32x32 a, s8x128 b, Square /*map_fn*/,
52+
std::integral_constant<size_t, 4> /*horizontal_factor*/) {
53+
a.v = Q6_Vw_vrmpyacc_VwVbVb(a.v, b.v, b.v);
54+
return a;
55+
}
56+
57+
} // namespace simd
58+
59+
using simd::bf16x64;
60+
using simd::f16x64;
61+
using simd::f32x32;
62+
using simd::s16x64;
63+
using simd::s32x32;
64+
using simd::s8x128;
65+
using simd::u8x128;
66+
using s32x128 = simd::vec<int32_t, 128>;
67+
using f32x128 = simd::vec<float, 128>;
68+
69+
using bf16x64_rvar = float16_wrapper<bf16x64, s16x64>;
70+
71+
MIN_MAX_KERNEL(min_max_fp32_4x32_hvx, f32x32, f32x32, float, 32);
72+
MIN_MAX_KERNEL(min_max_fp16_4x64_hvx, f16x64, f16x64, half, 64);
73+
MIN_MAX_KERNEL(min_max_bf16_4x64_hvx, bf16x64_rvar, bf16x64_rvar, bfloat16, 64);
74+
MIN_MAX_KERNEL(min_max_uint8_4x128_hvx, u8x128, u8x128, uint8_t, 128);
75+
MIN_MAX_KERNEL(min_max_int8_4x128_hvx, s8x128, s8x128, int8_t, 128);
76+
77+
MIN_MAX_KERNEL(min_fp32_4x32_hvx, f32x32, dummy_t, float, 32);
78+
MIN_MAX_KERNEL(min_fp16_4x64_hvx, f16x64, dummy_t, half, 64);
79+
MIN_MAX_KERNEL(min_bf16_4x64_hvx, bf16x64_rvar, dummy_t, bfloat16, 64);
80+
MIN_MAX_KERNEL(min_uint8_4x128_hvx, u8x128, dummy_t, uint8_t, 128);
81+
MIN_MAX_KERNEL(min_int8_4x128_hvx, s8x128, dummy_t, int8_t, 128);
82+
83+
MIN_MAX_KERNEL(max_fp32_4x32_hvx, dummy_t, f32x32, float, 32);
84+
MIN_MAX_KERNEL(max_fp16_4x64_hvx, dummy_t, f16x64, half, 64);
85+
MIN_MAX_KERNEL(max_bf16_4x64_hvx, dummy_t, bf16x64_rvar, bfloat16, 64);
86+
MIN_MAX_KERNEL(max_uint8_4x128_hvx, dummy_t, u8x128, uint8_t, 128);
87+
MIN_MAX_KERNEL(max_int8_4x128_hvx, dummy_t, s8x128, int8_t, 128);
88+
89+
void sum_uint8_int32_hvx(size_t n, size_t k3, size_t k2, size_t k1,
90+
size_t a_stride_n, size_t a_stride_k3,
91+
size_t a_stride_k2, const void* a, size_t, void* c) {
92+
if (k1 == 1 && a_stride_n == sizeof(uint8_t)) {
93+
// TODO(b/482435301): This case is poorly optimized. It naively converts to
94+
// int32 and does a 32-bit add. We should be using a widening op, and
95+
// storing the accumulators interleaved until `sum_rows`.
96+
stream_reduce<sum_accumulator_k1_1<s32x128>, uint8_t, int32_t>(
97+
n, k3, k2, a_stride_k3, a_stride_k2,
98+
reinterpret_cast<const uint8_t*>(a),
99+
/*C_stride_m=*/0, reinterpret_cast<int32_t*>(c));
100+
} else {
101+
tiled_reduce<sum_accumulator_x32<s32x32, 128, Identity>, uint8_t, int32_t>(
102+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
103+
reinterpret_cast<const uint8_t*>(a), /*C_stride_m=*/0,
104+
reinterpret_cast<int32_t*>(c));
105+
}
106+
}
107+
108+
void sum_squared_uint8_int32_hvx(size_t n, size_t k3, size_t k2, size_t k1,
109+
size_t a_stride_n, size_t a_stride_k3,
110+
size_t a_stride_k2, const void* a, size_t,
111+
void* c) {
112+
if (k1 == 1 && a_stride_n == sizeof(uint8_t)) {
113+
// TODO(b/482435301): This case is poorly optimized. It naively converts to
114+
// int32 and does a 32-bit add. We should be using a widening op, and
115+
// storing the accumulators interleaved until `sum_rows`.
116+
stream_reduce<sum_accumulator_k1_1<s32x128, Square>, uint8_t, int32_t>(
117+
n, k3, k2, a_stride_k3, a_stride_k2,
118+
reinterpret_cast<const uint8_t*>(a),
119+
/*C_stride_m=*/0, reinterpret_cast<int32_t*>(c));
120+
} else {
121+
tiled_reduce<sum_accumulator_x32<s32x32, 128, Square>, uint8_t, int32_t>(
122+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
123+
reinterpret_cast<const uint8_t*>(a), /*C_stride_m=*/0,
124+
reinterpret_cast<int32_t*>(c));
125+
}
126+
}
127+
128+
void sum_int8_int32_hvx(size_t n, size_t k3, size_t k2, size_t k1,
129+
size_t a_stride_n, size_t a_stride_k3,
130+
size_t a_stride_k2, const void* a, size_t, void* c) {
131+
if (k1 == 1 && a_stride_n == sizeof(int8_t)) {
132+
// TODO(b/482435301): This case is poorly optimized. It naively converts to
133+
// int32 and does a 32-bit add. We should be using a widening op, and
134+
// storing the accumulators interleaved until `sum_rows`.
135+
stream_reduce<sum_accumulator_k1_1<s32x128>, int8_t, int32_t>(
136+
n, k3, k2, a_stride_k3, a_stride_k2, reinterpret_cast<const int8_t*>(a),
137+
/*C_stride_m=*/0, reinterpret_cast<int32_t*>(c));
138+
} else {
139+
tiled_reduce<sum_accumulator_x32<s32x32, 128, Identity>, int8_t, int32_t>(
140+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
141+
reinterpret_cast<const int8_t*>(a), /*C_stride_m=*/0,
142+
reinterpret_cast<int32_t*>(c));
143+
}
144+
}
145+
146+
void sum_squared_int8_int32_hvx(size_t n, size_t k3, size_t k2, size_t k1,
147+
size_t a_stride_n, size_t a_stride_k3,
148+
size_t a_stride_k2, const void* a, size_t,
149+
void* c) {
150+
if (k1 == 1 && a_stride_n == sizeof(int8_t)) {
151+
// TODO(b/482435301): This case is poorly optimized. It naively converts to
152+
// int32 and does a 32-bit add. We should be using a widening op, and
153+
// storing the accumulators interleaved until `sum_rows`.
154+
stream_reduce<sum_accumulator_k1_1<s32x128, Square>, int8_t, int32_t>(
155+
n, k3, k2, a_stride_k3, a_stride_k2, reinterpret_cast<const int8_t*>(a),
156+
/*C_stride_m=*/0, reinterpret_cast<int32_t*>(c));
157+
} else {
158+
tiled_reduce<sum_accumulator_x32<s32x32, 128, Square>, int8_t, int32_t>(
159+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
160+
reinterpret_cast<const int8_t*>(a), /*C_stride_m=*/0,
161+
reinterpret_cast<int32_t*>(c));
162+
}
163+
}
164+
165+
} // namespace ynn

ynnpack/kernels/reduce/max.inc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, max_int8_4x16_neon, int8_t, int8_t)
88
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, max_uint8_4x16_neon, uint8_t, uint8_t)
99
#endif
1010

11+
#ifdef YNN_ARCH_HEXAGON_HVX
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, max_fp32_4x32_hvx, float, float)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, max_bf16_4x64_hvx, bfloat16, bfloat16)
14+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, max_fp16_4x64_hvx, half, half)
15+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, max_int8_4x128_hvx, int8_t, int8_t)
16+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, max_uint8_4x128_hvx, uint8_t, uint8_t)
17+
#endif // YNN_ARCH_HEXAGON_HVX
18+
1119
#ifdef YNN_ARCH_X86_AVX512
1220
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, max_bf16_4x32_avx512bw, bfloat16, bfloat16)
1321
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, max_fp16_4x32_avx512bw, half, half)

ynnpack/kernels/reduce/min.inc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, min_int8_4x16_neon, int8_t, int8_t)
88
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, min_uint8_4x16_neon, uint8_t, uint8_t)
99
#endif
1010

11+
#ifdef YNN_ARCH_HEXAGON_HVX
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_fp32_4x32_hvx, float, float)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_bf16_4x64_hvx, bfloat16, bfloat16)
14+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_fp16_4x64_hvx, half, half)
15+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_int8_4x128_hvx, int8_t, int8_t)
16+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_uint8_4x128_hvx, uint8_t, uint8_t)
17+
#endif // YNN_ARCH_HEXAGON_HVX
18+
1119
#ifdef YNN_ARCH_X86_AVX512
1220
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, min_bf16_4x32_avx512bw, bfloat16, bfloat16)
1321
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, min_fp16_4x32_avx512bw, bfloat16, bfloat16)

ynnpack/kernels/reduce/min_max.inc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, min_max_int8_4x16_neon, int8_t, int8_t)
88
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, min_max_uint8_4x16_neon, uint8_t, uint8_t)
99
#endif
1010

11+
#ifdef YNN_ARCH_HEXAGON_HVX
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_max_fp32_4x32_hvx, float, float)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_max_bf16_4x64_hvx, bfloat16, bfloat16)
14+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_max_fp16_4x64_hvx, half, half)
15+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_max_int8_4x128_hvx, int8_t, int8_t)
16+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, min_max_uint8_4x128_hvx, uint8_t, uint8_t)
17+
#endif // YNN_ARCH_HEXAGON_HVX
18+
1119
#ifdef YNN_ARCH_X86_AVX512
1220
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, min_max_bf16_4x32_avx512bw, bfloat16, bfloat16)
1321
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, min_max_fp16_4x32_avx512bw, half, half)

ynnpack/kernels/reduce/min_max_accumulator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ struct min_max_accumulator {
136136
}
137137

138138
template <typename AccT>
139-
void accumulate_min(T* __restrict C, size_t n, const AccT* acc) {
139+
void accumulate_min(T* __restrict C, size_t n, const AccT* __restrict acc) {
140140
switch (n) {
141141
case 4:
142142
C[3] = min(C[3], horizontal_min(acc[3]));
@@ -153,7 +153,7 @@ struct min_max_accumulator {
153153
}
154154

155155
template <typename AccT>
156-
void accumulate_max(T* __restrict C, size_t n, const AccT* acc) {
156+
void accumulate_max(T* __restrict C, size_t n, const AccT* __restrict acc) {
157157
switch (n) {
158158
case 4:
159159
C[3] = max(C[3], horizontal_max(acc[3]));

ynnpack/kernels/reduce/sum.inc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
1515
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_bf16_fp32_neon, bfloat16, float)
1616
#endif // YNN_ARCH_ARM_NEON
1717

18+
#ifdef YNN_ARCH_HEXAGON_HVX
19+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, sum_int8_int32_hvx, int8_t, int32_t)
20+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, sum_uint8_int32_hvx, uint8_t, int32_t)
21+
#endif // YNN_ARCH_HEXAGON_HVX
22+
1823
#ifdef YNN_ARCH_X86_AVX512BF16
1924
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_bf16_fp32_avx512bf16, bfloat16, float)
2025
#endif // YNN_ARCH_X86_AVX512BF16

ynnpack/kernels/reduce/sum_accumulator.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ YNN_ALWAYS_INLINE auto sum_rows(const AccT* acc,
6969
auto v_1 = (extract<0>(acc[1], cols) + extract<1>(acc[1], cols)) +
7070
(extract<2>(acc[1], cols) + extract<3>(acc[1], cols));
7171

72+
// TODO(dsharlet): This returns a vector of 4 values, when it should return
73+
// a vector of 2 values.
7274
auto zero = decltype(v_0)(0);
7375
auto t = transpose<typename AccT::value_type>({{v_0, v_1, zero, zero}});
7476
return (t[0] + t[1]) + (t[2] + t[3]);
@@ -98,6 +100,25 @@ YNN_ALWAYS_INLINE auto sum_rows(const AccT* acc,
98100
return (t[0] + t[1]) + (t[2] + t[3]);
99101
}
100102

103+
#ifndef YNN_ARCH_X86
104+
// This is not numerically consistent, don't let it be used on x86.
105+
template <typename AccT, size_t K, size_t N>
106+
YNN_ALWAYS_INLINE auto sum_rows(const AccT* __restrict acc,
107+
std::integral_constant<size_t, K> /*K*/,
108+
std::integral_constant<size_t, N> /*N*/) {
109+
using scalar = typename AccT::value_type;
110+
scalar result[N];
111+
YNN_UNROLL
112+
for (size_t i = 0; i < N; ++i) {
113+
result[i] = simd::horizontal_sum(acc[i]);
114+
}
115+
// TODO(dsharlet): This returns a vector of 4 values to meet the assumptions
116+
// of the callers below. It should return a vector of N values.
117+
static_assert(N <= 4);
118+
return simd::load(result, N, simd::vec<scalar, 4>{});
119+
}
120+
#endif // YNN_ARCH_X86
121+
101122
template <typename AccT, size_t K_, typename MapFn = Identity, size_t N_ = 4>
102123
struct sum_accumulator_x32 {
103124
static constexpr std::integral_constant<size_t, N_> N = {};
@@ -123,16 +144,18 @@ struct sum_accumulator_x32 {
123144
NT n, KT k) {
124145
const simd::vec<AT, K> zero(0);
125146
auto a_0 = load(offset_bytes(A, 0 * A_stride_n), k, zero);
126-
auto a_1 = 1 < n ? load(offset_bytes(A, 1 * A_stride_n), k, zero) : zero;
127147
acc[0] = reduce_add(acc[0], a_0, map_fn, horizontal_factor);
128-
acc[1] = reduce_add(acc[1], a_1, map_fn, horizontal_factor);
129-
130-
if constexpr (N == 4) {
148+
if constexpr (N >= 2) {
149+
auto a_1 = 1 < n ? load(offset_bytes(A, 1 * A_stride_n), k, zero) : zero;
150+
acc[1] = reduce_add(acc[1], a_1, map_fn, horizontal_factor);
151+
}
152+
if constexpr (N >= 4) {
131153
auto a_2 = 2 < n ? load(offset_bytes(A, 2 * A_stride_n), k, zero) : zero;
132154
auto a_3 = 3 < n ? load(offset_bytes(A, 3 * A_stride_n), k, zero) : zero;
133155
acc[2] = reduce_add(acc[2], a_2, map_fn, horizontal_factor);
134156
acc[3] = reduce_add(acc[3], a_3, map_fn, horizontal_factor);
135157
}
158+
static_assert(N <= 4, "");
136159
}
137160

138161
template <typename T, typename NT>

ynnpack/kernels/reduce/sum_squared.inc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
1919
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_bf16_fp32_neon, bfloat16, float)
2020
#endif // YNN_ARCH_ARM_NEON
2121

22+
#ifdef YNN_ARCH_HEXAGON_HVX
23+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, sum_squared_int8_int32_hvx, int8_t, int32_t)
24+
YNN_UNARY_REDUCE_KERNEL(arch_flag::hvx, sum_squared_uint8_int32_hvx, uint8_t, int32_t)
25+
#endif // YNN_ARCH_HEXAGON_HVX
26+
2227
#ifdef YNN_ARCH_X86_AVX512BF16
2328
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_squared_bf16_fp32_avx512bf16, bfloat16, float)
2429
#endif // YNN_ARCH_X86_AVX512BF16

0 commit comments

Comments
 (0)