Skip to content

Commit 5358463

Browse files
authored
Merge pull request #4 from ti-uni-bielefeld/fix_vec_to_emulated_mask_conversion
Fix conversion from Vec to emulated Mask
2 parents 9a897c7 + 0babf04 commit 5358463

File tree

7 files changed

+190
-8
lines changed

7 files changed

+190
-8
lines changed

src/lib/tsimd/mask_impl_emu.H

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,13 @@ class Mask
121121

122122
public:
123123
Mask() = default;
124-
explicit SIMD_INLINE Mask(const Vec<T, SIMD_WIDTH> &x) : mask(x) {}
124+
explicit SIMD_INLINE Mask(const Vec<T, SIMD_WIDTH> &x)
125+
{
126+
// shift the most significant bit into all bits
127+
const auto &xInt = reinterpret<typename TypeInfo<T>::IntegerType>(x);
128+
const auto &shifted = srai<sizeof(T) * 8 - 1>(xInt);
129+
mask = reinterpret<T>(shifted);
130+
}
125131
SIMD_INLINE Mask(const uint64_t x) : mask(int2bits<T, SIMD_WIDTH>(x)) {}
126132
explicit SIMD_INLINE operator Vec<T, SIMD_WIDTH>() const { return mask; }
127133
SIMD_INLINE operator uint64_t() const { return msb2int<T, SIMD_WIDTH>(mask); }

src/lib/tsimd/types.H

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ struct TypeInfo<Byte>
152152
using NextLargerType = Word;
153153
using UnsignedType = Byte;
154154
using SignedType = SignedByte;
155+
using IntegerType = SignedByte;
155156
};
156157

157158
template <>
@@ -168,6 +169,7 @@ struct TypeInfo<SignedByte>
168169
using NextLargerType = Short;
169170
using UnsignedType = Byte;
170171
using SignedType = SignedByte;
172+
using IntegerType = SignedByte;
171173
};
172174

173175
template <>
@@ -181,6 +183,7 @@ struct TypeInfo<Word>
181183
using NextLargerType = Int; // no larger unsigned type, use Int
182184
using UnsignedType = Word;
183185
using SignedType = Short;
186+
using IntegerType = Short;
184187
};
185188

186189
template <>
@@ -194,6 +197,7 @@ struct TypeInfo<Short>
194197
using NextLargerType = Int;
195198
using UnsignedType = Word;
196199
using SignedType = Short;
200+
using IntegerType = Short;
197201
};
198202

199203
template <>
@@ -207,6 +211,7 @@ struct TypeInfo<Int>
207211
using NextLargerType = Long;
208212
using UnsignedType = uint32_t; // not a SIMD type
209213
using SignedType = Int;
214+
using IntegerType = Int;
210215
};
211216

212217
template <>
@@ -220,6 +225,7 @@ struct TypeInfo<Long>
220225
using NextLargerType = Long; // no larger integer type than Long
221226
using UnsignedType = uint64_t; // not a SIMD type
222227
using SignedType = Long;
228+
using IntegerType = Long;
223229
};
224230

225231
template <>
@@ -232,6 +238,7 @@ struct TypeInfo<Float>
232238
using NextLargerType = Double;
233239
using UnsignedType = Float; // no unsigned float type
234240
using SignedType = Float;
241+
using IntegerType = Int;
235242
};
236243

237244
template <>
@@ -244,6 +251,7 @@ struct TypeInfo<Double>
244251
using NextLargerType = Double; // no larger double type than Double
245252
using UnsignedType = Double; // no unsigned double type
246253
using SignedType = Double;
254+
using IntegerType = Long;
247255
};
248256

249257
} // namespace types
@@ -321,6 +329,9 @@ struct TypeInfo
321329
/// @brief The signed type (e.g. SignedByte for Byte), or the same
322330
/// type if there is no signed type
323331
using SignedType = typename internal::types::TypeInfo<T>::SignedType;
332+
/// @brief The signed integer type of the same size (e.g. SignedByte for Byte,
333+
/// Int for Float)
334+
using IntegerType = typename internal::types::TypeInfo<T>::IntegerType;
324335
};
325336

326337
// ===========================================================================

src/test/autotest/core.H

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,15 @@ struct CmpEqual
248248

249249
template <typename T, size_t SIMD_WIDTH>
250250
static bool cmpVec(SerialVec<T, SIMD_WIDTH> &expected,
251-
SerialVec<T, SIMD_WIDTH> &actual)
251+
SerialVec<T, SIMD_WIDTH> &actual,
252+
bool strictBitsCmp = false)
252253
{
253254
if (memcmp(&expected, &actual, sizeof(SerialVec<T, SIMD_WIDTH>)) == 0) {
254255
return true;
256+
} else if (strictBitsCmp) {
257+
// if strictBitsCmp is true, we only want to compare the bit patterns of
258+
// the SerialVecs, so we return false here
259+
return false;
255260
}
256261
for (size_t i = 0; i < SerialVec<T, SIMD_WIDTH>::elems; ++i) {
257262
if (!cmpScalar(expected[i], actual[i])) { return false; }
@@ -261,11 +266,11 @@ struct CmpEqual
261266

262267
template <typename T, size_t SIMD_WIDTH>
263268
static bool cmpVec(SerialVec<T, SIMD_WIDTH> &expected,
264-
Vec<T, SIMD_WIDTH> &actual)
269+
Vec<T, SIMD_WIDTH> &actual, bool strictBitsCmp = false)
265270
{
266271
SerialVec<T, SIMD_WIDTH> actualSerial =
267272
SerialVec<T, SIMD_WIDTH>::fromVec(actual);
268-
return cmpVec(expected, actualSerial);
273+
return cmpVec(expected, actualSerial, strictBitsCmp);
269274
}
270275
};
271276

src/test/autotest/mask.H

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,117 @@ struct TesterMaskUnaryBool
838838
}
839839
};
840840

841+
// =============================================================================
842+
// test for binary vector functions that also need a mask template parameter
843+
// =============================================================================
844+
845+
template <typename T, size_t SIMD_WIDTH, template <typename, size_t> class FCT,
846+
class CMP = CmpEqual>
847+
struct TesterBinaryWithMask
848+
{
849+
static void test(size_t repeats, const std::string &pattern)
850+
{
851+
std::string name = FCT<T, SIMD_WIDTH>::name();
852+
if (name.find(pattern) == std::string::npos) { return; }
853+
std::string args = "v,v";
854+
printInfo(name, args);
855+
#if !ONLY_TIME_MEASUREMENT
856+
size_t errors = 0, trials = 0;
857+
for (size_t i = 0; i < repeats; i++, trials++) {
858+
SerialVec<T, SIMD_WIDTH> cond;
859+
SerialVec<T, SIMD_WIDTH> a;
860+
FCT<T, SIMD_WIDTH>::randomizeInput(cond);
861+
FCT<T, SIMD_WIDTH>::randomizeInput(a);
862+
SerialVec<T, SIMD_WIDTH> cs =
863+
FCT<T, SIMD_WIDTH>::template apply<SerialVec, SerialMask>(cond, a);
864+
SerialVec<T, SIMD_WIDTH> cp = SerialVec<T, SIMD_WIDTH>::fromVec(
865+
FCT<T, SIMD_WIDTH>::template apply<Vec, Mask>(cond.getVec(),
866+
a.getVec()));
867+
if (!CMP::cmpVec(cs, cp)) {
868+
errors++;
869+
if (errors <= PRINT_ERRORS) {
870+
printError(name, args);
871+
PRINT_VEC(T, cond);
872+
PRINT_VEC(T, a);
873+
PRINT_VEC(T, cs);
874+
PRINT_VEC(T, cp);
875+
}
876+
EXIT;
877+
}
878+
}
879+
printErrorStats(errors, trials);
880+
#endif
881+
882+
Vec<T, SIMD_WIDTH> input1, input2;
883+
struct timespec start = getTimeSpecMonotonic();
884+
for (size_t i = 0; i < repeats / SIMD_TIME_MEASUREMENT_UNROLL; i++) {
885+
for (size_t j = 0; j < SIMD_TIME_MEASUREMENT_UNROLL; j++) {
886+
doNotOptimize(input1);
887+
doNotOptimize(input2);
888+
Vec<T, SIMD_WIDTH> result =
889+
FCT<T, SIMD_WIDTH>::template apply<Vec, Mask>(input1, input2);
890+
doNotOptimize(result);
891+
}
892+
}
893+
struct timespec end = getTimeSpecMonotonic();
894+
long int time = timeSpecDiffNsec(end, start);
895+
printTimeStats(time, repeats);
896+
}
897+
};
898+
899+
// =============================================================================
900+
// test for unary mask functions that return a vector
901+
// =============================================================================
902+
903+
template <typename T, size_t SIMD_WIDTH, template <typename, size_t> class FCT,
904+
class CMP = CmpEqual>
905+
struct TesterUnaryMaskToVec
906+
{
907+
static void test(size_t repeats, const std::string &pattern)
908+
{
909+
std::string name = FCT<T, SIMD_WIDTH>::name();
910+
if (name.find(pattern) == std::string::npos) { return; }
911+
std::string args = "m";
912+
printInfo(name, args);
913+
#if !ONLY_TIME_MEASUREMENT
914+
size_t errors = 0, trials = 0;
915+
for (size_t i = 0; i < repeats; i++, trials++) {
916+
SerialMask<T, SIMD_WIDTH> k;
917+
FCT<T, SIMD_WIDTH>::randomizeInput(k);
918+
SerialVec<T, SIMD_WIDTH> cs =
919+
FCT<T, SIMD_WIDTH>::template apply<SerialVec, SerialMask>(k);
920+
SerialVec<T, SIMD_WIDTH> cp = SerialVec<T, SIMD_WIDTH>::fromVec(
921+
FCT<T, SIMD_WIDTH>::template apply<Vec, Mask>(k.getMask()));
922+
if (!CMP::cmpVec(cs, cp, true)) {
923+
errors++;
924+
if (errors <= PRINT_ERRORS) {
925+
printError(name, args);
926+
PRINT_SERIAL_MASK(T, SIMD_WIDTH, k);
927+
PRINT_VEC(T, cs);
928+
PRINT_VEC(T, cp);
929+
}
930+
EXIT;
931+
}
932+
}
933+
printErrorStats(errors, trials);
934+
#endif
935+
936+
Mask<T, SIMD_WIDTH> inputMask;
937+
struct timespec start = getTimeSpecMonotonic();
938+
for (size_t i = 0; i < repeats / SIMD_TIME_MEASUREMENT_UNROLL; i++) {
939+
for (size_t j = 0; j < SIMD_TIME_MEASUREMENT_UNROLL; j++) {
940+
doNotOptimize(inputMask);
941+
Vec<T, SIMD_WIDTH> result =
942+
FCT<T, SIMD_WIDTH>::template apply<Vec, Mask>(inputMask);
943+
doNotOptimize(result);
944+
}
945+
}
946+
struct timespec end = getTimeSpecMonotonic();
947+
long int time = timeSpecDiffNsec(end, start);
948+
printTimeStats(time, repeats);
949+
}
950+
};
951+
841952
// =============================================================================
842953
// test for unary mask functions
843954
// =============================================================================

src/test/autotest/serial_mask.H

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,24 @@ struct SerialMask
137137
SerialMask(const typename _UInt<bits>::type &_mask) : mask(_mask) {}
138138
SerialMask(const SerialVec<T, SIMD_WIDTH> &x)
139139
{
140+
const SerialVec<UInt<sizeof(T) * 8>, SIMD_WIDTH> &xUInt =
141+
reinterpret<UInt<sizeof(T) * 8>>(x);
140142
mask = 0;
141-
// avoids comparison with TypeInfo::trueval, which fails for
142-
// Float (NaN)
143-
SerialVec<T, SIMD_WIDTH> negX = bit_not(x);
144143
for (size_t i = 0; i < x.elements; i++) {
145-
if (negX[i] == T(0)) { mask = mask | (UInt<bits>(1) << i); }
144+
if ((xUInt[i] >> (sizeof(T) * 8 - 1)) != 0) {
145+
mask = mask | (UInt<bits>(1) << i);
146+
}
146147
}
147148
}
149+
explicit operator SerialVec<T, SIMD_WIDTH>() const
150+
{
151+
SerialVec<T, SIMD_WIDTH> result;
152+
for (size_t i = 0; i < bits; i++) {
153+
result[i] =
154+
(mask & (UInt<bits>(1) << i)) != 0 ? TypeInfo<T>::trueval() : T(0);
155+
}
156+
return result;
157+
}
148158

149159
operator UInt<bits>() const { return mask; }
150160
operator typename _UInt<bits>::type() const { return mask; }

src/test/autotest/testM.C

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ int main(int argc, char *argv[])
5151

5252
printf("pattern \"%s\", repeats1 = %zu\n", pattern.c_str(), repeats1);
5353

54+
TestAll<TesterUnaryMaskToVec, SW, Mask_toVec>::test(repeats1, pattern);
55+
TestAll<TesterBinaryWithMask, SW, Mask_ifelsezeroFromVec>::test(repeats1,
56+
pattern);
5457
TestAll<TesterMaskConditionBinary, SW, Mask_ifelse>::test(repeats1, pattern);
5558
TestAll<TesterMaskConditionUnary, SW, Mask_ifelsezero>::test(repeats1,
5659
pattern);

src/test/autotest/wrappers_mask.H

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@
4141
namespace simd {
4242
namespace auto_test {
4343

44+
template <typename T, size_t SIMD_WIDTH>
45+
struct Mask_toVec
46+
{
47+
static std::string name() { return t2s<T, SIMD_WIDTH>("Vec"); }
48+
static void randomizeInput(SerialMask<T, SIMD_WIDTH> &mask)
49+
{
50+
mask.randomize();
51+
}
52+
template <template <typename, size_t> class Vec,
53+
template <typename, size_t> class Mask>
54+
static Vec<T, SIMD_WIDTH> apply(const Mask<T, SIMD_WIDTH> &a)
55+
{
56+
return Vec<T, SIMD_WIDTH>(a);
57+
};
58+
};
59+
4460
template <typename T, size_t SIMD_WIDTH>
4561
struct Mask_ifelse
4662
{
@@ -78,6 +94,26 @@ struct Mask_ifelsezero
7894
}
7995
};
8096

97+
// the same as Mask_ifelsezero, but with the condition as a vector, to test
98+
// conversion from vector to mask
99+
template <typename T, size_t SIMD_WIDTH>
100+
struct Mask_ifelsezeroFromVec
101+
{
102+
static std::string name() { return t2s<T, SIMD_WIDTH>("mask_ifelsezero"); }
103+
static void randomizeInput(SerialMask<T, SIMD_WIDTH> &mask)
104+
{
105+
mask.randomize();
106+
}
107+
static void randomizeInput(SerialVec<T, SIMD_WIDTH> &vec) { vec.randomize(); }
108+
template <template <typename, size_t> class Vec,
109+
template <typename, size_t> class Mask>
110+
static Vec<T, SIMD_WIDTH> apply(const Vec<T, SIMD_WIDTH> &cond,
111+
const Vec<T, SIMD_WIDTH> &trueVal)
112+
{
113+
return mask_ifelsezero(Mask<T, SIMD_WIDTH>(cond), trueVal);
114+
}
115+
};
116+
81117
template <typename Tout, typename T, size_t SIMD_WIDTH>
82118
struct Mask_cvts
83119
{

0 commit comments

Comments
 (0)