Skip to content

Commit 77118ca

Browse files
author
Rafał Hibner
committed
Remove redundant count from tdigest
1 parent 5f17c46 commit 77118ca

File tree

4 files changed

+57
-72
lines changed

4 files changed

+57
-72
lines changed

cpp/src/arrow/compute/kernels/aggregate_tdigest.cc

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using arrow::internal::VisitSetBitRunsVoid;
3636

3737
struct TDigestBaseImpl : public ScalarAggregator {
3838
explicit TDigestBaseImpl(std::shared_ptr<TDigest::Scaler> scaler, uint32_t buffer_size)
39-
: tdigest{std::move(scaler), buffer_size}, count{0}, all_valid{true} {}
39+
: tdigest{std::move(scaler), buffer_size}, all_valid{true} {}
4040

4141
Status MergeFrom(KernelContext*, KernelState&& src) override {
4242
const auto& other = checked_cast<const TDigestBaseImpl&>(src);
@@ -45,7 +45,6 @@ struct TDigestBaseImpl : public ScalarAggregator {
4545
return Status::OK();
4646
}
4747
this->tdigest.Merge(other.tdigest);
48-
this->count += other.count;
4948
return Status::OK();
5049
}
5150

@@ -61,20 +60,16 @@ struct TDigestBaseImpl : public ScalarAggregator {
6160
}
6261

6362
TDigest tdigest;
64-
uint64_t count;
6563
bool all_valid;
6664
static const std::shared_ptr<DataType>& out_type() {
67-
static auto out_type = struct_({
68-
field("centroids",
69-
list(field("item",
70-
struct_({field("mean", float64(), false),
71-
field("weight", float64(), false)}),
72-
false)),
73-
false),
74-
field("min", float64(), true),
75-
field("max", float64(), true),
76-
field("count", uint64(), false),
77-
});
65+
static auto out_type =
66+
struct_({field("centroids",
67+
list(field("item",
68+
struct_({field("mean", float64(), false),
69+
field("weight", float64(), false)}),
70+
false)),
71+
false),
72+
field("min", float64(), true), field("max", float64(), true)});
7873
return out_type;
7974
}
8075
};
@@ -88,14 +83,12 @@ struct TDigestQuantileFinalizer : public TDigestBaseImpl {
8883
min_count(min_count) {}
8984

9085
bool isNull() {
91-
return this->tdigest.is_empty() || !this->all_valid || this->count < min_count;
86+
return this->tdigest.is_empty() || !this->all_valid ||
87+
this->tdigest.TotalWeight() < (double)min_count;
9288
}
9389

9490
double Quantile(size_t i) { return this->tdigest.Quantile(this->q[i]); }
95-
void Reset() {
96-
this->tdigest.Reset();
97-
this->count = 0;
98-
}
91+
void Reset() { this->tdigest.Reset(); }
9992

10093
Status Finalize(KernelContext* ctx, Datum* out) override {
10194
const size_t out_length = q.size();
@@ -160,17 +153,15 @@ struct TDigestCentroidFinalizer : public TDigestBaseImpl {
160153
struct_({field("mean", float64(), false),
161154
field("weight", float64(), false)}),
162155
false)));
163-
auto count = std::make_shared<UInt64Scalar>(this->count);
164156
std::shared_ptr<Scalar> min, max;
165-
if (this->count) {
157+
if (!this->tdigest.is_empty()) {
166158
min = std::make_shared<DoubleScalar>(this->tdigest.Min());
167159
max = std::make_shared<DoubleScalar>(this->tdigest.Max());
168160
} else {
169161
min = max = MakeNullScalar(float64());
170162
}
171163
*out = std::make_shared<StructScalar>(
172-
std::vector<std::shared_ptr<Scalar>>{centroids_scalar, min, max, count},
173-
out_type());
164+
std::vector<std::shared_ptr<Scalar>>{centroids_scalar, min, max}, out_type());
174165
}
175166

176167
return Status::OK();
@@ -213,7 +204,6 @@ struct TDigestInputConsumerImpl : public TDigestFinalizer_T {
213204
const CType* values = data.GetValues<CType>(1);
214205

215206
if (data.length > data.GetNullCount()) {
216-
this->count += data.length - data.GetNullCount();
217207
VisitSetBitRunsVoid(data.buffers[0].data, data.offset, data.length,
218208
[&](int64_t pos, int64_t len) {
219209
for (int64_t i = 0; i < len; ++i) {
@@ -224,7 +214,6 @@ struct TDigestInputConsumerImpl : public TDigestFinalizer_T {
224214
} else {
225215
const CType value = UnboxScalar<ArrowType>::Unbox(*batch[0].scalar);
226216
if (batch[0].scalar->is_valid) {
227-
this->count += 1;
228217
for (int64_t i = 0; i < batch.length; i++) {
229218
this->tdigest.NanAdd(ToDouble(value));
230219
}
@@ -253,15 +242,12 @@ struct TDigestCentroidConsumerImpl : public TDigestFinalizer_T {
253242
checked_cast<const ListScalar*>(input_struct_scalar->value[1].get())->value;
254243
auto min = checked_cast<const DoubleScalar*>(input_struct_scalar->value[1].get());
255244
auto max = checked_cast<const DoubleScalar*>(input_struct_scalar->value[2].get());
256-
auto count = checked_cast<const UInt64Scalar*>(input_struct_scalar->value[3].get());
257245
auto mean_double_array = checked_cast<const DoubleArray*>(mean_array.get());
258246
auto weight_double_array = checked_cast<const DoubleArray*>(weight_array.get());
259247
DCHECK_EQ(mean_double_array->length(), weight_double_array->length());
260-
auto count_uint64 = count->value;
261-
if (count_uint64) {
248+
if (mean_double_array->length() > 0) {
262249
DCHECK(min->is_valid);
263250
DCHECK(max->is_valid);
264-
this->count += count_uint64;
265251
this->tdigest.SetMinMax(min->value, max->value);
266252
for (int64_t i = 0; i < mean_double_array->length(); i++) {
267253
this->tdigest.NanAdd(mean_double_array->Value(i), weight_double_array->Value(i));

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4566,8 +4566,7 @@ TEST(TestTDigestMapKernel, Options) {
45664566
field("weight", float64(), false)}),
45674567
false)),
45684568
false),
4569-
field("min", float64(), true), field("max", float64(), true),
4570-
field("count", uint64(), false)});
4569+
field("min", float64(), true), field("max", float64(), true)});
45714570
TDigestMapOptions keep_nulls(/*delta=*/5, /*buffer_size=*/500,
45724571
/*skip_nulls=*/false,
45734572
/*scaler=*/TDigestMapOptions::Scaler::K0);
@@ -4580,32 +4579,32 @@ TEST(TestTDigestMapKernel, Options) {
45804579
ResultWith(ScalarFromJSON(
45814580
output_type,
45824581
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4583-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}")));
4582+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}")));
45844583
EXPECT_THAT(
45854584
TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, 4.0, 5.0]"), keep_nulls),
45864585
ResultWith(ScalarFromJSON(
45874586
output_type,
45884587
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4589-
"{\"mean\":5.0,\"weight\":1}],\"min\":1.0,\"max\":5.0,\"count\":5}")));
4588+
"{\"mean\":5.0,\"weight\":1}],\"min\":1.0,\"max\":5.0}")));
45904589
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, 4.0]"), keep_nulls),
45914590
ResultWith(ScalarFromJSON(
45924591
output_type,
45934592
"{\"centroids\":[{\"mean\":1.0,\"weight\":1}, "
45944593
"{\"mean\":2.0,\"weight\":1}, {\"mean\":3.0,\"weight\":1}, "
4595-
"{\"mean\":4.0,\"weight\":1}],\"min\":1.0,\"max\":4.0,\"count\":4}")));
4594+
"{\"mean\":4.0,\"weight\":1}],\"min\":1.0,\"max\":4.0}")));
45964595

45974596
EXPECT_THAT(
45984597
TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0]"), keep_nulls),
45994598
ResultWith(ScalarFromJSON(
46004599
output_type,
46014600
"{\"centroids\":[{\"mean\":1.0,\"weight\":1}, {\"mean\":2.0,\"weight\":1}, "
4602-
"{\"mean\":3.0,\"weight\":1}],\"min\":1.0,\"max\":3.0,\"count\":3}")));
4601+
"{\"mean\":3.0,\"weight\":1}],\"min\":1.0,\"max\":3.0}")));
46034602
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, 3.0, null]"), keep_nulls),
46044603
ResultWith(ScalarFromJSON(output_type, "null")));
46054604
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "1.0"), keep_nulls),
46064605
ResultWith(ScalarFromJSON(output_type,
46074606
"{\"centroids\":[{\"mean\":1.0,\"weight\":1}],"
4608-
"\"min\":1.0,\"max\":1.0,\"count\":1}")));
4607+
"\"min\":1.0,\"max\":1.0}")));
46094608
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "null"), keep_nulls),
46104609
ResultWith(ScalarFromJSON(output_type, "null")));
46114610

@@ -4614,20 +4613,19 @@ TEST(TestTDigestMapKernel, Options) {
46144613
ResultWith(ScalarFromJSON(
46154614
output_type,
46164615
"{\"centroids\":[{\"mean\":1.0,\"weight\":1}, {\"mean\":2.0,\"weight\":1}, "
4617-
"{\"mean\":3.0,\"weight\":1}],\"min\":1.0,\"max\":3.0,\"count\":3}")));
4616+
"{\"mean\":3.0,\"weight\":1}],\"min\":1.0,\"max\":3.0}")));
46184617
EXPECT_THAT(TDigestMap(ArrayFromJSON(input_type, "[1.0, 2.0, null]"), skip_nulls),
46194618
ResultWith(ScalarFromJSON(
46204619
output_type,
46214620
"{\"centroids\":[{\"mean\":1.0,\"weight\":1}, "
4622-
"{\"mean\":2.0,\"weight\":1}],\"min\":1.0,\"max\":2.0,\"count\":2}")));
4621+
"{\"mean\":2.0,\"weight\":1}],\"min\":1.0,\"max\":2.0}")));
46234622
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "1.0"), skip_nulls),
46244623
ResultWith(ScalarFromJSON(output_type,
46254624
"{\"centroids\":[{\"mean\":1.0,\"weight\":1}],"
4626-
"\"min\":1.0,\"max\":1.0,\"count\":1}")));
4627-
EXPECT_THAT(
4628-
TDigestMap(ScalarFromJSON(input_type, "null"), skip_nulls),
4629-
ResultWith(ScalarFromJSON(
4630-
output_type, "{\"centroids\":[],\"min\":null,\"max\":null,\"count\":0}")));
4625+
"\"min\":1.0,\"max\":1.0}")));
4626+
EXPECT_THAT(TDigestMap(ScalarFromJSON(input_type, "null"), skip_nulls),
4627+
ResultWith(ScalarFromJSON(output_type,
4628+
"{\"centroids\":[],\"min\":null,\"max\":null}")));
46314629
}
46324630

46334631
TEST(TestTDigestReduceKernel, Basic) {
@@ -4637,66 +4635,65 @@ TEST(TestTDigestReduceKernel, Basic) {
46374635
field("weight", float64(), false)}),
46384636
false)),
46394637
false),
4640-
field("min", float64(), true), field("max", float64(), true),
4641-
field("count", uint64(), false)});
4638+
field("min", float64(), true), field("max", float64(), true)});
46424639
TDigestReduceOptions options(/*delta=*/5, /*scaler=*/TDigestMapOptions::Scaler::K0);
46434640
EXPECT_THAT(
46444641
TDigestReduce(
46454642
ArrayFromJSON(
46464643
type,
46474644
"["
46484645
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4649-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6},"
4646+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0},"
46504647
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4651-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}"
4648+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}"
46524649
"]"),
46534650
options),
46544651
ResultWith(ScalarFromJSON(
46554652
type,
46564653
"{\"centroids\":[{\"mean\":1.5,\"weight\":4}, {\"mean\":3.5,\"weight\":4}, "
4657-
"{\"mean\":5.5,\"weight\":4}],\"min\":1.0,\"max\":6.0,\"count\":12}")));
4654+
"{\"mean\":5.5,\"weight\":4}],\"min\":1.0,\"max\":6.0}")));
46584655

46594656
EXPECT_THAT(
46604657
TDigestReduce(
46614658
ScalarFromJSON(
46624659
type,
46634660
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4664-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}"),
4661+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}"),
46654662
options),
46664663
ResultWith(ScalarFromJSON(
46674664
type,
46684665
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4669-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}")));
4666+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}")));
46704667

46714668
EXPECT_THAT(
46724669
TDigestReduce(
46734670
ArrayFromJSON(
46744671
type,
46754672
"["
4676-
"{\"centroids\":[],\"min\":null,\"max\":null,\"count\":0},"
4673+
"{\"centroids\":[],\"min\":null,\"max\":null},"
46774674
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4678-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}"
4675+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}"
46794676
"]"),
46804677
options),
46814678
ResultWith(ScalarFromJSON(
46824679
type,
46834680
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4684-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}")));
4681+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}")));
46854682

46864683
EXPECT_THAT(
46874684
TDigestReduce(
46884685
ArrayFromJSON(
46894686
type,
46904687
"["
46914688
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4692-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6},"
4693-
"{\"centroids\":[],\"min\":null,\"max\":null,\"count\":0}"
4689+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0},"
4690+
"{\"centroids\":[],\"min\":null,\"max\":null}"
46944691
"]"),
46954692
options),
46964693
ResultWith(ScalarFromJSON(
46974694
type,
46984695
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4699-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}")));
4696+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}")));
47004697
}
47014698

47024699
TEST(TestTDigestQuantileKernel, Basic) {
@@ -4707,18 +4704,17 @@ TEST(TestTDigestQuantileKernel, Basic) {
47074704
field("weight", float64(), false)}),
47084705
false)),
47094706
false),
4710-
field("min", float64(), true), field("max", float64(), true),
4711-
field("count", uint64(), false)});
4707+
field("min", float64(), true), field("max", float64(), true)});
47124708

47134709
auto output_type = float64();
47144710

47154711
auto input_array = ArrayFromJSON(
47164712
input_type,
47174713
"["
47184714
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4719-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6},"
4715+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0},"
47204716
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4721-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}"
4717+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}"
47224718
"]");
47234719

47244720
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*delta=*/5, /*min_count=*/12);
@@ -4738,15 +4734,14 @@ TEST(TestTDigestQuantileKernel, Scalar) {
47384734
field("weight", float64(), false)}),
47394735
false)),
47404736
false),
4741-
field("min", float64(), true), field("max", float64(), true),
4742-
field("count", uint64(), false)});
4737+
field("min", float64(), true), field("max", float64(), true)});
47434738

47444739
auto output_type = float64();
47454740

47464741
auto input_scalar = ScalarFromJSON(
47474742
input_type,
47484743
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4749-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}");
4744+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}");
47504745

47514746
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*delta=*/5, /*min_count=*/6);
47524747
TDigestQuantileOptions min_count(/*q=*/0.5, /*delta=*/5, /*min_count=*/7);
@@ -4765,19 +4760,18 @@ TEST(TestTDigestQuantileKernel, ElementWise) {
47654760
field("weight", float64(), false)}),
47664761
false)),
47674762
false),
4768-
field("min", float64(), true), field("max", float64(), true),
4769-
field("count", uint64(), false)});
4763+
field("min", float64(), true), field("max", float64(), true)});
47704764

47714765
auto output_type_multiple = fixed_size_list(float64(), 3);
47724766
auto output_type_single = fixed_size_list(float64(), 1);
47734767

47744768
auto input_array = ArrayFromJSON(
47754769
input_type,
47764770
"["
4771+
"{\"centroids\":[{\"mean\":1.5,\"weight\":3}, {\"mean\":3.5,\"weight\":3}, "
4772+
"{\"mean\":5.5,\"weight\":3}],\"min\":1.0,\"max\":6.0},"
47774773
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4778-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":7},"
4779-
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4780-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}"
4774+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}"
47814775
"]");
47824776

47834777
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*delta=*/5, /*min_count=*/5);
@@ -4805,18 +4799,17 @@ TEST(TestTDigestMapReduceQuantileKernel, Basic) {
48054799
field("weight", float64(), false)}),
48064800
false)),
48074801
false),
4808-
field("min", float64(), true), field("max", float64(), true),
4809-
field("count", uint64(), false)});
4802+
field("min", float64(), true), field("max", float64(), true)});
48104803

48114804
auto output_type = float64();
48124805

48134806
auto input_array = ArrayFromJSON(
48144807
input_type,
48154808
"["
48164809
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4817-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6},"
4810+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0},"
48184811
"{\"centroids\":[{\"mean\":1.5,\"weight\":2}, {\"mean\":3.5,\"weight\":2}, "
4819-
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0,\"count\":6}"
4812+
"{\"mean\":5.5,\"weight\":2}],\"min\":1.0,\"max\":6.0}"
48204813
"]");
48214814

48224815
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*delta=*/5, /*min_count=*/12);

cpp/src/arrow/util/tdigest.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,11 @@ double TDigest::Mean() const {
430430
return impl_->Mean();
431431
}
432432

433+
double TDigest::TotalWeight() const {
434+
MergeInput();
435+
return impl_->total_weight();
436+
}
437+
433438
bool TDigest::is_empty() const {
434439
return input_.size() == 0 && impl_->total_weight() == 0;
435440
}

cpp/src/arrow/util/tdigest_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class ARROW_EXPORT TDigest {
112112
double Min() const { return Quantile(0); }
113113
double Max() const { return Quantile(1); }
114114
double Mean() const;
115+
double TotalWeight() const;
115116

116117
// check if this tdigest contains no valid data points
117118
bool is_empty() const;

0 commit comments

Comments
 (0)