Skip to content

Commit bded7aa

Browse files
committed
fix: check weight is zero
1 parent 662aef3 commit bded7aa

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

tdigest/include/tdigest_impl.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <algorithm>
2424
#include <cmath>
2525
#include <sstream>
26+
#include <type_traits>
2627

2728
#include "common_defs.hpp"
2829
#include "memory_operations.hpp"
@@ -43,6 +44,14 @@ inline void check_not_infinite(T value, const char* name) {
4344
}
4445
}
4546

47+
template<typename T>
48+
inline void check_non_zero(T value, const char* name) {
49+
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
50+
if (value == 0) {
51+
throw std::invalid_argument(std::string(name) + " must not be zero");
52+
}
53+
}
54+
4655
template<typename T, typename A>
4756
tdigest<T, A>::tdigest(uint16_t k, const A& allocator):
4857
tdigest(false, k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::infinity(), vector_centroid(allocator), 0, vector_t(allocator))
@@ -437,6 +446,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
437446
for (const auto& c: centroids) {
438447
check_not_nan(c.get_mean(), "centroid mean");
439448
check_not_infinite(c.get_mean(), "centroid mean");
449+
check_non_zero(c.get_weight(), "centroid weight");
450+
440451
weight += c.get_weight();
441452
}
442453
for (const auto& value: buffer) {
@@ -508,6 +519,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
508519
for (const auto& c: centroids) {
509520
check_not_nan(c.get_mean(), "centroid mean");
510521
check_not_infinite(c.get_mean(), "centroid mean");
522+
check_non_zero(c.get_weight(), "centroid weight");
523+
511524
weight += c.get_weight();
512525
}
513526
for (const auto& value: buffer) {
@@ -542,6 +555,8 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
542555
const auto weight_double = read_big_endian<double>(is);
543556
check_not_nan(weight_double, "centroid weight");
544557
check_not_infinite(weight_double, "centroid weight");
558+
check_non_zero(weight_double, "centroid weight");
559+
545560
const auto mean = read_big_endian<double>(is);
546561
check_not_nan(mean, "centroid mean");
547562
check_not_infinite(mean, "centroid mean");

tdigest/test/tdigest_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ constexpr size_t counts_size = 8;
3333
constexpr size_t min_offset = header_size + counts_size;
3434
constexpr size_t max_offset = min_offset + sizeof(double);
3535
constexpr size_t first_centroid_mean_offset = min_offset + sizeof(double) * 2;
36+
constexpr size_t first_centroid_weight_offset = first_centroid_mean_offset + sizeof(double);
3637
constexpr size_t first_buffered_value_offset = first_centroid_mean_offset;
3738
constexpr size_t single_value_offset = header_size;
3839

@@ -582,4 +583,22 @@ TEST_CASE("deserialize bytes rejects infinity buffered value", "[tdigest]") {
582583
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
583584
}
584585

586+
TEST_CASE("deserialize bytes rejects zero centroid weight", "[tdigest]") {
587+
tdigest_double td(100);
588+
for (int i = 0; i < 10; ++i) td.update(i);
589+
auto bytes = td.serialize();
590+
write_bytes(bytes, first_centroid_weight_offset, static_cast<uint64_t>(0));
591+
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
592+
}
593+
594+
TEST_CASE("deserialize stream rejects zero centroid weight", "[tdigest]") {
595+
tdigest_double td(100);
596+
for (int i = 0; i < 10; ++i) td.update(i);
597+
auto bytes = td.serialize();
598+
std::string data(reinterpret_cast<const char*>(bytes.data()), bytes.size());
599+
write_bytes(data, first_centroid_weight_offset, static_cast<uint64_t>(0));
600+
std::istringstream is(data, std::ios::binary);
601+
REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument);
602+
}
603+
585604
} /* namespace datasketches */

0 commit comments

Comments
 (0)