Skip to content

Commit 7487397

Browse files
committed
refactor(tdigest): replace template Rank with separate Rank and RevRank methods
1 parent b6074af commit 7487397

4 files changed

Lines changed: 206 additions & 133 deletions

File tree

src/commands/cmd_tdigest.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,14 @@ class TDigestRankCommand : public Commander {
202202
TDigest tdigest(srv->storage, conn->GetNamespace());
203203
std::vector<int> result;
204204
result.reserve(origin_inputs_.size());
205-
if (const auto s = tdigest.Rank<Reverse>(ctx, key_name_, unique_inputs_, result); !s.ok()) {
205+
const auto s = [&]() {
206+
if constexpr (Reverse) {
207+
return tdigest.RevRank(ctx, key_name_, unique_inputs_, result);
208+
} else {
209+
return tdigest.Rank(ctx, key_name_, unique_inputs_, result);
210+
}
211+
}();
212+
if (!s.ok()) {
206213
if (s.IsNotFound()) {
207214
return {Status::RedisExecErr, errKeyNotFound};
208215
}

src/types/redis_tdigest.cc

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,126 @@
4747

4848
namespace redis {
4949

50+
namespace {
51+
template <bool Reverse, typename Container>
52+
inline decltype(auto) GetCbeginIter(const Container& centroids) {
53+
if constexpr (Reverse) {
54+
return centroids.crbegin();
55+
} else {
56+
return centroids.cbegin();
57+
}
58+
}
59+
60+
template <bool Reverse, typename Container>
61+
inline decltype(auto) GetCendIter(const Container& centroids) {
62+
if constexpr (Reverse) {
63+
return centroids.crend();
64+
} else {
65+
return centroids.cend();
66+
}
67+
}
68+
69+
template <bool Reverse>
70+
rocksdb::Status RankImpl(TDigest* self, engine::Context& ctx, const Slice& digest_name,
71+
const std::vector<double>& inputs, std::vector<int>& result) {
72+
auto ns_key = self->AppendNamespacePrefix(digest_name);
73+
TDigestMetadata metadata;
74+
{
75+
LockGuard guard(self->storage_->GetLockManager(), ns_key);
76+
77+
if (auto status = self->getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) {
78+
return status;
79+
}
80+
81+
if (metadata.total_observations == 0) {
82+
result.resize(inputs.size(), -2);
83+
return rocksdb::Status::OK();
84+
}
85+
86+
if (auto status = self->mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
87+
return status;
88+
}
89+
}
90+
91+
std::vector<Centroid> centroids;
92+
if (auto status = self->dumpCentroids(ctx, ns_key, metadata, &centroids); !status.ok()) {
93+
return status;
94+
}
95+
96+
auto dump_centroids = DummyCentroids<Reverse>(metadata, centroids);
97+
if (auto status = TDigestRank<Reverse>(dump_centroids, inputs, result); !status) {
98+
return rocksdb::Status::InvalidArgument(status.Msg());
99+
}
100+
return rocksdb::Status::OK();
101+
}
102+
} // namespace
103+
104+
// TODO: It should be replaced by a iteration of the rocksdb iterator
105+
template <bool Reverse>
106+
class DummyCentroids {
107+
public:
108+
DummyCentroids(const TDigestMetadata& meta_data, const std::vector<Centroid>& centroids)
109+
: meta_data_(meta_data), centroids_(centroids) {}
110+
class Iterator {
111+
public:
112+
using IterType = std::conditional_t<Reverse, std::vector<Centroid>::const_reverse_iterator,
113+
std::vector<Centroid>::const_iterator>;
114+
Iterator(IterType iter, const std::vector<Centroid>& centroids) : iter_(iter), centroids_(centroids) {}
115+
std::unique_ptr<Iterator> Clone() const {
116+
if (iter_ != GetCendIter<Reverse>(centroids_)) {
117+
return std::make_unique<Iterator>(
118+
std::next(GetCbeginIter<Reverse>(centroids_), std::distance(GetCbeginIter<Reverse>(centroids_), iter_)),
119+
centroids_);
120+
}
121+
return std::make_unique<Iterator>(GetCendIter<Reverse>(centroids_), centroids_);
122+
}
123+
bool Next() {
124+
if (Valid()) {
125+
std::advance(iter_, 1);
126+
}
127+
return iter_ != GetCendIter<Reverse>(centroids_);
128+
}
129+
130+
// The Prev function can only be called for item is not cend,
131+
// because we must guarantee the iterator to be inside the valid range before iteration.
132+
bool Prev() {
133+
if (Valid() && iter_ != GetCendIter<Reverse>(centroids_)) {
134+
std::advance(iter_, -1);
135+
}
136+
return Valid();
137+
}
138+
bool Valid() const { return iter_ != GetCendIter<Reverse>(centroids_); }
139+
StatusOr<Centroid> GetCentroid() const {
140+
if (iter_ == GetCendIter<Reverse>(centroids_)) {
141+
return {::Status::NotOK, "invalid iterator during decoding tdigest centroid"};
142+
}
143+
return *iter_;
144+
}
145+
146+
private:
147+
IterType iter_;
148+
const std::vector<Centroid>& centroids_;
149+
};
150+
151+
std::unique_ptr<Iterator> Begin() const {
152+
return std::make_unique<Iterator>(GetCbeginIter<Reverse>(centroids_), centroids_);
153+
}
154+
std::unique_ptr<Iterator> End() const {
155+
if (centroids_.empty()) {
156+
return std::make_unique<Iterator>(GetCendIter<Reverse>(centroids_), centroids_);
157+
}
158+
return std::make_unique<Iterator>(std::prev(GetCendIter<Reverse>(centroids_)), centroids_);
159+
}
160+
double TotalWeight() const { return static_cast<double>(meta_data_.total_weight); }
161+
double Min() const { return meta_data_.minimum; }
162+
double Max() const { return meta_data_.maximum; }
163+
uint64_t Size() const { return meta_data_.merged_nodes; }
164+
165+
private:
166+
const TDigestMetadata& meta_data_;
167+
const std::vector<Centroid>& centroids_;
168+
};
169+
50170
uint32_t constexpr kMaxElements = 1 * 1024; // 1k doubles
51171

52172
rocksdb::Status TDigest::Create(engine::Context& ctx, const Slice& digest_name, const TDigestCreateOptions& options,
@@ -578,4 +698,70 @@ std::string TDigest::internalSegmentGuardPrefixKey(const TDigestMetadata& metada
578698
PutFixed8(&prefix_key, static_cast<uint8_t>(seg));
579699
return InternalKey(ns_key, prefix_key, metadata.version, storage_->IsSlotIdEncoded()).Encode();
580700
}
701+
702+
rocksdb::Status TDigest::Rank(engine::Context& ctx, const Slice& digest_name, const std::vector<double>& inputs,
703+
std::vector<int>& result) {
704+
auto ns_key = AppendNamespacePrefix(digest_name);
705+
TDigestMetadata metadata;
706+
{
707+
LockGuard guard(storage_->GetLockManager(), ns_key);
708+
709+
if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) {
710+
return status;
711+
}
712+
713+
if (metadata.total_observations == 0) {
714+
result.resize(inputs.size(), -2);
715+
return rocksdb::Status::OK();
716+
}
717+
718+
if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
719+
return status;
720+
}
721+
}
722+
723+
std::vector<Centroid> centroids;
724+
if (auto status = dumpCentroids(ctx, ns_key, metadata, &centroids); !status.ok()) {
725+
return status;
726+
}
727+
728+
auto dump_centroids = DummyCentroids<false>(metadata, centroids);
729+
if (auto status = TDigestRank<false>(dump_centroids, inputs, result); !status) {
730+
return rocksdb::Status::InvalidArgument(status.Msg());
731+
}
732+
return rocksdb::Status::OK();
733+
}
734+
735+
rocksdb::Status TDigest::RevRank(engine::Context& ctx, const Slice& digest_name, const std::vector<double>& inputs,
736+
std::vector<int>& result) {
737+
auto ns_key = AppendNamespacePrefix(digest_name);
738+
TDigestMetadata metadata;
739+
{
740+
LockGuard guard(storage_->GetLockManager(), ns_key);
741+
742+
if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) {
743+
return status;
744+
}
745+
746+
if (metadata.total_observations == 0) {
747+
result.resize(inputs.size(), -2);
748+
return rocksdb::Status::OK();
749+
}
750+
751+
if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
752+
return status;
753+
}
754+
}
755+
756+
std::vector<Centroid> centroids;
757+
if (auto status = dumpCentroids(ctx, ns_key, metadata, &centroids); !status.ok()) {
758+
return status;
759+
}
760+
761+
auto dump_centroids = DummyCentroids<true>(metadata, centroids);
762+
if (auto status = TDigestRank<true>(dump_centroids, inputs, result); !status) {
763+
return rocksdb::Status::InvalidArgument(status.Msg());
764+
}
765+
return rocksdb::Status::OK();
766+
}
581767
} // namespace redis

src/types/redis_tdigest.h

Lines changed: 2 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -34,92 +34,6 @@
3434

3535
namespace redis {
3636

37-
namespace detail {
38-
template <bool Reverse, typename Container>
39-
inline decltype(auto) GetCbeginIter(const Container& centroids) {
40-
if constexpr (Reverse) {
41-
return centroids.crbegin();
42-
} else {
43-
return centroids.cbegin();
44-
}
45-
}
46-
47-
template <bool Reverse, typename Container>
48-
inline decltype(auto) GetCendIter(const Container& centroids) {
49-
if constexpr (Reverse) {
50-
return centroids.crend();
51-
} else {
52-
return centroids.cend();
53-
}
54-
}
55-
} // namespace detail
56-
57-
// TODO: It should be replaced by a iteration of the rocksdb iterator
58-
template <bool Reverse>
59-
class DummyCentroids {
60-
public:
61-
DummyCentroids(const TDigestMetadata& meta_data, const std::vector<Centroid>& centroids)
62-
: meta_data_(meta_data), centroids_(centroids) {}
63-
class Iterator {
64-
public:
65-
using IterType = std::conditional_t<Reverse, std::vector<Centroid>::const_reverse_iterator,
66-
std::vector<Centroid>::const_iterator>;
67-
Iterator(IterType iter, const std::vector<Centroid>& centroids) : iter_(iter), centroids_(centroids) {}
68-
std::unique_ptr<Iterator> Clone() const {
69-
if (iter_ != detail::GetCendIter<Reverse>(centroids_)) {
70-
return std::make_unique<Iterator>(std::next(detail::GetCbeginIter<Reverse>(centroids_),
71-
std::distance(detail::GetCbeginIter<Reverse>(centroids_), iter_)),
72-
centroids_);
73-
}
74-
return std::make_unique<Iterator>(detail::GetCendIter<Reverse>(centroids_), centroids_);
75-
}
76-
bool Next() {
77-
if (Valid()) {
78-
std::advance(iter_, 1);
79-
}
80-
return iter_ != detail::GetCendIter<Reverse>(centroids_);
81-
}
82-
83-
// The Prev function can only be called for item is not cend,
84-
// because we must guarantee the iterator to be inside the valid range before iteration.
85-
bool Prev() {
86-
if (Valid() && iter_ != detail::GetCendIter<Reverse>(centroids_)) {
87-
std::advance(iter_, -1);
88-
}
89-
return Valid();
90-
}
91-
bool Valid() const { return iter_ != detail::GetCendIter<Reverse>(centroids_); }
92-
StatusOr<Centroid> GetCentroid() const {
93-
if (iter_ == detail::GetCendIter<Reverse>(centroids_)) {
94-
return {::Status::NotOK, "invalid iterator during decoding tdigest centroid"};
95-
}
96-
return *iter_;
97-
}
98-
99-
private:
100-
IterType iter_;
101-
const std::vector<Centroid>& centroids_;
102-
};
103-
104-
std::unique_ptr<Iterator> Begin() const {
105-
return std::make_unique<Iterator>(detail::GetCbeginIter<Reverse>(centroids_), centroids_);
106-
}
107-
std::unique_ptr<Iterator> End() const {
108-
if (centroids_.empty()) {
109-
return std::make_unique<Iterator>(detail::GetCendIter<Reverse>(centroids_), centroids_);
110-
}
111-
return std::make_unique<Iterator>(std::prev(detail::GetCendIter<Reverse>(centroids_)), centroids_);
112-
}
113-
double TotalWeight() const { return static_cast<double>(meta_data_.total_weight); }
114-
double Min() const { return meta_data_.minimum; }
115-
double Max() const { return meta_data_.maximum; }
116-
uint64_t Size() const { return meta_data_.merged_nodes; }
117-
118-
private:
119-
const TDigestMetadata& meta_data_;
120-
const std::vector<Centroid>& centroids_;
121-
};
122-
12337
inline constexpr uint32_t kTDigestMaxCompression = 1000; // limit the compression to 1k
12438

12539
struct CentroidWithKey {
@@ -164,9 +78,10 @@ class TDigest : public SubKeyScanner {
16478

16579
rocksdb::Status Merge(engine::Context& ctx, const Slice& dest_digest, const std::vector<std::string>& source_digests,
16680
const TDigestMergeOptions& options);
167-
template <bool Reverse>
16881
rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const std::vector<double>& inputs,
16982
std::vector<int>& result);
83+
rocksdb::Status RevRank(engine::Context& ctx, const Slice& digest_name, const std::vector<double>& inputs,
84+
std::vector<int>& result);
17085
rocksdb::Status GetMetaData(engine::Context& context, const Slice& digest_name, TDigestMetadata* metadata);
17186

17287
private:
@@ -219,39 +134,4 @@ class TDigest : public SubKeyScanner {
219134
rocksdb::Status decodeCentroidFromKeyValue(const rocksdb::Slice& key, const rocksdb::Slice& value,
220135
Centroid* centroid) const;
221136
};
222-
223-
template <bool Reverse>
224-
rocksdb::Status TDigest::Rank(engine::Context& ctx, const Slice& digest_name, const std::vector<double>& inputs,
225-
std::vector<int>& result) {
226-
auto ns_key = AppendNamespacePrefix(digest_name);
227-
TDigestMetadata metadata;
228-
{
229-
LockGuard guard(storage_->GetLockManager(), ns_key);
230-
231-
if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) {
232-
return status;
233-
}
234-
235-
if (metadata.total_observations == 0) {
236-
result.resize(inputs.size(), -2);
237-
return rocksdb::Status::OK();
238-
}
239-
240-
if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
241-
return status;
242-
}
243-
}
244-
245-
std::vector<Centroid> centroids;
246-
if (auto status = dumpCentroids(ctx, ns_key, metadata, &centroids); !status.ok()) {
247-
return status;
248-
}
249-
250-
auto dump_centroids = DummyCentroids<Reverse>(metadata, centroids);
251-
if (auto status = TDigestRank<Reverse>(dump_centroids, inputs, result); !status) {
252-
return rocksdb::Status::InvalidArgument(status.Msg());
253-
}
254-
return rocksdb::Status::OK();
255-
}
256-
257137
} // namespace redis

0 commit comments

Comments
 (0)