Skip to content

Commit 7a2f9e3

Browse files
authored
Merge pull request apache#471 from proost/fix-tdigest-inf-params
fix: rejecting inf as value
2 parents 3c07e75 + 1979834 commit 7a2f9e3

File tree

3 files changed

+229
-4
lines changed

3 files changed

+229
-4
lines changed

tdigest/include/tdigest.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class tdigest {
108108

109109
/**
110110
* Update this t-Digest with the given value
111+
* NaN and infinity values are ignored
111112
* @param value to update the t-Digest with
112113
*/
113114
void update(T value);
@@ -153,6 +154,7 @@ class tdigest {
153154
* Compute approximate normalized rank of the given value.
154155
*
155156
* <p>If the sketch is empty this throws std::runtime_error.
157+
* <p>NaN value throw std::invalid_argument.
156158
*
157159
* @param value to be ranked
158160
* @return normalized rank (from 0 to 1 inclusive)

tdigest/include/tdigest_impl.hpp

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,35 @@
2323
#include <algorithm>
2424
#include <cmath>
2525
#include <sstream>
26+
#include <type_traits>
2627

2728
#include "common_defs.hpp"
2829
#include "memory_operations.hpp"
2930

3031
namespace datasketches {
3132

33+
template<typename T>
34+
inline void check_not_nan(T value, const char* name) {
35+
if (std::isnan(value)) {
36+
throw std::invalid_argument(std::string(name) + " must not be NaN");
37+
}
38+
}
39+
40+
template<typename T>
41+
inline void check_not_infinite(T value, const char* name) {
42+
if (std::isinf(value)) {
43+
throw std::invalid_argument(std::string(name) + " must not be infinite");
44+
}
45+
}
46+
47+
template<typename T>
48+
inline void check_non_zero(T value, const char* name) {
49+
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
50+
if (value == 0) {
51+
throw std::invalid_argument(std::string(name) + " must not be zero");
52+
}
53+
}
54+
3255
template<typename T, typename A>
3356
tdigest<T, A>::tdigest(uint16_t k, const A& allocator):
3457
tdigest(false, k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::infinity(), vector_centroid(allocator), 0, vector_t(allocator))
@@ -37,6 +60,7 @@ tdigest(false, k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::i
3760
template<typename T, typename A>
3861
void tdigest<T, A>::update(T value) {
3962
if (std::isnan(value)) return;
63+
if (std::isinf(value)) return;
4064
if (buffer_.size() == centroids_capacity_ * BUFFER_MULTIPLIER) compress();
4165
buffer_.push_back(value);
4266
min_ = std::min(min_, value);
@@ -400,6 +424,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
400424
const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE);
401425
if (is_single_value) {
402426
const T value = read<T>(is);
427+
check_not_nan(value, "single_value");
428+
check_not_infinite(value, "single_value");
403429
return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator));
404430
}
405431

@@ -408,12 +434,26 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
408434

409435
const T min = read<T>(is);
410436
const T max = read<T>(is);
437+
check_not_nan(min, "min");
438+
check_not_infinite(min, "min");
439+
check_not_nan(max, "max");
440+
check_not_infinite(max, "max");
411441
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
412442
if (num_centroids > 0) read(is, centroids.data(), num_centroids * sizeof(centroid));
413443
vector_t buffer(num_buffered, 0, allocator);
414444
if (num_buffered > 0) read(is, buffer.data(), num_buffered * sizeof(T));
415445
uint64_t weight = 0;
416-
for (const auto& c: centroids) weight += c.get_weight();
446+
for (const auto& c: centroids) {
447+
check_not_nan(c.get_mean(), "centroid mean");
448+
check_not_infinite(c.get_mean(), "centroid mean");
449+
check_non_zero(c.get_weight(), "centroid weight");
450+
451+
weight += c.get_weight();
452+
}
453+
for (const auto& value: buffer) {
454+
check_not_nan(value, "buffered_value");
455+
check_not_infinite(value, "buffered_value");
456+
}
417457
return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer));
418458
}
419459

@@ -451,6 +491,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
451491
ensure_minimum_memory(end_ptr - ptr, sizeof(T));
452492
T value;
453493
ptr += copy_from_mem(ptr, value);
494+
check_not_nan(value, "single_value");
495+
check_not_infinite(value, "single_value");
454496
return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator));
455497
}
456498

@@ -465,12 +507,26 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
465507
ptr += copy_from_mem(ptr, min);
466508
T max;
467509
ptr += copy_from_mem(ptr, max);
510+
check_not_nan(min, "min");
511+
check_not_infinite(min, "min");
512+
check_not_nan(max, "max");
513+
check_not_infinite(max, "max");
468514
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
469515
if (num_centroids > 0) ptr += copy_from_mem(ptr, centroids.data(), num_centroids * sizeof(centroid));
470516
vector_t buffer(num_buffered, 0, allocator);
471517
if (num_buffered > 0) copy_from_mem(ptr, buffer.data(), num_buffered * sizeof(T));
472518
uint64_t weight = 0;
473-
for (const auto& c: centroids) weight += c.get_weight();
519+
for (const auto& c: centroids) {
520+
check_not_nan(c.get_mean(), "centroid mean");
521+
check_not_infinite(c.get_mean(), "centroid mean");
522+
check_non_zero(c.get_weight(), "centroid weight");
523+
524+
weight += c.get_weight();
525+
}
526+
for (const auto& value: buffer) {
527+
check_not_nan(value, "buffered_value");
528+
check_not_infinite(value, "buffered_value");
529+
}
474530
return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer));
475531
}
476532

@@ -487,13 +543,24 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
487543
if (type == COMPAT_DOUBLE) { // compatibility with asBytes()
488544
const auto min = read_big_endian<double>(is);
489545
const auto max = read_big_endian<double>(is);
546+
check_not_nan(min, "min");
547+
check_not_infinite(min, "min");
548+
check_not_nan(max, "max");
549+
check_not_infinite(max, "max");
490550
const auto k = static_cast<uint16_t>(read_big_endian<double>(is));
491551
const auto num_centroids = read_big_endian<uint32_t>(is);
492552
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
493553
uint64_t total_weight = 0;
494554
for (auto& c: centroids) {
495-
const W weight = static_cast<W>(read_big_endian<double>(is));
555+
const auto weight_double = read_big_endian<double>(is);
556+
check_not_nan(weight_double, "centroid weight");
557+
check_not_infinite(weight_double, "centroid weight");
558+
check_non_zero(weight_double, "centroid weight");
559+
496560
const auto mean = read_big_endian<double>(is);
561+
check_not_nan(mean, "centroid mean");
562+
check_not_infinite(mean, "centroid mean");
563+
const W weight = static_cast<W>(weight_double);
497564
c = centroid(mean, weight);
498565
total_weight += weight;
499566
}
@@ -502,6 +569,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
502569
// COMPAT_FLOAT: compatibility with asSmallBytes()
503570
const auto min = read_big_endian<double>(is); // reference implementation uses doubles for min and max
504571
const auto max = read_big_endian<double>(is);
572+
check_not_nan(min, "min");
573+
check_not_infinite(min, "min");
574+
check_not_nan(max, "max");
575+
check_not_infinite(max, "max");
505576
const auto k = static_cast<uint16_t>(read_big_endian<float>(is));
506577
// reference implementation stores capacities of the array of centroids and the buffer as shorts
507578
// they can be derived from k in the constructor
@@ -510,8 +581,13 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
510581
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
511582
uint64_t total_weight = 0;
512583
for (auto& c: centroids) {
513-
const W weight = static_cast<W>(read_big_endian<float>(is));
584+
const auto weight_float = read_big_endian<float>(is);
585+
check_not_nan(weight_float, "centroid weight");
586+
check_not_infinite(weight_float, "centroid weight");
514587
const auto mean = read_big_endian<float>(is);
588+
check_not_nan(mean, "centroid mean");
589+
check_not_infinite(mean, "centroid mean");
590+
const W weight = static_cast<W>(weight_float);
515591
c = centroid(mean, weight);
516592
total_weight += weight;
517593
}
@@ -538,6 +614,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
538614
double max;
539615
ptr += copy_from_mem(ptr, max);
540616
max = byteswap(max);
617+
check_not_nan(min, "min");
618+
check_not_infinite(min, "min");
619+
check_not_nan(max, "max");
620+
check_not_infinite(max, "max");
541621
double k_double;
542622
ptr += copy_from_mem(ptr, k_double);
543623
const uint16_t k = static_cast<uint16_t>(byteswap(k_double));
@@ -554,6 +634,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
554634
double mean;
555635
ptr += copy_from_mem(ptr, mean);
556636
mean = byteswap(mean);
637+
check_not_nan(weight, "centroid weight");
638+
check_not_infinite(weight, "centroid weight");
639+
check_not_nan(mean, "centroid mean");
640+
check_not_infinite(mean, "centroid mean");
557641
c = centroid(mean, static_cast<W>(weight));
558642
total_weight += static_cast<uint64_t>(weight);
559643
}
@@ -567,6 +651,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
567651
double max;
568652
ptr += copy_from_mem(ptr, max);
569653
max = byteswap(max);
654+
check_not_nan(min, "min");
655+
check_not_infinite(min, "min");
656+
check_not_nan(max, "max");
657+
check_not_infinite(max, "max");
570658
float k_float;
571659
ptr += copy_from_mem(ptr, k_float);
572660
const uint16_t k = static_cast<uint16_t>(byteswap(k_float));
@@ -586,6 +674,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
586674
float mean;
587675
ptr += copy_from_mem(ptr, mean);
588676
mean = byteswap(mean);
677+
check_not_nan(weight, "centroid weight");
678+
check_not_infinite(weight, "centroid weight");
679+
check_not_nan(mean, "centroid mean");
680+
check_not_infinite(mean, "centroid mean");
589681
c = centroid(mean, static_cast<W>(weight));
590682
total_weight += static_cast<uint64_t>(weight);
591683
}

tdigest/test/tdigest_test.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,36 @@
1818
*/
1919

2020
#include <catch2/catch.hpp>
21+
#include <cstring>
2122
#include <iostream>
2223
#include <fstream>
24+
#include <sstream>
2325

2426
#include "tdigest.hpp"
2527

2628
namespace datasketches {
2729

30+
namespace {
31+
constexpr size_t header_size = 8;
32+
constexpr size_t counts_size = 8;
33+
constexpr size_t min_offset = header_size + counts_size;
34+
constexpr size_t max_offset = min_offset + sizeof(double);
35+
constexpr size_t first_centroid_mean_offset = min_offset + sizeof(double) * 2;
36+
constexpr size_t first_centroid_weight_offset = first_centroid_mean_offset + sizeof(double);
37+
constexpr size_t first_buffered_value_offset = first_centroid_mean_offset;
38+
constexpr size_t single_value_offset = header_size;
39+
40+
template <typename T>
41+
void write_bytes(std::vector<uint8_t>& bytes, size_t offset, T value) {
42+
std::memcpy(bytes.data() + offset, &value, sizeof(T));
43+
}
44+
45+
template <typename T>
46+
void write_bytes(std::string& data, size_t offset, T value) {
47+
std::memcpy(&data[offset], &value, sizeof(T));
48+
}
49+
} // namespace
50+
2851
TEST_CASE("empty", "[tdigest]") {
2952
tdigest_double td(10);
3053
// std::cout << td.to_string();
@@ -470,4 +493,112 @@ TEST_CASE("iterate centroids", "[tdigest]") {
470493
REQUIRE(td.get_total_weight() == total_weight);
471494
}
472495

496+
TEST_CASE("update rejects positive infinity", "[tdigest]") {
497+
tdigest_double td(100);
498+
td.update(1.0);
499+
td.update(2.0);
500+
td.update(std::numeric_limits<double>::infinity());
501+
REQUIRE(td.get_total_weight() == 2);
502+
REQUIRE(td.get_max_value() == 2.0);
503+
}
504+
505+
TEST_CASE("update rejects negative infinity", "[tdigest]") {
506+
tdigest_double td(100);
507+
td.update(1.0);
508+
td.update(2.0);
509+
td.update(-std::numeric_limits<double>::infinity());
510+
REQUIRE(td.get_total_weight() == 2);
511+
REQUIRE(td.get_min_value() == 1.0);
512+
}
513+
514+
TEST_CASE("deserialize bytes rejects NaN single value", "[tdigest]") {
515+
tdigest_double td(100);
516+
td.update(1.0);
517+
auto bytes = td.serialize();
518+
write_bytes(bytes, single_value_offset, std::numeric_limits<double>::quiet_NaN());
519+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
520+
}
521+
522+
TEST_CASE("deserialize stream rejects infinity min", "[tdigest]") {
523+
tdigest_double td(100);
524+
td.update(1.0);
525+
td.update(2.0);
526+
td.update(3.0);
527+
auto bytes = td.serialize();
528+
std::string data(reinterpret_cast<const char*>(bytes.data()), bytes.size());
529+
write_bytes(data, min_offset, std::numeric_limits<double>::infinity());
530+
std::istringstream is(data, std::ios::binary);
531+
REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument);
532+
}
533+
534+
TEST_CASE("deserialize bytes rejects NaN centroid mean", "[tdigest]") {
535+
tdigest_double td(100);
536+
for (int i = 0; i < 10; ++i) td.update(i);
537+
auto bytes = td.serialize();
538+
write_bytes(bytes, first_centroid_mean_offset, std::numeric_limits<double>::quiet_NaN());
539+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
540+
}
541+
542+
TEST_CASE("deserialize bytes rejects NaN buffered value", "[tdigest]") {
543+
tdigest_double td(100);
544+
td.update(1.0);
545+
td.update(2.0);
546+
auto bytes = td.serialize(0, true);
547+
write_bytes(bytes, first_buffered_value_offset, std::numeric_limits<double>::quiet_NaN());
548+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
549+
}
550+
551+
TEST_CASE("deserialize bytes rejects infinity single value", "[tdigest]") {
552+
tdigest_double td(100);
553+
td.update(1.0);
554+
auto bytes = td.serialize();
555+
write_bytes(bytes, single_value_offset, std::numeric_limits<double>::infinity());
556+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
557+
}
558+
559+
TEST_CASE("deserialize bytes rejects NaN max", "[tdigest]") {
560+
tdigest_double td(100);
561+
td.update(1.0);
562+
td.update(2.0);
563+
auto bytes = td.serialize();
564+
write_bytes(bytes, max_offset, std::numeric_limits<double>::quiet_NaN());
565+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
566+
}
567+
568+
TEST_CASE("deserialize bytes rejects infinity max", "[tdigest]") {
569+
tdigest_double td(100);
570+
td.update(1.0);
571+
td.update(2.0);
572+
auto bytes = td.serialize();
573+
write_bytes(bytes, max_offset, std::numeric_limits<double>::infinity());
574+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
575+
}
576+
577+
TEST_CASE("deserialize bytes rejects infinity buffered value", "[tdigest]") {
578+
tdigest_double td(100);
579+
td.update(1.0);
580+
td.update(2.0);
581+
auto bytes = td.serialize(0, true);
582+
write_bytes(bytes, first_buffered_value_offset, std::numeric_limits<double>::infinity());
583+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
584+
}
585+
586+
TEST_CASE("deserialize bytes rejects zero centroid weight", "[tdigest]") {
587+
tdigest_double td(100);
588+
for (int i = 0; i < 10; ++i) td.update(i);
589+
auto bytes = td.serialize();
590+
write_bytes(bytes, first_centroid_weight_offset, static_cast<uint64_t>(0));
591+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
592+
}
593+
594+
TEST_CASE("deserialize stream rejects zero centroid weight", "[tdigest]") {
595+
tdigest_double td(100);
596+
for (int i = 0; i < 10; ++i) td.update(i);
597+
auto bytes = td.serialize();
598+
std::string data(reinterpret_cast<const char*>(bytes.data()), bytes.size());
599+
write_bytes(data, first_centroid_weight_offset, static_cast<uint64_t>(0));
600+
std::istringstream is(data, std::ios::binary);
601+
REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument);
602+
}
603+
473604
} /* namespace datasketches */

0 commit comments

Comments
 (0)