Skip to content

Commit 2c1e071

Browse files
goldvitalycopybara-github
authored andcommitted
Add DataSliceImpl::TransformValues for faster implementatation of the operators working per array.
Before removed values similar functionality was integrated into the `DataSliceBuilder`. With removed values `DataSliceBuilder` is very overloaded and complicatted. Adding such functionality in it will be quite complicated and hard to maintain. PiperOrigin-RevId: 717957212 Change-Id: I3572160477d494af630382537d0b52a0359d957e
1 parent 951a792 commit 2c1e071

File tree

3 files changed

+460
-48
lines changed

3 files changed

+460
-48
lines changed

koladata/internal/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ cc_library(
406406
":types_buffer",
407407
"@com_google_absl//absl/base:core_headers",
408408
"@com_google_absl//absl/container:flat_hash_map",
409+
"@com_google_absl//absl/container:flat_hash_set",
409410
"@com_google_absl//absl/container:inlined_vector",
410411
"@com_google_absl//absl/log",
411412
"@com_google_absl//absl/log:check",
@@ -420,6 +421,7 @@ cc_library(
420421
"@com_google_arolla//arolla/memory",
421422
"@com_google_arolla//arolla/qtype",
422423
"@com_google_arolla//arolla/util",
424+
"@com_google_arolla//arolla/util:status_backport",
423425
],
424426
)
425427

@@ -435,6 +437,7 @@ cc_test(
435437
":types_buffer",
436438
"@com_google_absl//absl/status",
437439
"@com_google_absl//absl/status:status_matchers",
440+
"@com_google_absl//absl/status:statusor",
438441
"@com_google_absl//absl/strings",
439442
"@com_google_arolla//arolla/dense_array",
440443
"@com_google_arolla//arolla/dense_array/qtype",

koladata/internal/data_slice.h

Lines changed: 130 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
#include <cstddef>
2020
#include <cstdint>
2121
#include <memory>
22+
#include <optional>
2223
#include <string>
2324
#include <type_traits>
2425
#include <utility>
2526
#include <variant>
2627
#include <vector>
2728

29+
#include "absl/container/flat_hash_set.h"
2830
#include "absl/container/inlined_vector.h"
2931
#include "absl/log/check.h"
3032
#include "absl/status/status.h"
@@ -47,8 +49,10 @@
4749
#include "arolla/util/iterator.h"
4850
#include "arolla/util/refcount_ptr.h"
4951
#include "arolla/util/repr.h"
52+
#include "arolla/util/status.h"
5053
#include "arolla/util/text.h"
5154
#include "arolla/util/unit.h"
55+
#include "arolla/util/status_macros_backport.h"
5256

5357
namespace koladata::internal {
5458

@@ -197,6 +201,27 @@ class DataSliceImpl {
197201
}
198202
}
199203

204+
// Transforms values of the DataSliceImpl.
205+
//
206+
// `result_size` is the size of the result DataSliceImpl.
207+
//
208+
// `allocation_ids` is the set of allocation ids that are allowed to be
209+
// present in the result DataSliceImpl. If `std::nullopt`, any allocation id
210+
// is allowed.
211+
//
212+
// `transform` is a functor that takes a DenseArray of values of the
213+
// DataSliceImpl and returns a DenseArray of values of the result
214+
// DataSliceImpl.
215+
//
216+
// All DenseArrays in the result DataSliceImpl must have the same size and
217+
// non-intersecting presence ids.
218+
//
219+
// All resulted DenseArrays must have different types.
220+
template <class DenseArrayTransformer>
221+
absl::StatusOr<DataSliceImpl> TransformValues(
222+
size_t result_size, std::optional<AllocationIdSet> allocation_ids,
223+
DenseArrayTransformer&& transform) const;
224+
200225
// Returns DataItem with given offset.
201226
DataItem operator[](int64_t offset) const;
202227

@@ -262,6 +287,15 @@ class DataSliceImpl {
262287
private:
263288
friend class SliceBuilder;
264289

290+
template <class T>
291+
void AddAllocIds(const arolla::DenseArray<T>& array) {
292+
if constexpr (std::is_same_v<T, ObjectId>) {
293+
AllocationIdSet& id_set = internal_->allocation_ids;
294+
array.ForEachPresent(
295+
[&](int64_t id, ObjectId obj) { id_set.Insert(AllocationId(obj)); });
296+
}
297+
}
298+
265299
using Variant = std::variant<arolla::DenseArray<ObjectId>, //
266300
arolla::DenseArray<int32_t>, //
267301
arolla::DenseArray<float>, //
@@ -314,24 +348,40 @@ class DataSliceImpl {
314348

315349
namespace data_slice_impl {
316350

317-
template <class T, class... Ts>
318-
bool VerifyNonIntersectingIds(const arolla::DenseArray<T>& main_values,
319-
const arolla::DenseArray<Ts>&... values) {
320-
size_t bitmap_size = arolla::bitmap::BitmapSize(main_values.size());
321-
std::vector<arolla::bitmap::Word> present_ids(bitmap_size);
322-
auto process_ids = [&](const auto& array) {
323-
for (size_t i = 0; i < bitmap_size; ++i) {
351+
class NonIntersectingIdsChecker {
352+
public:
353+
explicit NonIntersectingIdsChecker(size_t size)
354+
: size_(size),
355+
bitmap_size_(arolla::bitmap::BitmapSize(size)),
356+
present_ids_(bitmap_size_) {}
357+
358+
template <class T>
359+
bool Verify(const arolla::DenseArray<T>& array) {
360+
if (array.size() != size_) {
361+
return false;
362+
}
363+
for (size_t i = 0; i < bitmap_size_; ++i) {
324364
arolla::bitmap::Word word = arolla::bitmap::GetWordWithOffset(
325365
array.bitmap, i, array.bitmap_bit_offset);
326-
if (present_ids[i] & word) {
327-
return true;
366+
if (present_ids_[i] & word) {
367+
return false;
328368
}
329-
present_ids[i] |= word;
369+
present_ids_[i] |= word;
330370
}
331-
return false;
332-
};
333-
process_ids(main_values);
334-
return !(process_ids(values) || ...);
371+
return true;
372+
}
373+
374+
private:
375+
size_t size_;
376+
size_t bitmap_size_;
377+
std::vector<arolla::bitmap::Word> present_ids_;
378+
};
379+
380+
template <class T, class... Ts>
381+
bool VerifyNonIntersectingIds(const arolla::DenseArray<T>& main_values,
382+
const arolla::DenseArray<Ts>&... values) {
383+
NonIntersectingIdsChecker checker(main_values.size());
384+
return checker.Verify(main_values) && (checker.Verify(values) && ...);
335385
}
336386

337387
template <class T>
@@ -369,6 +419,70 @@ constexpr bool AreAllTypesDistinct(std::type_identity<T>,
369419

370420
} // namespace data_slice_impl
371421

422+
template <class DenseArrayTransformer>
423+
absl::StatusOr<DataSliceImpl> DataSliceImpl::TransformValues(
424+
size_t result_size, std::optional<AllocationIdSet> allocation_ids,
425+
DenseArrayTransformer&& transform) const {
426+
const auto& values = internal_->values;
427+
DataSliceImpl res;
428+
429+
auto& res_impl = *res.internal_;
430+
res_impl.size = result_size;
431+
432+
if (allocation_ids.has_value()) {
433+
res_impl.allocation_ids = std::move(*allocation_ids);
434+
}
435+
436+
#ifndef NDEBUG
437+
absl::flat_hash_set<arolla::QTypePtr> qtypes;
438+
qtypes.reserve(values.size());
439+
data_slice_impl::NonIntersectingIdsChecker checker(result_size);
440+
#endif
441+
442+
auto add_array = [&]<class T>(arolla::DenseArray<T> transformed_array) {
443+
#ifndef NDEBUG
444+
DCHECK(qtypes.insert(arolla::GetQType<T>()).second)
445+
<< "duplicated type: " << arolla::GetQType<T>();
446+
DCHECK_EQ(transformed_array.size(), result_size);
447+
DCHECK_EQ(transformed_array.bitmap_bit_offset, 0);
448+
DCHECK(checker.Verify(transformed_array)) << "ids are intersecting";
449+
#endif
450+
if (!allocation_ids.has_value()) {
451+
res.AddAllocIds(transformed_array);
452+
}
453+
if (!transformed_array.IsAllMissing()) {
454+
res_impl.dtype = res_impl.values.empty() ? arolla::GetQType<T>()
455+
: arolla::GetNothingQType();
456+
res_impl.values.push_back(std::move(transformed_array));
457+
}
458+
};
459+
460+
res_impl.values.reserve(values.size());
461+
462+
for (const Variant& vals : values) {
463+
RETURN_IF_ERROR(std::visit(
464+
[&](const auto& array) -> absl::Status {
465+
auto s = transform(array);
466+
if constexpr (arolla::IsStatusOrT<decltype(s)>::value) {
467+
if (!s.ok()) {
468+
return s.status();
469+
}
470+
add_array(*std::move(s));
471+
} else {
472+
add_array(std::move(s));
473+
}
474+
return absl::OkStatus();
475+
},
476+
vals));
477+
}
478+
479+
if (res_impl.values.size() > 1) {
480+
InitTypesBuffer(res_impl);
481+
}
482+
DCHECK(res.VerifyAllocIdsConsistency());
483+
return res;
484+
}
485+
372486
template <class T, class... Ts>
373487
void DataSliceImpl::CreateImpl(DataSliceImpl& res,
374488
arolla::DenseArray<T> main_values,
@@ -408,15 +522,8 @@ template <class T, class... Ts>
408522
DataSliceImpl DataSliceImpl::Create(arolla::DenseArray<T> main_values,
409523
arolla::DenseArray<Ts>... values) {
410524
DataSliceImpl res;
411-
auto add_alloc_ids = [&res](const auto& arr) {
412-
if constexpr (std::is_same_v<decltype(arr), const ObjectIdArray&>) {
413-
AllocationIdSet& id_set = res.internal_->allocation_ids;
414-
arr.ForEachPresent(
415-
[&](int64_t id, ObjectId obj) { id_set.Insert(AllocationId(obj)); });
416-
}
417-
};
418-
add_alloc_ids(main_values);
419-
(add_alloc_ids(values), ...);
525+
res.AddAllocIds(main_values);
526+
(res.AddAllocIds(values), ...);
420527
CreateImpl(res, std::move(main_values), std::move(values)...);
421528
return res;
422529
}

0 commit comments

Comments
 (0)