Skip to content

Commit ea5f14f

Browse files
Koladata Teamcopybara-github
authored andcommitted
add tree-shaped data with primitives to deep ops benchmarks
PiperOrigin-RevId: 856257494 Change-Id: I984b2435abf5157a4dcdafdfa9d3e7d77a9ebcef
1 parent 20fc5cc commit ea5f14f

File tree

3 files changed

+134
-8
lines changed

3 files changed

+134
-8
lines changed

koladata/internal/op_utils/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,6 @@ cc_test(
395395
srcs = ["deep_clone_test.cc"],
396396
deps = [
397397
":deep_clone",
398-
":presence_and",
399398
"//koladata:test_utils",
400399
"//koladata/internal:data_bag",
401400
"//koladata/internal:data_item",
@@ -406,10 +405,8 @@ cc_test(
406405
"//koladata/internal:uuid_object",
407406
"//koladata/internal/testing:deep_op_utils",
408407
"//koladata/internal/testing:matchers",
409-
"@com_google_absl//absl/random",
410408
"@com_google_absl//absl/types:span",
411409
"@com_google_arolla//arolla/dense_array",
412-
"@com_google_arolla//arolla/memory",
413410
"@com_google_arolla//arolla/qtype",
414411
"@com_google_arolla//arolla/util",
415412
"@com_google_googletest//:gtest_main",
@@ -755,11 +752,13 @@ cc_test(
755752
deps = [
756753
":deep_op_benchmarks_util",
757754
":extract",
755+
"//koladata/internal:casting",
758756
"//koladata/internal:data_bag",
759757
"//koladata/internal:data_item",
760758
"//koladata/internal:data_slice",
761759
"//koladata/internal:dtype",
762760
"@com_google_absl//absl/status",
761+
"@com_google_absl//absl/status:statusor",
763762
"@com_google_arolla//arolla/qtype",
764763
"@com_google_arolla//arolla/util/testing",
765764
"@com_google_benchmark//:benchmark_main",
@@ -1274,6 +1273,7 @@ cc_library(
12741273
"@com_google_absl//absl/random",
12751274
"@com_google_absl//absl/random:distributions",
12761275
"@com_google_absl//absl/strings",
1276+
"@com_google_absl//absl/strings:str_format",
12771277
"@com_google_arolla//arolla/dense_array",
12781278
"@com_google_arolla//arolla/memory",
12791279
"@com_google_arolla//arolla/qtype",

koladata/internal/op_utils/deep_op_benchmarks_util.h

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
#ifndef KOLADATA_INTERNAL_OP_UTILS_DEEP_OP_BENCHMARKS_UTIL_H_
1616
#define KOLADATA_INTERNAL_OP_UTILS_DEEP_OP_BENCHMARKS_UTIL_H_
1717

18+
#include <cmath>
1819
#include <cstdint>
1920
#include <functional>
21+
#include <optional>
2022
#include <string>
2123
#include <utility>
2224
#include <vector>
@@ -27,6 +29,7 @@
2729
#include "absl/random/distributions.h"
2830
#include "absl/random/random.h"
2931
#include "absl/strings/str_cat.h"
32+
#include "absl/strings/str_format.h"
3033
#include "arolla/dense_array/dense_array.h"
3134
#include "arolla/memory/optional_value.h"
3235
#include "arolla/qtype/base_types.h"
@@ -71,10 +74,32 @@ constexpr auto kLayersBenchmarkFn = [](auto* b) {
7174
->Args({20, 100, 10, 10});
7275
};
7376

77+
constexpr auto kTreesBenchmarkFn = [](auto* b) {
78+
// Number of layers, number of non-leaf childs per object, number of leaves
79+
// per object, size of the slice.
80+
b->Args({2, 10, 10, 1})
81+
->Args({2, 10, 10, 10})
82+
->Args({2, 10, 10, 100})
83+
->Args({2, 1000, 1000, 1})
84+
->Args({2, 1000, 1000, 10})
85+
->Args({5, 10, 10, 1})
86+
->Args({5, 10, 10, 10})
87+
->Args({5, 10, 10, 100})
88+
->Args({18, 2, 2, 1})
89+
->Args({18, 2, 2, 10})
90+
->Args({18, 2, 2, 100});
91+
};
92+
7493
using RunBenchmarksFn =
7594
std::function<void(benchmark::State&, DataSliceImpl&, DataItem&,
7695
DataBagImplPtr&, DataBagImpl::FallbackSpan)>;
7796

97+
using RunCastingBenchmarksFn = std::function<void(
98+
benchmark::State&, DataSliceImpl& ds_impl, DataItem& ds_schema,
99+
DataBagImplPtr& ds_databag, const DataBagImpl::FallbackSpan ds_fallbacks,
100+
DataItem& new_schema, DataBagImplPtr& new_schema_databag,
101+
const DataBagImpl::FallbackSpan new_schema_fallbacks)>;
102+
78103
inline DataSliceImpl ShuffleObjectsSlice(const DataSliceImpl& ds,
79104
absl::BitGen& gen) {
80105
// Generate random permutation.
@@ -108,7 +133,8 @@ inline DataSliceImpl ApplyRandomMask(const DataSliceImpl& ds,
108133
return result;
109134
}
110135

111-
inline void BM_DisjointChains(benchmark::State& state, RunBenchmarksFn run_fn) {
136+
inline void BM_DisjointChains(benchmark::State& state,
137+
const RunBenchmarksFn& run_fn) {
112138
int64_t schema_depth = state.range(0);
113139
int64_t ds_size = state.range(1);
114140

@@ -131,7 +157,7 @@ inline void BM_DisjointChains(benchmark::State& state, RunBenchmarksFn run_fn) {
131157
}
132158

133159
inline void BM_DisjointChainsObjects(benchmark::State& state,
134-
RunBenchmarksFn run_fn) {
160+
const RunBenchmarksFn& run_fn) {
135161
int64_t schema_depth = state.range(0);
136162
int64_t ds_size = state.range(1);
137163

@@ -162,7 +188,7 @@ inline void BM_DisjointChainsObjects(benchmark::State& state,
162188
run_fn(state, root_ds, kObjectSchema, db, {});
163189
}
164190

165-
inline void BM_DAG(benchmark::State& state, RunBenchmarksFn run_fn) {
191+
inline void BM_DAG(benchmark::State& state, const RunBenchmarksFn& run_fn) {
166192
int64_t schema_depth = state.range(0);
167193
int64_t attr_count = state.range(1);
168194
int64_t presence_rate = state.range(2);
@@ -191,7 +217,8 @@ inline void BM_DAG(benchmark::State& state, RunBenchmarksFn run_fn) {
191217
run_fn(state, root_ds, root_schema, db, {});
192218
}
193219

194-
inline void BM_DAGObjects(benchmark::State& state, RunBenchmarksFn run_fn) {
220+
inline void BM_DAGObjects(benchmark::State& state,
221+
const RunBenchmarksFn& run_fn) {
195222
int64_t schema_depth = state.range(0);
196223
int64_t attr_count = state.range(1);
197224
int64_t presence_rate = state.range(2);
@@ -231,6 +258,73 @@ inline void BM_DAGObjects(benchmark::State& state, RunBenchmarksFn run_fn) {
231258
run_fn(state, root_ds, kObjectSchema, db, {});
232259
}
233260

261+
inline void BM_TreeShapedIntToFloat(benchmark::State& state,
262+
const RunCastingBenchmarksFn& run_fn) {
263+
int64_t schema_depth = state.range(0);
264+
int64_t schema_attr_count = state.range(1);
265+
int64_t primitive_attr_count = state.range(2);
266+
int64_t ds_size = state.range(3);
267+
int64_t num_leaves =
268+
static_cast<int64_t>(std::pow(schema_attr_count, schema_depth - 1)) *
269+
primitive_attr_count * ds_size;
270+
state.SetLabel(absl::StrFormat(
271+
"ds_size=%d; depth=%d; attrs_per_node=%d+%d; total_leaves=%d", ds_size,
272+
schema_depth, schema_attr_count, primitive_attr_count, num_leaves));
273+
274+
CancellationContext::ScopeGuard cancellation_scope;
275+
auto db_a = DataBagImpl::CreateEmptyDatabag();
276+
auto db_b = DataBagImpl::CreateEmptyDatabag();
277+
auto root_schema_a = DataItem(internal::AllocateExplicitSchema());
278+
auto root_schema_b = DataItem(internal::AllocateExplicitSchema());
279+
auto root_ds = DataSliceImpl::AllocateEmptyObjects(ds_size);
280+
std::vector<DataItem> schemas_a({root_schema_a});
281+
std::vector<DataItem> schemas_b({root_schema_b});
282+
std::vector<DataSliceImpl> ds({root_ds});
283+
int value = 0;
284+
for (int64_t i = 0; i < schema_depth; ++i) {
285+
std::vector<DataItem> child_schemas_a;
286+
std::vector<DataItem> child_schemas_b;
287+
std::vector<DataSliceImpl> child_ds;
288+
for (uint64_t id = 0; id < schemas_a.size(); ++id) {
289+
auto schema_a = schemas_a[id];
290+
auto schema_b = schemas_b[id];
291+
auto cur_ds = ds[id];
292+
for (int64_t j = 0; j < schema_attr_count; ++j)
293+
{
294+
std::string attr_name = absl::StrCat("layer_", i, "_child_", j);
295+
auto child_schema_a = DataItem(internal::AllocateExplicitSchema());
296+
auto child_schema_b = DataItem(internal::AllocateExplicitSchema());
297+
auto cur_child_ds = DataSliceImpl::AllocateEmptyObjects(ds_size);
298+
CHECK_OK(db_a->SetSchemaAttr(schema_a, attr_name, child_schema_a));
299+
CHECK_OK(db_b->SetSchemaAttr(schema_b, attr_name, child_schema_b));
300+
CHECK_OK(db_a->SetAttr(cur_ds, attr_name, cur_child_ds));
301+
child_ds.push_back(std::move(cur_child_ds));
302+
child_schemas_a.push_back(std::move(child_schema_a));
303+
child_schemas_b.push_back(std::move(child_schema_b));
304+
}
305+
for (int64_t j = 0; j < primitive_attr_count; ++j)
306+
{
307+
std::string attr_name = absl::StrCat("layer_", i, "_primitive_", j);
308+
auto child_schema_a = DataItem(schema::kInt32);
309+
auto child_schema_b = DataItem(schema::kFloat32);
310+
std::vector<std::optional<int>> values(ds_size);
311+
for (int64_t h = 0; h < ds_size; ++h) {
312+
values[h] = ++value;
313+
}
314+
DataSliceImpl::Create(
315+
arolla::CreateDenseArray<int>(values.begin(), values.end()));
316+
CHECK_OK(db_a->SetSchemaAttr(schema_a, attr_name, child_schema_a));
317+
CHECK_OK(db_b->SetSchemaAttr(schema_b, attr_name, child_schema_b));
318+
}
319+
}
320+
std::swap(ds, child_ds);
321+
std::swap(schemas_a, child_schemas_a);
322+
std::swap(schemas_b, child_schemas_b);
323+
}
324+
run_fn(state, root_ds, root_schema_a, db_a, {}, root_schema_b, db_b,
325+
{});
326+
}
327+
234328
} // namespace benchmarks_utils
235329

236330
} // namespace koladata::internal

koladata/internal/op_utils/extract_benchmarks.cc

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
//
1515
#include "benchmark/benchmark.h"
1616
#include "absl/status/status.h"
17+
#include "absl/status/statusor.h"
1718
#include "arolla/qtype/base_types.h"
19+
#include "koladata/internal/casting.h"
1820
#include "koladata/internal/data_bag.h"
1921
#include "koladata/internal/data_item.h"
2022
#include "koladata/internal/data_slice.h"
@@ -42,6 +44,32 @@ void RunBenchmarks(benchmark::State& state, DataSliceImpl& ds, DataItem& schema,
4244
}
4345
}
4446

47+
void RunCastingBenchmarks(benchmark::State& state, DataSliceImpl& ds,
48+
DataItem& schema, DataBagImplPtr& databag,
49+
DataBagImpl::FallbackSpan fallbacks,
50+
DataItem& new_schema, DataBagImplPtr& schema_databag,
51+
DataBagImpl::FallbackSpan schema_fallbacks) {
52+
auto cast_data_callback =
53+
static_cast<absl::StatusOr<internal::DataSliceImpl> (*)(
54+
const internal::DataSliceImpl&, const internal::DataItem&)>(
55+
schema::CastDataTo);
56+
while (state.KeepRunning()) {
57+
benchmark::DoNotOptimize(ds);
58+
benchmark::DoNotOptimize(schema);
59+
benchmark::DoNotOptimize(databag);
60+
benchmark::DoNotOptimize(fallbacks);
61+
benchmark::DoNotOptimize(new_schema);
62+
benchmark::DoNotOptimize(schema_databag);
63+
benchmark::DoNotOptimize(schema_fallbacks);
64+
auto result_db = DataBagImpl::CreateEmptyDatabag();
65+
ExtractOp(result_db.get())(ds, new_schema, *databag, fallbacks,
66+
&*schema_databag, schema_fallbacks,
67+
/*max_depth=*/-1, cast_data_callback)
68+
.IgnoreError();
69+
benchmark::DoNotOptimize(result_db);
70+
}
71+
}
72+
4573
void BM_DisjointChains(benchmark::State& state) {
4674
benchmarks_utils::BM_DisjointChains(state, RunBenchmarks);
4775
}
@@ -62,6 +90,11 @@ void BM_DAGObjects(benchmark::State& state) {
6290
}
6391
BENCHMARK(BM_DAGObjects)->Apply(benchmarks_utils::kLayersBenchmarkFn);
6492

93+
void BM_TreeShapedIntToFloat(benchmark::State& state) {
94+
benchmarks_utils::BM_TreeShapedIntToFloat(state, RunCastingBenchmarks);
95+
}
96+
BENCHMARK(BM_TreeShapedIntToFloat)->Apply(benchmarks_utils::kTreesBenchmarkFn);
97+
6598
inline void BM_ScalarPrimitive(benchmark::State& state) {
6699
auto input_db = DataBagImpl::CreateEmptyDatabag();
67100
auto result_db = DataBagImpl::CreateEmptyDatabag();
@@ -77,7 +110,6 @@ inline void BM_ScalarPrimitive(benchmark::State& state) {
77110
benchmark::DoNotOptimize(result_db);
78111
}
79112
}
80-
81113
BENCHMARK(BM_ScalarPrimitive);
82114

83115
} // namespace

0 commit comments

Comments
 (0)