2323#include < algorithm>
2424#include < cmath>
2525#include < sstream>
26+ #include < type_traits>
2627
2728#include " common_defs.hpp"
2829#include " memory_operations.hpp"
2930
3031namespace datasketches {
3132
33+ template <typename T>
34+ inline void check_not_nan (T value, const char * name) {
35+ if (std::isnan (value)) {
36+ throw std::invalid_argument (std::string (name) + " must not be NaN" );
37+ }
38+ }
39+
40+ template <typename T>
41+ inline void check_not_infinite (T value, const char * name) {
42+ if (std::isinf (value)) {
43+ throw std::invalid_argument (std::string (name) + " must not be infinite" );
44+ }
45+ }
46+
47+ template <typename T>
48+ inline void check_non_zero (T value, const char * name) {
49+ static_assert (std::is_arithmetic<T>::value, " T must be an arithmetic type" );
50+ if (value == 0 ) {
51+ throw std::invalid_argument (std::string (name) + " must not be zero" );
52+ }
53+ }
54+
3255template <typename T, typename A>
3356tdigest<T, A>::tdigest(uint16_t k, const A& allocator):
3457tdigest (false , k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::infinity(), vector_centroid(allocator), 0, vector_t(allocator))
@@ -37,6 +60,7 @@ tdigest(false, k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::i
3760template <typename T, typename A>
3861void tdigest<T, A>::update(T value) {
3962 if (std::isnan (value)) return ;
63+ if (std::isinf (value)) return ;
4064 if (buffer_.size () == centroids_capacity_ * BUFFER_MULTIPLIER) compress ();
4165 buffer_.push_back (value);
4266 min_ = std::min (min_, value);
@@ -400,6 +424,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
400424 const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE);
401425 if (is_single_value) {
402426 const T value = read<T>(is);
427+ check_not_nan (value, " single_value" );
428+ check_not_infinite (value, " single_value" );
403429 return tdigest (reverse_merge, k, value, value, vector_centroid (1 , centroid (value, 1 ), allocator), 1 , vector_t (allocator));
404430 }
405431
@@ -408,12 +434,26 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
408434
409435 const T min = read<T>(is);
410436 const T max = read<T>(is);
437+ check_not_nan (min, " min" );
438+ check_not_infinite (min, " min" );
439+ check_not_nan (max, " max" );
440+ check_not_infinite (max, " max" );
411441 vector_centroid centroids (num_centroids, centroid (0 , 0 ), allocator);
412442 if (num_centroids > 0 ) read (is, centroids.data (), num_centroids * sizeof (centroid));
413443 vector_t buffer (num_buffered, 0 , allocator);
414444 if (num_buffered > 0 ) read (is, buffer.data (), num_buffered * sizeof (T));
415445 uint64_t weight = 0 ;
416- for (const auto & c: centroids) weight += c.get_weight ();
446+ for (const auto & c: centroids) {
447+ check_not_nan (c.get_mean (), " centroid mean" );
448+ check_not_infinite (c.get_mean (), " centroid mean" );
449+ check_non_zero (c.get_weight (), " centroid weight" );
450+
451+ weight += c.get_weight ();
452+ }
453+ for (const auto & value: buffer) {
454+ check_not_nan (value, " buffered_value" );
455+ check_not_infinite (value, " buffered_value" );
456+ }
417457 return tdigest (reverse_merge, k, min, max, std::move (centroids), weight, std::move (buffer));
418458}
419459
@@ -451,6 +491,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
451491 ensure_minimum_memory (end_ptr - ptr, sizeof (T));
452492 T value;
453493 ptr += copy_from_mem (ptr, value);
494+ check_not_nan (value, " single_value" );
495+ check_not_infinite (value, " single_value" );
454496 return tdigest (reverse_merge, k, value, value, vector_centroid (1 , centroid (value, 1 ), allocator), 1 , vector_t (allocator));
455497 }
456498
@@ -465,12 +507,26 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
465507 ptr += copy_from_mem (ptr, min);
466508 T max;
467509 ptr += copy_from_mem (ptr, max);
510+ check_not_nan (min, " min" );
511+ check_not_infinite (min, " min" );
512+ check_not_nan (max, " max" );
513+ check_not_infinite (max, " max" );
468514 vector_centroid centroids (num_centroids, centroid (0 , 0 ), allocator);
469515 if (num_centroids > 0 ) ptr += copy_from_mem (ptr, centroids.data (), num_centroids * sizeof (centroid));
470516 vector_t buffer (num_buffered, 0 , allocator);
471517 if (num_buffered > 0 ) copy_from_mem (ptr, buffer.data (), num_buffered * sizeof (T));
472518 uint64_t weight = 0 ;
473- for (const auto & c: centroids) weight += c.get_weight ();
519+ for (const auto & c: centroids) {
520+ check_not_nan (c.get_mean (), " centroid mean" );
521+ check_not_infinite (c.get_mean (), " centroid mean" );
522+ check_non_zero (c.get_weight (), " centroid weight" );
523+
524+ weight += c.get_weight ();
525+ }
526+ for (const auto & value: buffer) {
527+ check_not_nan (value, " buffered_value" );
528+ check_not_infinite (value, " buffered_value" );
529+ }
474530 return tdigest (reverse_merge, k, min, max, std::move (centroids), weight, std::move (buffer));
475531}
476532
@@ -487,13 +543,24 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
487543 if (type == COMPAT_DOUBLE) { // compatibility with asBytes()
488544 const auto min = read_big_endian<double >(is);
489545 const auto max = read_big_endian<double >(is);
546+ check_not_nan (min, " min" );
547+ check_not_infinite (min, " min" );
548+ check_not_nan (max, " max" );
549+ check_not_infinite (max, " max" );
490550 const auto k = static_cast <uint16_t >(read_big_endian<double >(is));
491551 const auto num_centroids = read_big_endian<uint32_t >(is);
492552 vector_centroid centroids (num_centroids, centroid (0 , 0 ), allocator);
493553 uint64_t total_weight = 0 ;
494554 for (auto & c: centroids) {
495- const W weight = static_cast <W>(read_big_endian<double >(is));
555+ const auto weight_double = read_big_endian<double >(is);
556+ check_not_nan (weight_double, " centroid weight" );
557+ check_not_infinite (weight_double, " centroid weight" );
558+ check_non_zero (weight_double, " centroid weight" );
559+
496560 const auto mean = read_big_endian<double >(is);
561+ check_not_nan (mean, " centroid mean" );
562+ check_not_infinite (mean, " centroid mean" );
563+ const W weight = static_cast <W>(weight_double);
497564 c = centroid (mean, weight);
498565 total_weight += weight;
499566 }
@@ -502,6 +569,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
502569 // COMPAT_FLOAT: compatibility with asSmallBytes()
503570 const auto min = read_big_endian<double >(is); // reference implementation uses doubles for min and max
504571 const auto max = read_big_endian<double >(is);
572+ check_not_nan (min, " min" );
573+ check_not_infinite (min, " min" );
574+ check_not_nan (max, " max" );
575+ check_not_infinite (max, " max" );
505576 const auto k = static_cast <uint16_t >(read_big_endian<float >(is));
506577 // reference implementation stores capacities of the array of centroids and the buffer as shorts
507578 // they can be derived from k in the constructor
@@ -510,8 +581,13 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
510581 vector_centroid centroids (num_centroids, centroid (0 , 0 ), allocator);
511582 uint64_t total_weight = 0 ;
512583 for (auto & c: centroids) {
513- const W weight = static_cast <W>(read_big_endian<float >(is));
584+ const auto weight_float = read_big_endian<float >(is);
585+ check_not_nan (weight_float, " centroid weight" );
586+ check_not_infinite (weight_float, " centroid weight" );
514587 const auto mean = read_big_endian<float >(is);
588+ check_not_nan (mean, " centroid mean" );
589+ check_not_infinite (mean, " centroid mean" );
590+ const W weight = static_cast <W>(weight_float);
515591 c = centroid (mean, weight);
516592 total_weight += weight;
517593 }
@@ -538,6 +614,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
538614 double max;
539615 ptr += copy_from_mem (ptr, max);
540616 max = byteswap (max);
617+ check_not_nan (min, " min" );
618+ check_not_infinite (min, " min" );
619+ check_not_nan (max, " max" );
620+ check_not_infinite (max, " max" );
541621 double k_double;
542622 ptr += copy_from_mem (ptr, k_double);
543623 const uint16_t k = static_cast <uint16_t >(byteswap (k_double));
@@ -554,6 +634,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
554634 double mean;
555635 ptr += copy_from_mem (ptr, mean);
556636 mean = byteswap (mean);
637+ check_not_nan (weight, " centroid weight" );
638+ check_not_infinite (weight, " centroid weight" );
639+ check_not_nan (mean, " centroid mean" );
640+ check_not_infinite (mean, " centroid mean" );
557641 c = centroid (mean, static_cast <W>(weight));
558642 total_weight += static_cast <uint64_t >(weight);
559643 }
@@ -567,6 +651,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
567651 double max;
568652 ptr += copy_from_mem (ptr, max);
569653 max = byteswap (max);
654+ check_not_nan (min, " min" );
655+ check_not_infinite (min, " min" );
656+ check_not_nan (max, " max" );
657+ check_not_infinite (max, " max" );
570658 float k_float;
571659 ptr += copy_from_mem (ptr, k_float);
572660 const uint16_t k = static_cast <uint16_t >(byteswap (k_float));
@@ -586,6 +674,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
586674 float mean;
587675 ptr += copy_from_mem (ptr, mean);
588676 mean = byteswap (mean);
677+ check_not_nan (weight, " centroid weight" );
678+ check_not_infinite (weight, " centroid weight" );
679+ check_not_nan (mean, " centroid mean" );
680+ check_not_infinite (mean, " centroid mean" );
589681 c = centroid (mean, static_cast <W>(weight));
590682 total_weight += static_cast <uint64_t >(weight);
591683 }
0 commit comments