Skip to content

Commit 1c2cd43

Browse files
author
Rafał Hibner
committed
Selectable scaler
1 parent c155e99 commit 1c2cd43

File tree

6 files changed

+119
-52
lines changed

6 files changed

+119
-52
lines changed

cpp/src/arrow/acero/hash_aggregate_test.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,9 @@ TEST_P(GroupBy, TDigest) {
12711271
auto keep_nulls_min_count =
12721272
std::make_shared<TDigestOptions>(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
12731273
/*skip_nulls=*/false, /*min_count=*/3);
1274+
auto scaler_0 = std::make_shared<TDigestOptions>(
1275+
/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
1276+
/*skip_nulls=*/true, /*min_count=*/3, /*scaler=*/TDigestOptions::K0);
12741277
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
12751278
GroupByTest(
12761279
{
@@ -1280,6 +1283,7 @@ TEST_P(GroupBy, TDigest) {
12801283
batch->GetColumnByName("argument"),
12811284
batch->GetColumnByName("argument"),
12821285
batch->GetColumnByName("argument"),
1286+
batch->GetColumnByName("argument"),
12831287
},
12841288
{
12851289
batch->GetColumnByName("key"),
@@ -1292,6 +1296,7 @@ TEST_P(GroupBy, TDigest) {
12921296
{"hash_tdigest", keep_nulls},
12931297
{"hash_tdigest", min_count},
12941298
{"hash_tdigest", keep_nulls_min_count},
1299+
{"hash_tdigest", scaler_0},
12951300
},
12961301
false));
12971302

@@ -1304,13 +1309,14 @@ TEST_P(GroupBy, TDigest) {
13041309
field("hash_tdigest", fixed_size_list(float64(), 1)),
13051310
field("hash_tdigest", fixed_size_list(float64(), 1)),
13061311
field("hash_tdigest", fixed_size_list(float64(), 1)),
1312+
field("hash_tdigest", fixed_size_list(float64(), 1)),
13071313
}),
13081314
R"([
1309-
[1, [1.0], [1.0, 3.0, 3.0], [1.0, 3.0, 3.0], [null], [null], [null]],
1310-
[2, [0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0], [0.0], [0.0] ],
1311-
[3, [null], [null, null, null], [null, null, null], [null], [null], [null]],
1312-
[4, [1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [null], [1.0], [null]],
1313-
[null, [1.0], [1.0, 4.0, 4.0], [1.0, 4.0, 4.0], [1.0], [null], [null]]
1315+
[1, [1.0], [1.0, 3.0, 3.0], [1.0, 3.0, 3.0], [null], [null], [null], [null]],
1316+
[2, [0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0], [0.0], [0.0], [0.0]],
1317+
[3, [null], [null, null, null], [null, null, null], [null], [null], [null], [null]],
1318+
[4, [1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [null], [1.0], [null], [1.0]],
1319+
[null, [1.0], [1.0, 4.0, 4.0], [1.0, 4.0, 4.0], [1.0], [null], [null], [null]]
13141320
])"),
13151321
aggregated_and_grouped,
13161322
/*verbose=*/true);

cpp/src/arrow/compute/api_aggregate.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ struct EnumTraits<compute::QuantileOptions::Interpolation>
6969
}
7070
};
7171

72+
template <>
73+
struct EnumTraits<compute::TDigestOptions::Scaler>
74+
: BasicEnumTraits<compute::TDigestOptions::Scaler, compute::TDigestOptions::K0,
75+
compute::TDigestOptions::K1> {
76+
static std::string name() { return "TDigestOptions::Scaler"; }
77+
static std::string value_name(compute::TDigestOptions::Scaler value) {
78+
switch (value) {
79+
case compute::TDigestOptions::K0:
80+
return "K0";
81+
case compute::TDigestOptions::K1:
82+
return "K1";
83+
}
84+
return "<INVALID>";
85+
}
86+
};
87+
7288
template <>
7389
struct EnumTraits<compute::PivotWiderOptions::UnexpectedKeyBehavior>
7490
: BasicEnumTraits<compute::PivotWiderOptions::UnexpectedKeyBehavior,
@@ -123,7 +139,8 @@ static auto kTDigestOptionsType = GetFunctionOptionsType<TDigestOptions>(
123139
DataMember("q", &TDigestOptions::q), DataMember("delta", &TDigestOptions::delta),
124140
DataMember("buffer_size", &TDigestOptions::buffer_size),
125141
DataMember("skip_nulls", &TDigestOptions::skip_nulls),
126-
DataMember("min_count", &TDigestOptions::min_count));
142+
DataMember("min_count", &TDigestOptions::min_count),
143+
DataMember("scaler", &TDigestOptions::scaler));
127144
static auto kPivotOptionsType = GetFunctionOptionsType<PivotWiderOptions>(
128145
DataMember("key_names", &PivotWiderOptions::key_names),
129146
DataMember("unexpected_key_behavior", &PivotWiderOptions::unexpected_key_behavior));
@@ -179,21 +196,24 @@ QuantileOptions::QuantileOptions(std::vector<double> q, enum Interpolation inter
179196
constexpr char QuantileOptions::kTypeName[];
180197

181198
TDigestOptions::TDigestOptions(double q, uint32_t delta, uint32_t buffer_size,
182-
bool skip_nulls, uint32_t min_count)
199+
bool skip_nulls, uint32_t min_count, enum Scaler scaler)
183200
: FunctionOptions(internal::kTDigestOptionsType),
184201
q{q},
185202
delta{delta},
186203
buffer_size{buffer_size},
187204
skip_nulls{skip_nulls},
188-
min_count{min_count} {}
205+
min_count{min_count},
206+
scaler{scaler} {}
189207
TDigestOptions::TDigestOptions(std::vector<double> q, uint32_t delta,
190-
uint32_t buffer_size, bool skip_nulls, uint32_t min_count)
208+
uint32_t buffer_size, bool skip_nulls, uint32_t min_count,
209+
enum Scaler scaler)
191210
: FunctionOptions(internal::kTDigestOptionsType),
192211
q{std::move(q)},
193212
delta{delta},
194213
buffer_size{buffer_size},
195214
skip_nulls{skip_nulls},
196-
min_count{min_count} {}
215+
min_count{min_count},
216+
scaler{scaler} {}
197217
constexpr char TDigestOptions::kTypeName[];
198218

199219
PivotWiderOptions::PivotWiderOptions(std::vector<std::string> key_names,

cpp/src/arrow/compute/api_aggregate.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,17 @@ class ARROW_EXPORT QuantileOptions : public FunctionOptions {
172172
/// By default, returns the median value.
173173
class ARROW_EXPORT TDigestOptions : public FunctionOptions {
174174
public:
175+
enum Scaler {
176+
K0 = 0,
177+
K1,
178+
};
179+
175180
explicit TDigestOptions(double q = 0.5, uint32_t delta = 100,
176181
uint32_t buffer_size = 500, bool skip_nulls = true,
177-
uint32_t min_count = 0);
182+
uint32_t min_count = 0, enum Scaler scaler = K0);
178183
explicit TDigestOptions(std::vector<double> q, uint32_t delta = 100,
179184
uint32_t buffer_size = 500, bool skip_nulls = true,
180-
uint32_t min_count = 0);
185+
uint32_t min_count = 0, enum Scaler scaler = K0);
181186
static constexpr char const kTypeName[] = "TDigestOptions";
182187
static TDigestOptions Defaults() { return TDigestOptions{}; }
183188

@@ -192,6 +197,8 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions {
192197
bool skip_nulls;
193198
/// If less than this many non-null values are observed, emit null.
194199
uint32_t min_count;
200+
/// select scaler implementation
201+
enum Scaler scaler;
195202
};
196203

197204
/// \brief Control Pivot kernel behavior

cpp/src/arrow/compute/kernels/aggregate_tdigest.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ namespace internal {
2929
namespace {
3030

3131
using arrow::internal::TDigest;
32+
using arrow::internal::TDigestScalerK0;
33+
using arrow::internal::TDigestScalerK1;
3234
using arrow::internal::VisitSetBitRunsVoid;
3335

3436
template <typename ArrowType>
@@ -37,9 +39,10 @@ struct TDigestImpl : public ScalarAggregator {
3739
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
3840
using CType = typename TypeTraits<ArrowType>::CType;
3941

40-
TDigestImpl(const TDigestOptions& options, const DataType& in_type)
42+
TDigestImpl(const TDigestOptions& options, const DataType& in_type,
43+
std::unique_ptr<TDigest::Scaler> scaler)
4144
: options{options},
42-
tdigest{options.delta, options.buffer_size},
45+
tdigest{std::move(scaler), options.buffer_size},
4346
count{0},
4447
decimal_scale{0},
4548
all_valid{true} {
@@ -149,16 +152,28 @@ struct TDigestInitState {
149152

150153
template <typename Type>
151154
enable_if_number<Type, Status> Visit(const Type&) {
152-
state.reset(new TDigestImpl<Type>(options, in_type));
155+
ARROW_ASSIGN_OR_RAISE(auto scaler, MakeScaler());
156+
state.reset(new TDigestImpl<Type>(options, in_type, std::move(scaler)));
153157
return Status::OK();
154158
}
155159

156160
template <typename Type>
157161
enable_if_decimal<Type, Status> Visit(const Type&) {
158-
state.reset(new TDigestImpl<Type>(options, in_type));
162+
ARROW_ASSIGN_OR_RAISE(auto scaler, MakeScaler());
163+
state.reset(new TDigestImpl<Type>(options, in_type, std::move(scaler)));
159164
return Status::OK();
160165
}
161166

167+
Result<std::unique_ptr<TDigest::Scaler>> MakeScaler() {
168+
switch (options.scaler) {
169+
case TDigestOptions::K0:
170+
return std::make_unique<TDigestScalerK0>(options.delta);
171+
case TDigestOptions::K1:
172+
return std::make_unique<TDigestScalerK1>(options.delta);
173+
}
174+
return Status::NotImplemented("Invalid TDigest scaler");
175+
}
176+
162177
Result<std::unique_ptr<KernelState>> Create() {
163178
RETURN_NOT_OK(VisitTypeInline(in_type, this));
164179
return std::move(state);

cpp/src/arrow/util/tdigest.cc

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,31 +50,13 @@ struct Centroid {
5050
}
5151
};
5252

53-
// scale function K0: linear function, as baseline
54-
struct ScalerK0 {
55-
explicit ScalerK0(uint32_t delta) : delta_norm(delta / 2.0) {}
56-
57-
double K(double q) const { return delta_norm * q; }
58-
double Q(double k) const { return k / delta_norm; }
59-
60-
const double delta_norm;
61-
};
62-
63-
// scale function K1
64-
struct ScalerK1 {
65-
explicit ScalerK1(uint32_t delta) : delta_norm(delta / (2.0 * M_PI)) {}
66-
67-
double K(double q) const { return delta_norm * std::asin(2 * q - 1); }
68-
double Q(double k) const { return (std::sin(k / delta_norm) + 1) / 2; }
69-
70-
const double delta_norm;
71-
};
72-
7353
// implements t-digest merging algorithm
74-
template <class T = ScalerK1>
75-
class TDigestMerger : private T {
54+
class TDigestMerger {
7655
public:
77-
explicit TDigestMerger(uint32_t delta) : T(delta) { Reset(0, nullptr); }
56+
explicit TDigestMerger(std::unique_ptr<TDigest::Scaler> scaler)
57+
: scaler_{std::move(scaler)} {
58+
Reset(0, nullptr);
59+
}
7860

7961
void Reset(double total_weight, std::vector<Centroid>* tdigest) {
8062
total_weight_ = total_weight;
@@ -94,7 +76,7 @@ class TDigestMerger : private T {
9476
td.back().Merge(centroid);
9577
} else {
9678
const double quantile = weight_so_far_ / total_weight_;
97-
const double next_weight_limit = total_weight_ * this->Q(this->K(quantile) + 1);
79+
const double next_weight_limit = total_weight_ * scaler_->QK1(quantile);
9880
// weight limit should be strictly increasing, until the last centroid
9981
if (next_weight_limit <= weight_limit_) {
10082
weight_limit_ = total_weight_;
@@ -108,10 +90,10 @@ class TDigestMerger : private T {
10890

10991
// validate k-size of a tdigest
11092
Status Validate(const std::vector<Centroid>& tdigest, double total_weight) const {
111-
double q_prev = 0, k_prev = this->K(0);
93+
double q_prev = 0, k_prev = scaler_->K(0);
11294
for (size_t i = 0; i < tdigest.size(); ++i) {
11395
const double q = q_prev + tdigest[i].weight / total_weight;
114-
const double k = this->K(q);
96+
const double k = scaler_->K(q);
11597
if (tdigest[i].weight != 1 && (k - k_prev) > 1.001) {
11698
return Status::Invalid("oversized centroid: ", k - k_prev);
11799
}
@@ -121,7 +103,10 @@ class TDigestMerger : private T {
121103
return Status::OK();
122104
}
123105

106+
uint32_t delta() const { return scaler_->delta_; }
107+
124108
private:
109+
std::unique_ptr<TDigest::Scaler> scaler_;
125110
double total_weight_; // total weight of this tdigest
126111
double weight_so_far_; // accumulated weight till current bin
127112
double weight_limit_; // max accumulated weight to move to next bin
@@ -132,10 +117,9 @@ class TDigestMerger : private T {
132117

133118
class TDigest::TDigestImpl {
134119
public:
135-
explicit TDigestImpl(uint32_t delta)
136-
: delta_(delta > 10 ? delta : 10), merger_(delta_) {
137-
tdigests_[0].reserve(delta_);
138-
tdigests_[1].reserve(delta_);
120+
explicit TDigestImpl(std::unique_ptr<Scaler> scaler) : merger_(std::move(scaler)) {
121+
tdigests_[0].reserve(merger_.delta());
122+
tdigests_[1].reserve(merger_.delta());
139123
Reset();
140124
}
141125

@@ -169,7 +153,8 @@ class TDigest::TDigestImpl {
169153
return Status::Invalid("tdigest total weight mismatch");
170154
}
171155
// check if buffer expanded
172-
if (tdigests_[0].capacity() > delta_ || tdigests_[1].capacity() > delta_) {
156+
if (tdigests_[0].capacity() > merger_.delta() ||
157+
tdigests_[1].capacity() > merger_.delta()) {
173158
return Status::Invalid("oversized tdigest buffer");
174159
}
175160
// check k-size
@@ -342,10 +327,7 @@ class TDigest::TDigestImpl {
342327
double total_weight() const { return total_weight_; }
343328

344329
private:
345-
// must be declared before merger_, see constructor initialization list
346-
const uint32_t delta_;
347-
348-
TDigestMerger<> merger_;
330+
TDigestMerger merger_;
349331
double total_weight_;
350332
double min_, max_;
351333

@@ -355,7 +337,11 @@ class TDigest::TDigestImpl {
355337
int current_;
356338
};
357339

358-
TDigest::TDigest(uint32_t delta, uint32_t buffer_size) : impl_(new TDigestImpl(delta)) {
340+
TDigest::TDigest(uint32_t delta, uint32_t buffer_size)
341+
: TDigest(std::make_unique<TDigestScalerK1>(delta), buffer_size) {}
342+
343+
TDigest::TDigest(std::unique_ptr<Scaler> scaler, uint32_t buffer_size)
344+
: impl_(new TDigestImpl(std::move(scaler))) {
359345
input_.reserve(buffer_size);
360346
Reset();
361347
}

cpp/src/arrow/util/tdigest_internal.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,17 @@ namespace internal {
3838

3939
class ARROW_EXPORT TDigest {
4040
public:
41+
struct ARROW_EXPORT Scaler {
42+
explicit Scaler(const uint32_t delta) : delta_(delta) {}
43+
virtual ~Scaler() {}
44+
virtual double K(double q) const = 0;
45+
// reduce virtual calls
46+
virtual double QK1(double q) const = 0;
47+
const uint32_t delta_;
48+
};
49+
4150
explicit TDigest(uint32_t delta = 100, uint32_t buffer_size = 500);
51+
explicit TDigest(std::unique_ptr<Scaler> scaler, uint32_t buffer_size = 500);
4252
~TDigest();
4353
TDigest(TDigest&&);
4454
TDigest& operator=(TDigest&&);
@@ -100,5 +110,28 @@ class ARROW_EXPORT TDigest {
100110
std::unique_ptr<TDigestImpl> impl_;
101111
};
102112

113+
// scale function K0: linear function, as baseline
114+
struct ARROW_EXPORT TDigestScalerK0 : public TDigest::Scaler {
115+
explicit TDigestScalerK0(uint32_t delta) : Scaler(delta), delta_norm(delta / 2.0) {}
116+
117+
double K(double q) const override { return delta_norm * q; }
118+
double Q(double k) const { return k / delta_norm; }
119+
double QK1(double q) const override { return Q(K(q) + 1); }
120+
121+
const double delta_norm;
122+
};
123+
124+
// scale function K1
125+
struct ARROW_EXPORT TDigestScalerK1 : public TDigest::Scaler {
126+
explicit TDigestScalerK1(uint32_t delta)
127+
: Scaler(delta), delta_norm(delta / (2.0 * M_PI)) {}
128+
129+
double K(double q) const override { return delta_norm * std::asin(2 * q - 1); }
130+
double Q(double k) const { return (std::sin(k / delta_norm) + 1) / 2; }
131+
double QK1(double q) const override { return Q(K(q) + 1); }
132+
133+
const double delta_norm;
134+
};
135+
103136
} // namespace internal
104137
} // namespace arrow

0 commit comments

Comments
 (0)