|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
| 18 | +#include "arrow/array/builder_nested.h" |
| 19 | +#include "arrow/array/builder_primitive.h" |
18 | 20 | #include "arrow/compute/api_aggregate.h" |
19 | 21 | #include "arrow/compute/kernels/aggregate_internal.h" |
20 | 22 | #include "arrow/compute/kernels/common_internal.h" |
@@ -80,23 +82,33 @@ struct TDigestQuantileFinalizer : public TDigestBaseImpl { |
80 | 82 | q(std::move(q)), |
81 | 83 | min_count(min_count) {} |
82 | 84 |
|
| 85 | + bool isNull() { |
| 86 | + return this->tdigest.is_empty() || !this->all_valid || this->count < min_count; |
| 87 | + } |
| 88 | + |
| 89 | + double Quantile(size_t i) { return this->tdigest.Quantile(this->q[i]); } |
| 90 | + void Reset() { |
| 91 | + this->tdigest.Reset(); |
| 92 | + this->count = 0; |
| 93 | + } |
| 94 | + |
83 | 95 | Status Finalize(KernelContext* ctx, Datum* out) override { |
84 | | - const int64_t out_length = q.size(); |
| 96 | + const size_t out_length = q.size(); |
85 | 97 | auto out_data = ArrayData::Make(float64(), out_length, 0); |
86 | 98 | out_data->buffers.resize(2, nullptr); |
87 | 99 | ARROW_ASSIGN_OR_RAISE(out_data->buffers[1], |
88 | 100 | ctx->Allocate(out_length * sizeof(double))); |
89 | 101 | double* out_buffer = out_data->template GetMutableValues<double>(1); |
90 | 102 |
|
91 | | - if (this->tdigest.is_empty() || !this->all_valid || this->count < min_count) { |
| 103 | + if (isNull()) { |
92 | 104 | ARROW_ASSIGN_OR_RAISE(out_data->buffers[0], ctx->AllocateBitmap(out_length)); |
93 | 105 | std::memset(out_data->buffers[0]->mutable_data(), 0x00, |
94 | 106 | out_data->buffers[0]->size()); |
95 | 107 | std::fill(out_buffer, out_buffer + out_length, 0.0); |
96 | 108 | out_data->null_count = out_length; |
97 | 109 | } else { |
98 | | - for (int64_t i = 0; i < out_length; ++i) { |
99 | | - out_buffer[i] = this->tdigest.Quantile(this->q[i]); |
| 110 | + for (size_t i = 0; i < out_length; ++i) { |
| 111 | + out_buffer[i] = Quantile(i); |
100 | 112 | } |
101 | 113 | } |
102 | 114 | *out = Datum(std::move(out_data)); |
@@ -323,6 +335,65 @@ struct TDigestQuantileImpl |
323 | 335 | std::move(scaler), options.delta) {} |
324 | 336 | }; |
325 | 337 |
|
| 338 | +struct TDigestQuantileScalarImpl : public TDigestQuantileImpl { |
| 339 | + explicit TDigestQuantileScalarImpl(const TDigestQuantileOptions& options, |
| 340 | + std::unique_ptr<TDigest::Scaler> scaler) |
| 341 | + : TDigestQuantileImpl(options, std::move(scaler)) {} |
| 342 | + |
| 343 | + static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx, |
| 344 | + const KernelInitArgs& args) { |
| 345 | + auto options = static_cast<const TDigestQuantileOptions&>(*args.options); |
| 346 | + ARROW_ASSIGN_OR_RAISE(auto scaler, |
| 347 | + TDigestBaseImpl::MakeScaler(options.scaler, options.delta)); |
| 348 | + return std::make_unique<TDigestQuantileScalarImpl>(options, std::move(scaler)); |
| 349 | + } |
| 350 | + |
| 351 | + static Result<TypeHolder> ResolveOutput(KernelContext* ctx, |
| 352 | + const std::vector<TypeHolder>& types) { |
| 353 | + auto state = checked_cast<TDigestQuantileScalarImpl*>(ctx->state()); |
| 354 | + return state->OutputType(); |
| 355 | + } |
| 356 | + |
| 357 | + size_t OutputSize() const { return this->q.size(); } |
| 358 | + |
| 359 | + TypeHolder OutputType() const { |
| 360 | + return fixed_size_list(field("item", float64()), OutputSize()); |
| 361 | + } |
| 362 | + |
| 363 | + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { |
| 364 | + auto state = checked_cast<TDigestQuantileScalarImpl*>(ctx->state()); |
| 365 | + auto value_builder = std::make_shared<DoubleBuilder>(ctx->memory_pool()); |
| 366 | + const auto output_size = state->OutputSize(); |
| 367 | + FixedSizeListBuilder fsl_builder( |
| 368 | + ctx->memory_pool(), checked_pointer_cast<arrow::ArrayBuilder>(value_builder), |
| 369 | + output_size); |
| 370 | + |
| 371 | + std::shared_ptr<Array> array = MakeArray(batch[0].array.ToArrayData()); |
| 372 | + for (int i = 0; i < array->length(); ++i) { |
| 373 | + if (array->IsValid(i)) { |
| 374 | + ARROW_RETURN_NOT_OK(fsl_builder.Append()); |
| 375 | + ARROW_ASSIGN_OR_RAISE(auto scalar, array->GetScalar(i)); |
| 376 | + state->Reset(); |
| 377 | + ARROW_RETURN_NOT_OK(state->Consume(scalar.get())); |
| 378 | + |
| 379 | + if (state->isNull()) { |
| 380 | + ARROW_RETURN_NOT_OK(value_builder->AppendNulls(output_size)); |
| 381 | + } else { |
| 382 | + for (size_t i = 0; i < output_size; ++i) { |
| 383 | + ARROW_RETURN_NOT_OK(value_builder->Append(state->Quantile(i))); |
| 384 | + } |
| 385 | + } |
| 386 | + } else { |
| 387 | + ARROW_RETURN_NOT_OK(fsl_builder.AppendNull()); |
| 388 | + } |
| 389 | + } |
| 390 | + std::shared_ptr<arrow::Array> out_array; |
| 391 | + ARROW_RETURN_NOT_OK(fsl_builder.Finish(&out_array)); |
| 392 | + out->value = std::move(out_array->data()); |
| 393 | + return Status::OK(); |
| 394 | + } |
| 395 | +}; |
| 396 | + |
326 | 397 | template <template <typename> typename TDigestImpl_T, typename TDigestOptions_T> |
327 | 398 | struct TDigestInitState { |
328 | 399 | std::unique_ptr<KernelState> state; |
@@ -538,13 +609,27 @@ std::shared_ptr<ScalarAggregateFunction> AddTDigestReduceAggKernels() { |
538 | 609 | } |
539 | 610 |
|
540 | 611 | std::shared_ptr<ScalarAggregateFunction> AddTDigestQuantileAggKernels() { |
541 | | - static auto default_tdigest_options = TDigestMapOptions::Defaults(); |
| 612 | + static auto default_tdigest_options = TDigestQuantileOptions::Defaults(); |
542 | 613 | auto func = std::make_shared<ScalarAggregateFunction>( |
543 | 614 | "tdigest_quantile", Arity::Unary(), tdigest_quantile_doc, &default_tdigest_options); |
544 | 615 | AddTDigestQuantileKernels(TDigestQuantileInit, func.get()); |
545 | 616 | return func; |
546 | 617 | } |
547 | 618 |
|
| 619 | +std::shared_ptr<ScalarFunction> AddTDigestQuantileScalarKernels() { |
| 620 | + static auto default_tdigest_options = TDigestQuantileOptions::Defaults(); |
| 621 | + auto func = |
| 622 | + std::make_shared<ScalarFunction>("tdigest_quantile_element_wise", Arity::Unary(), |
| 623 | + tdigest_quantile_doc, &default_tdigest_options); |
| 624 | + auto output = OutputType{TDigestQuantileScalarImpl::ResolveOutput}; |
| 625 | + ScalarKernel kernel({InputType(TDigestCentroidTypeMatcher::getMatcher())}, output, |
| 626 | + TDigestQuantileScalarImpl::Exec, TDigestQuantileScalarImpl::Init); |
| 627 | + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; |
| 628 | + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; |
| 629 | + DCHECK_OK(func->AddKernel(kernel)); |
| 630 | + return func; |
| 631 | +} |
| 632 | + |
548 | 633 | std::shared_ptr<ScalarAggregateFunction> AddApproximateMedianAggKernels( |
549 | 634 | const ScalarAggregateFunction* tdigest_func) { |
550 | 635 | static ScalarAggregateOptions default_scalar_aggregate_options; |
@@ -593,6 +678,8 @@ void RegisterScalarAggregateTDigest(FunctionRegistry* registry) { |
593 | 678 | DCHECK_OK(registry->AddFunction(tdigest_merge)); |
594 | 679 | auto tdigest_quantile = AddTDigestQuantileAggKernels(); |
595 | 680 | DCHECK_OK(registry->AddFunction(tdigest_quantile)); |
| 681 | + auto tdigest_quantile_scalar = AddTDigestQuantileScalarKernels(); |
| 682 | + DCHECK_OK(registry->AddFunction(tdigest_quantile_scalar)); |
596 | 683 |
|
597 | 684 | auto approx_median = AddApproximateMedianAggKernels(tdigest.get()); |
598 | 685 | DCHECK_OK(registry->AddFunction(approx_median)); |
|
0 commit comments