Skip to content

Commit 046147f

Browse files
author
Rafał Hibner
committed
Instantiate tdigest per exec
1 parent 982cc07 commit 046147f

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

cpp/src/arrow/compute/kernels/aggregate_tdigest.cc

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ using arrow::internal::TDigestScalerK1;
3535
using arrow::internal::VisitSetBitRunsVoid;
3636

3737
struct TDigestBaseImpl : public ScalarAggregator {
38-
explicit TDigestBaseImpl(std::unique_ptr<TDigest::Scaler> scaler, uint32_t buffer_size)
38+
explicit TDigestBaseImpl(std::shared_ptr<TDigest::Scaler> scaler, uint32_t buffer_size)
3939
: tdigest{std::move(scaler), buffer_size}, count{0}, all_valid{true} {
4040
out_type = struct_({
4141
field("mean", list(field("item", float64(), false)), false),
@@ -57,7 +57,7 @@ struct TDigestBaseImpl : public ScalarAggregator {
5757
return Status::OK();
5858
}
5959

60-
static Result<std::unique_ptr<TDigest::Scaler>> MakeScaler(
60+
static Result<std::shared_ptr<TDigest::Scaler>> MakeScaler(
6161
TDigestOptions::Scaler scaler, uint32_t delta) {
6262
switch (scaler) {
6363
case TDigestOptions::K0:
@@ -288,7 +288,7 @@ struct TDigestImpl
288288
// using TDigestBaseImpl::tdigest;
289289

290290
explicit TDigestImpl(const TDigestOptions& options, const DataType& in_type,
291-
std::unique_ptr<TDigest::Scaler> scaler)
291+
std::shared_ptr<TDigest::Scaler> scaler)
292292
: TDigestInputConsumerImpl<ArrowType, TDigestQuantileFinalizer>(
293293
// TDigestInputConsumerImpl
294294
options.skip_nulls, in_type,
@@ -302,7 +302,7 @@ template <typename ArrowType>
302302
struct TDigestMapImpl
303303
: public TDigestInputConsumerImpl<ArrowType, TDigestCentroidFinalizer> {
304304
explicit TDigestMapImpl(const TDigestMapOptions& options, const DataType& in_type,
305-
std::unique_ptr<TDigest::Scaler> scaler)
305+
std::shared_ptr<TDigest::Scaler> scaler)
306306
: TDigestInputConsumerImpl<ArrowType, TDigestCentroidFinalizer>(
307307

308308
// TDigestInputConsumerImpl
@@ -314,7 +314,7 @@ struct TDigestMapImpl
314314

315315
struct TDigestReduceImpl : public TDigestCentroidConsumerImpl<TDigestCentroidFinalizer> {
316316
explicit TDigestReduceImpl(const TDigestReduceOptions& options,
317-
std::unique_ptr<TDigest::Scaler> scaler)
317+
std::shared_ptr<TDigest::Scaler> scaler)
318318
: TDigestCentroidConsumerImpl<TDigestCentroidFinalizer>(
319319
// TDigestCentroidConsumerImpl
320320
// TDigestCentroidFinalizer
@@ -325,7 +325,7 @@ struct TDigestReduceImpl : public TDigestCentroidConsumerImpl<TDigestCentroidFin
325325
struct TDigestQuantileImpl
326326
: public TDigestCentroidConsumerImpl<TDigestQuantileFinalizer> {
327327
explicit TDigestQuantileImpl(const TDigestQuantileOptions& options,
328-
std::unique_ptr<TDigest::Scaler> scaler)
328+
std::shared_ptr<TDigest::Scaler> scaler)
329329
: TDigestCentroidConsumerImpl<TDigestQuantileFinalizer>(
330330

331331
// TDigestCentroidConsumerImpl
@@ -335,10 +335,10 @@ struct TDigestQuantileImpl
335335
std::move(scaler), options.delta) {}
336336
};
337337

338-
struct TDigestQuantileScalarImpl : public TDigestQuantileImpl {
338+
struct TDigestQuantileScalarImpl : public KernelState {
339339
explicit TDigestQuantileScalarImpl(const TDigestQuantileOptions& options,
340-
std::unique_ptr<TDigest::Scaler> scaler)
341-
: TDigestQuantileImpl(options, std::move(scaler)) {}
340+
std::shared_ptr<TDigest::Scaler> scaler)
341+
: options(options), scaler(std::move(scaler)) {}
342342

343343
static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
344344
const KernelInitArgs& args) {
@@ -354,33 +354,35 @@ struct TDigestQuantileScalarImpl : public TDigestQuantileImpl {
354354
return state->OutputType();
355355
}
356356

357-
size_t OutputSize() const { return this->q.size(); }
357+
size_t OutputSize() const { return options.q.size(); }
358358

359359
TypeHolder OutputType() const {
360360
return fixed_size_list(field("item", float64()), OutputSize());
361361
}
362362

363363
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
364364
auto state = checked_cast<TDigestQuantileScalarImpl*>(ctx->state());
365+
TDigestQuantileImpl tdigest(state->options, state->scaler);
365366
auto value_builder = std::make_shared<DoubleBuilder>(ctx->memory_pool());
366367
const auto output_size = state->OutputSize();
367368
FixedSizeListBuilder fsl_builder(
368369
ctx->memory_pool(), checked_pointer_cast<arrow::ArrayBuilder>(value_builder),
369370
output_size);
370371

371372
std::shared_ptr<Array> array = MakeArray(batch[0].array.ToArrayData());
373+
372374
for (int i = 0; i < array->length(); ++i) {
373375
if (array->IsValid(i)) {
374376
ARROW_RETURN_NOT_OK(fsl_builder.Append());
375377
ARROW_ASSIGN_OR_RAISE(auto scalar, array->GetScalar(i));
376-
state->Reset();
377-
ARROW_RETURN_NOT_OK(state->Consume(scalar.get()));
378+
tdigest.Reset();
379+
ARROW_RETURN_NOT_OK(tdigest.Consume(scalar.get()));
378380

379-
if (state->isNull()) {
381+
if (tdigest.isNull()) {
380382
ARROW_RETURN_NOT_OK(value_builder->AppendNulls(output_size));
381383
} else {
382384
for (size_t i = 0; i < output_size; ++i) {
383-
ARROW_RETURN_NOT_OK(value_builder->Append(state->Quantile(i)));
385+
ARROW_RETURN_NOT_OK(value_builder->Append(tdigest.Quantile(i)));
384386
}
385387
}
386388
} else {
@@ -392,6 +394,10 @@ struct TDigestQuantileScalarImpl : public TDigestQuantileImpl {
392394
out->value = std::move(out_array->data());
393395
return Status::OK();
394396
}
397+
398+
private:
399+
TDigestQuantileOptions options;
400+
std::shared_ptr<TDigest::Scaler> scaler;
395401
};
396402

397403
template <template <typename> typename TDigestImpl_T, typename TDigestOptions_T>

cpp/src/arrow/util/tdigest.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct Centroid {
5858
// implements t-digest merging algorithm
5959
class TDigestMerger {
6060
public:
61-
explicit TDigestMerger(std::unique_ptr<TDigest::Scaler> scaler)
61+
explicit TDigestMerger(std::shared_ptr<TDigest::Scaler> scaler)
6262
: scaler_{std::move(scaler)} {
6363
Reset(0, nullptr);
6464
}
@@ -111,7 +111,7 @@ class TDigestMerger {
111111
uint32_t delta() const { return scaler_->delta_; }
112112

113113
private:
114-
std::unique_ptr<TDigest::Scaler> scaler_;
114+
std::shared_ptr<TDigest::Scaler> scaler_;
115115
double total_weight_; // total weight of this tdigest
116116
double weight_so_far_; // accumulated weight till current bin
117117
double weight_limit_; // max accumulated weight to move to next bin
@@ -122,7 +122,7 @@ class TDigestMerger {
122122

123123
class TDigest::TDigestImpl {
124124
public:
125-
explicit TDigestImpl(std::unique_ptr<Scaler> scaler) : merger_(std::move(scaler)) {
125+
explicit TDigestImpl(std::shared_ptr<Scaler> scaler) : merger_(std::move(scaler)) {
126126
tdigests_[0].reserve(merger_.delta());
127127
tdigests_[1].reserve(merger_.delta());
128128
Reset();
@@ -305,7 +305,7 @@ class TDigest::TDigestImpl {
305305
if (diff > 0) {
306306
if (ci_right == td.size() - 1) {
307307
// index larger than center of last bin
308-
DCHECK_EQ(weight_sum, total_weight_);
308+
DCHECK_LE(std::abs(weight_sum - total_weight_), (weight_sum * 1e-9));
309309
const Centroid* c = &td[ci_right];
310310
DCHECK_GE(c->weight, 2);
311311
return Lerp(c->mean, max_, diff / (c->weight / 2));
@@ -363,7 +363,7 @@ class TDigest::TDigestImpl {
363363
TDigest::TDigest(uint32_t delta, uint32_t buffer_size)
364364
: TDigest(std::make_unique<TDigestScalerK1>(delta), buffer_size) {}
365365

366-
TDigest::TDigest(std::unique_ptr<Scaler> scaler, uint32_t buffer_size)
366+
TDigest::TDigest(std::shared_ptr<Scaler> scaler, uint32_t buffer_size)
367367
: impl_(new TDigestImpl(std::move(scaler))) {
368368
input_.reserve(buffer_size);
369369
Reset();

cpp/src/arrow/util/tdigest_internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class ARROW_EXPORT TDigest {
4949
};
5050

5151
explicit TDigest(uint32_t delta = 100, uint32_t buffer_size = 500);
52-
explicit TDigest(std::unique_ptr<Scaler> scaler, uint32_t buffer_size = 500);
52+
explicit TDigest(std::shared_ptr<Scaler> scaler, uint32_t buffer_size = 500);
5353
~TDigest();
5454
TDigest(TDigest&&);
5555
TDigest& operator=(TDigest&&);

0 commit comments

Comments
 (0)