Skip to content

Commit 982cc07

Browse files
author
Rafał Hibner
committed
Add tdigest_quantile_element_wise
1 parent c2c5a35 commit 982cc07

File tree

4 files changed

+173
-5
lines changed

4 files changed

+173
-5
lines changed

cpp/src/arrow/compute/api_aggregate.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,12 @@ Result<Datum> TDigestQuantile(const Datum& value, const TDigestQuantileOptions&
384384
return CallFunction("tdigest_quantile", {value}, &options, ctx);
385385
}
386386

387+
Result<Datum> TDigestQuantileElementWise(const Datum& value,
388+
const TDigestQuantileOptions& options,
389+
ExecContext* ctx) {
390+
return CallFunction("tdigest_quantile_element_wise", {value}, &options, ctx);
391+
}
392+
387393
Result<Datum> Index(const Datum& value, const IndexOptions& options, ExecContext* ctx) {
388394
return CallFunction("index", {value}, &options, ctx);
389395
}

cpp/src/arrow/compute/api_aggregate.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,21 @@ Result<Datum> TDigestQuantile(
694694
const TDigestQuantileOptions& options = TDigestQuantileOptions::Defaults(),
695695
ExecContext* ctx = NULLPTR);
696696

697+
/// \brief Calculate the approximate quantiles using centroids with T-Digest algorithm
698+
///
699+
/// \param[in] value input centroid sets, expecting Scalar, Array or ChunkedArray of
700+
/// centroid structs \param[in] options see TDigestQuantileOptions for more information
701+
/// \param[in] ctx the function execution context, optional
702+
/// \return resulting struct of mean and weight arrays
703+
///
704+
/// \since 22.0.0
705+
/// \note API not yet finalized
706+
ARROW_EXPORT
707+
Result<Datum> TDigestQuantileElementWise(
708+
const Datum& value,
709+
const TDigestQuantileOptions& options = TDigestQuantileOptions::Defaults(),
710+
ExecContext* ctx = NULLPTR);
711+
697712
/// \brief Find the first index of a value in an array.
698713
///
699714
/// \param[in] value The array to search.

cpp/src/arrow/compute/kernels/aggregate_tdigest.cc

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
#include "arrow/array/builder_nested.h"
19+
#include "arrow/array/builder_primitive.h"
1820
#include "arrow/compute/api_aggregate.h"
1921
#include "arrow/compute/kernels/aggregate_internal.h"
2022
#include "arrow/compute/kernels/common_internal.h"
@@ -80,23 +82,33 @@ struct TDigestQuantileFinalizer : public TDigestBaseImpl {
8082
q(std::move(q)),
8183
min_count(min_count) {}
8284

85+
bool isNull() {
86+
return this->tdigest.is_empty() || !this->all_valid || this->count < min_count;
87+
}
88+
89+
double Quantile(size_t i) { return this->tdigest.Quantile(this->q[i]); }
90+
void Reset() {
91+
this->tdigest.Reset();
92+
this->count = 0;
93+
}
94+
8395
Status Finalize(KernelContext* ctx, Datum* out) override {
84-
const int64_t out_length = q.size();
96+
const size_t out_length = q.size();
8597
auto out_data = ArrayData::Make(float64(), out_length, 0);
8698
out_data->buffers.resize(2, nullptr);
8799
ARROW_ASSIGN_OR_RAISE(out_data->buffers[1],
88100
ctx->Allocate(out_length * sizeof(double)));
89101
double* out_buffer = out_data->template GetMutableValues<double>(1);
90102

91-
if (this->tdigest.is_empty() || !this->all_valid || this->count < min_count) {
103+
if (isNull()) {
92104
ARROW_ASSIGN_OR_RAISE(out_data->buffers[0], ctx->AllocateBitmap(out_length));
93105
std::memset(out_data->buffers[0]->mutable_data(), 0x00,
94106
out_data->buffers[0]->size());
95107
std::fill(out_buffer, out_buffer + out_length, 0.0);
96108
out_data->null_count = out_length;
97109
} else {
98-
for (int64_t i = 0; i < out_length; ++i) {
99-
out_buffer[i] = this->tdigest.Quantile(this->q[i]);
110+
for (size_t i = 0; i < out_length; ++i) {
111+
out_buffer[i] = Quantile(i);
100112
}
101113
}
102114
*out = Datum(std::move(out_data));
@@ -323,6 +335,65 @@ struct TDigestQuantileImpl
323335
std::move(scaler), options.delta) {}
324336
};
325337

338+
struct TDigestQuantileScalarImpl : public TDigestQuantileImpl {
339+
explicit TDigestQuantileScalarImpl(const TDigestQuantileOptions& options,
340+
std::unique_ptr<TDigest::Scaler> scaler)
341+
: TDigestQuantileImpl(options, std::move(scaler)) {}
342+
343+
static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
344+
const KernelInitArgs& args) {
345+
auto options = static_cast<const TDigestQuantileOptions&>(*args.options);
346+
ARROW_ASSIGN_OR_RAISE(auto scaler,
347+
TDigestBaseImpl::MakeScaler(options.scaler, options.delta));
348+
return std::make_unique<TDigestQuantileScalarImpl>(options, std::move(scaler));
349+
}
350+
351+
static Result<TypeHolder> ResolveOutput(KernelContext* ctx,
352+
const std::vector<TypeHolder>& types) {
353+
auto state = checked_cast<TDigestQuantileScalarImpl*>(ctx->state());
354+
return state->OutputType();
355+
}
356+
357+
size_t OutputSize() const { return this->q.size(); }
358+
359+
TypeHolder OutputType() const {
360+
return fixed_size_list(field("item", float64()), OutputSize());
361+
}
362+
363+
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
364+
auto state = checked_cast<TDigestQuantileScalarImpl*>(ctx->state());
365+
auto value_builder = std::make_shared<DoubleBuilder>(ctx->memory_pool());
366+
const auto output_size = state->OutputSize();
367+
FixedSizeListBuilder fsl_builder(
368+
ctx->memory_pool(), checked_pointer_cast<arrow::ArrayBuilder>(value_builder),
369+
output_size);
370+
371+
std::shared_ptr<Array> array = MakeArray(batch[0].array.ToArrayData());
372+
for (int i = 0; i < array->length(); ++i) {
373+
if (array->IsValid(i)) {
374+
ARROW_RETURN_NOT_OK(fsl_builder.Append());
375+
ARROW_ASSIGN_OR_RAISE(auto scalar, array->GetScalar(i));
376+
state->Reset();
377+
ARROW_RETURN_NOT_OK(state->Consume(scalar.get()));
378+
379+
if (state->isNull()) {
380+
ARROW_RETURN_NOT_OK(value_builder->AppendNulls(output_size));
381+
} else {
382+
for (size_t i = 0; i < output_size; ++i) {
383+
ARROW_RETURN_NOT_OK(value_builder->Append(state->Quantile(i)));
384+
}
385+
}
386+
} else {
387+
ARROW_RETURN_NOT_OK(fsl_builder.AppendNull());
388+
}
389+
}
390+
std::shared_ptr<arrow::Array> out_array;
391+
ARROW_RETURN_NOT_OK(fsl_builder.Finish(&out_array));
392+
out->value = std::move(out_array->data());
393+
return Status::OK();
394+
}
395+
};
396+
326397
template <template <typename> typename TDigestImpl_T, typename TDigestOptions_T>
327398
struct TDigestInitState {
328399
std::unique_ptr<KernelState> state;
@@ -538,13 +609,27 @@ std::shared_ptr<ScalarAggregateFunction> AddTDigestReduceAggKernels() {
538609
}
539610

540611
std::shared_ptr<ScalarAggregateFunction> AddTDigestQuantileAggKernels() {
541-
static auto default_tdigest_options = TDigestMapOptions::Defaults();
612+
static auto default_tdigest_options = TDigestQuantileOptions::Defaults();
542613
auto func = std::make_shared<ScalarAggregateFunction>(
543614
"tdigest_quantile", Arity::Unary(), tdigest_quantile_doc, &default_tdigest_options);
544615
AddTDigestQuantileKernels(TDigestQuantileInit, func.get());
545616
return func;
546617
}
547618

619+
std::shared_ptr<ScalarFunction> AddTDigestQuantileScalarKernels() {
620+
static auto default_tdigest_options = TDigestQuantileOptions::Defaults();
621+
auto func =
622+
std::make_shared<ScalarFunction>("tdigest_quantile_element_wise", Arity::Unary(),
623+
tdigest_quantile_doc, &default_tdigest_options);
624+
auto output = OutputType{TDigestQuantileScalarImpl::ResolveOutput};
625+
ScalarKernel kernel({InputType(TDigestCentroidTypeMatcher::getMatcher())}, output,
626+
TDigestQuantileScalarImpl::Exec, TDigestQuantileScalarImpl::Init);
627+
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
628+
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
629+
DCHECK_OK(func->AddKernel(kernel));
630+
return func;
631+
}
632+
548633
std::shared_ptr<ScalarAggregateFunction> AddApproximateMedianAggKernels(
549634
const ScalarAggregateFunction* tdigest_func) {
550635
static ScalarAggregateOptions default_scalar_aggregate_options;
@@ -593,6 +678,8 @@ void RegisterScalarAggregateTDigest(FunctionRegistry* registry) {
593678
DCHECK_OK(registry->AddFunction(tdigest_merge));
594679
auto tdigest_quantile = AddTDigestQuantileAggKernels();
595680
DCHECK_OK(registry->AddFunction(tdigest_quantile));
681+
auto tdigest_quantile_scalar = AddTDigestQuantileScalarKernels();
682+
DCHECK_OK(registry->AddFunction(tdigest_quantile_scalar));
596683

597684
auto approx_median = AddApproximateMedianAggKernels(tdigest.get());
598685
DCHECK_OK(registry->AddFunction(approx_median));

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4238,6 +4238,9 @@ class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel<ArrowType> {
42384238

42394239
ASSERT_OK_AND_ASSIGN(
42404240
out, TDigestQuantile(incremental_centroids, options)); // incremental quantile
4241+
// validate tdigest_quantile_element_wise
4242+
EXPECT_THAT(TDigestQuantileElementWise(incremental_centroids, options),
4243+
ResultWith(FixedSizeListScalar(out.make_array())));
42414244
}
42424245

42434246
const auto& out_array = out.make_array();
@@ -4694,6 +4697,63 @@ TEST(TestTDigestQuantileKernel, Basic) {
46944697
ResultWith(ArrayFromJSON(output_type, "[null]")));
46954698
}
46964699

4700+
TEST(TestTDigestQuantileKernel, Scalar) {
4701+
auto input_type =
4702+
struct_({field("mean", list(field("item", float64(), false)), false),
4703+
field("weight", list(field("item", float64(), false)), false),
4704+
field("min", float64(), true), field("max", float64(), true),
4705+
field("count", uint64(), false)});
4706+
4707+
auto output_type = float64();
4708+
4709+
auto input_scalar = ScalarFromJSON(input_type,
4710+
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, "
4711+
"2, 2],\"min\":1.0,\"max\":6.0,\"count\":6}");
4712+
4713+
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*delta=*/5, /*min_count=*/6);
4714+
TDigestQuantileOptions min_count(/*q=*/0.5, /*delta=*/5, /*min_count=*/7);
4715+
4716+
EXPECT_THAT(TDigestQuantile(input_scalar, multiple),
4717+
ResultWith(ArrayFromJSON(output_type, "[1, 3.5, 6]")));
4718+
EXPECT_THAT(TDigestQuantile(input_scalar, min_count),
4719+
ResultWith(ArrayFromJSON(output_type, "[null]")));
4720+
}
4721+
4722+
TEST(TestTDigestQuantileKernel, ElementWise) {
4723+
auto input_type =
4724+
struct_({field("mean", list(field("item", float64(), false)), false),
4725+
field("weight", list(field("item", float64(), false)), false),
4726+
field("min", float64(), true), field("max", float64(), true),
4727+
field("count", uint64(), false)});
4728+
4729+
auto output_type_multiple = fixed_size_list(float64(), 3);
4730+
auto output_type_single = fixed_size_list(float64(), 1);
4731+
4732+
auto input_array = ArrayFromJSON(input_type,
4733+
"["
4734+
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, "
4735+
"2, 2],\"min\":1.0,\"max\":6.0,\"count\":7},"
4736+
"{\"mean\":[1.5, 3.5, 5.5],\"weight\":[2, "
4737+
"2, 2],\"min\":1.0,\"max\":6.0,\"count\":6}"
4738+
"]");
4739+
4740+
TDigestQuantileOptions multiple(/*q=*/{0.1, 0.5, 0.9}, /*delta=*/5, /*min_count=*/5);
4741+
TDigestQuantileOptions single(/*q=*/0.5, /*delta=*/5, /*min_count=*/5);
4742+
TDigestQuantileOptions multiple_min_count(/*q=*/{0.1, 0.5, 0.9}, /*delta=*/5,
4743+
/*min_count=*/7);
4744+
TDigestQuantileOptions single_min_count(/*q=*/0.5, /*delta=*/5, /*min_count=*/7);
4745+
EXPECT_THAT(
4746+
TDigestQuantileElementWise(input_array, multiple),
4747+
ResultWith(ArrayFromJSON(output_type_multiple, "[[1, 3.5, 6],[1, 3.5, 6]]")));
4748+
EXPECT_THAT(TDigestQuantileElementWise(input_array, single),
4749+
ResultWith(ArrayFromJSON(output_type_single, "[[3.5],[3.5]]")));
4750+
EXPECT_THAT(
4751+
TDigestQuantileElementWise(input_array, multiple_min_count),
4752+
ResultWith(ArrayFromJSON(output_type_multiple, "[[1, 3.5, 6],[null,null,null]]")));
4753+
EXPECT_THAT(TDigestQuantileElementWise(input_array, single_min_count),
4754+
ResultWith(ArrayFromJSON(output_type_single, "[[3.5],[null]]")));
4755+
}
4756+
46974757
TEST(TestTDigestMapReduceQuantileKernel, Basic) {
46984758
auto input_type =
46994759
struct_({field("mean", list(field("item", float64(), false)), false),

0 commit comments

Comments
 (0)