Skip to content

Commit e538cd7

Browse files
authored
Merge pull request #3486 from stweil/tfloat
Add TFloat data type for neural network
2 parents 66b77e6 + 2759788 commit e538cd7

11 files changed

+464
-39
lines changed

Diff for: src/arch/dotproductavx.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,28 @@ namespace tesseract {
2929

3030
// Computes and returns the dot product of the n-vectors u and v.
3131
// Uses Intel AVX intrinsics to access the SIMD instruction set.
32+
#if defined(FAST_FLOAT)
33+
float DotProductAVX(const float *u, const float *v, int n) {
34+
const unsigned quot = n / 8;
35+
const unsigned rem = n % 8;
36+
__m256 t0 = _mm256_setzero_ps();
37+
for (unsigned k = 0; k < quot; k++) {
38+
__m256 f0 = _mm256_loadu_ps(u);
39+
__m256 f1 = _mm256_loadu_ps(v);
40+
f0 = _mm256_mul_ps(f0, f1);
41+
t0 = _mm256_add_ps(t0, f0);
42+
u += 8;
43+
v += 8;
44+
}
45+
alignas(32) float tmp[8];
46+
_mm256_store_ps(tmp, t0);
47+
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7];
48+
for (unsigned k = 0; k < rem; k++) {
49+
result += *u++ * *v++;
50+
}
51+
return result;
52+
}
53+
#else
3254
double DotProductAVX(const double *u, const double *v, int n) {
3355
const unsigned quot = n / 8;
3456
const unsigned rem = n % 8;
@@ -57,6 +79,7 @@ double DotProductAVX(const double *u, const double *v, int n) {
5779
}
5880
return result;
5981
}
82+
#endif
6083

6184
} // namespace tesseract.
6285

Diff for: src/arch/dotproductfma.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,34 @@ namespace tesseract {
2929

3030
// Computes and returns the dot product of the n-vectors u and v.
3131
// Uses Intel FMA intrinsics to access the SIMD instruction set.
32+
#if defined(FAST_FLOAT)
33+
float DotProductFMA(const float *u, const float *v, int n) {
34+
const unsigned quot = n / 16;
35+
const unsigned rem = n % 16;
36+
__m256 t0 = _mm256_setzero_ps();
37+
__m256 t1 = _mm256_setzero_ps();
38+
for (unsigned k = 0; k < quot; k++) {
39+
__m256 f0 = _mm256_loadu_ps(u);
40+
__m256 f1 = _mm256_loadu_ps(v);
41+
t0 = _mm256_fmadd_ps(f0, f1, t0);
42+
u += 8;
43+
v += 8;
44+
__m256 f2 = _mm256_loadu_ps(u);
45+
__m256 f3 = _mm256_loadu_ps(v);
46+
t1 = _mm256_fmadd_ps(f2, f3, t1);
47+
u += 8;
48+
v += 8;
49+
}
50+
t0 = _mm256_hadd_ps(t0, t1);
51+
alignas(32) float tmp[8];
52+
_mm256_store_ps(tmp, t0);
53+
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7];
54+
for (unsigned k = 0; k < rem; k++) {
55+
result += *u++ * *v++;
56+
}
57+
return result;
58+
}
59+
#else
3260
double DotProductFMA(const double *u, const double *v, int n) {
3361
const unsigned quot = n / 8;
3462
const unsigned rem = n % 8;
@@ -55,6 +83,7 @@ double DotProductFMA(const double *u, const double *v, int n) {
5583
}
5684
return result;
5785
}
86+
#endif
5887

5988
} // namespace tesseract.
6089

Diff for: src/arch/dotproductsse.cpp

+63-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,66 @@ namespace tesseract {
3030

3131
// Computes and returns the dot product of the n-vectors u and v.
3232
// Uses Intel SSE intrinsics to access the SIMD instruction set.
33+
#if defined(FAST_FLOAT)
34+
float DotProductSSE(const float *u, const float *v, int n) {
35+
int max_offset = n - 4;
36+
int offset = 0;
37+
// Accumulate a set of 4 sums in sum, by loading pairs of 4 values from u and
38+
// v, and multiplying them together in parallel.
39+
__m128 sum = _mm_setzero_ps();
40+
if (offset <= max_offset) {
41+
offset = 4;
42+
// Aligned load is reputedly faster but requires 16 byte aligned input.
43+
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 &&
44+
(reinterpret_cast<uintptr_t>(v) & 15) == 0) {
45+
// Use aligned load.
46+
sum = _mm_load_ps(u);
47+
__m128 floats2 = _mm_load_ps(v);
48+
// Multiply.
49+
sum = _mm_mul_ps(sum, floats2);
50+
while (offset <= max_offset) {
51+
__m128 floats1 = _mm_load_ps(u + offset);
52+
floats2 = _mm_load_ps(v + offset);
53+
floats1 = _mm_mul_ps(floats1, floats2);
54+
sum = _mm_add_ps(sum, floats1);
55+
offset += 4;
56+
}
57+
} else {
58+
// Use unaligned load.
59+
sum = _mm_loadu_ps(u);
60+
__m128 floats2 = _mm_loadu_ps(v);
61+
// Multiply.
62+
sum = _mm_mul_ps(sum, floats2);
63+
while (offset <= max_offset) {
64+
__m128 floats1 = _mm_loadu_ps(u + offset);
65+
floats2 = _mm_loadu_ps(v + offset);
66+
floats1 = _mm_mul_ps(floats1, floats2);
67+
sum = _mm_add_ps(sum, floats1);
68+
offset += 4;
69+
}
70+
}
71+
}
72+
// Add the 4 sums in sum horizontally.
73+
#if 0
74+
alignas(32) float tmp[4];
75+
_mm_store_ps(tmp, sum);
76+
float result = tmp[0] + tmp[1] + tmp[2] + tmp[3];
77+
#else
78+
__m128 zero = _mm_setzero_ps();
79+
// https://www.felixcloutier.com/x86/haddps
80+
sum = _mm_hadd_ps(sum, zero);
81+
sum = _mm_hadd_ps(sum, zero);
82+
// Extract the low result.
83+
float result = _mm_cvtss_f32(sum);
84+
#endif
85+
// Add on any left-over products.
86+
while (offset < n) {
87+
result += u[offset] * v[offset];
88+
++offset;
89+
}
90+
return result;
91+
}
92+
#else
3393
double DotProductSSE(const double *u, const double *v, int n) {
3494
int max_offset = n - 2;
3595
int offset = 0;
@@ -39,7 +99,8 @@ double DotProductSSE(const double *u, const double *v, int n) {
3999
if (offset <= max_offset) {
40100
offset = 2;
41101
// Aligned load is reputedly faster but requires 16 byte aligned input.
42-
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 && (reinterpret_cast<uintptr_t>(v) & 15) == 0) {
102+
if ((reinterpret_cast<uintptr_t>(u) & 15) == 0 &&
103+
(reinterpret_cast<uintptr_t>(v) & 15) == 0) {
43104
// Use aligned load.
44105
sum = _mm_load_pd(u);
45106
__m128d floats2 = _mm_load_pd(v);
@@ -78,6 +139,7 @@ double DotProductSSE(const double *u, const double *v, int n) {
78139
}
79140
return result;
80141
}
142+
#endif
81143

82144
} // namespace tesseract.
83145

Diff for: src/arch/intsimdmatrix.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ struct TESS_API IntSimdMatrix {
115115
static const IntSimdMatrix *intSimdMatrix;
116116
// Only available with NEON.
117117
static const IntSimdMatrix intSimdMatrixNEON;
118-
// Only available with AVX2 / SSE.
118+
// Only available with AVX2 / AVX / FMA / SSE.
119119
static const IntSimdMatrix intSimdMatrixAVX2;
120120
static const IntSimdMatrix intSimdMatrixSSE;
121121
};

0 commit comments

Comments
 (0)