Skip to content

Commit a627648

Browse files
Mikester95copybara-github
authored andcommitted
Add DerivedExecutor to wrap an existing Executor and add extra context guard.
PiperOrigin-RevId: 869886940 Change-Id: Ie3fb3653ad7badbc1f3032a4f398e421f97d4d4b
1 parent 8c0cb81 commit a627648

File tree

4 files changed

+253
-0
lines changed

4 files changed

+253
-0
lines changed

koladata/functor/parallel/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,31 @@ cc_library(
10841084
],
10851085
)
10861086

1087+
cc_test(
1088+
name = "derived_executor_test",
1089+
srcs = ["derived_executor_test.cc"],
1090+
deps = [
1091+
":context_guard",
1092+
":derived_executor",
1093+
":executor",
1094+
"@com_google_absl//absl/status",
1095+
"@com_google_absl//absl/status:status_matchers",
1096+
"@com_google_googletest//:gtest_main",
1097+
],
1098+
)
1099+
1100+
cc_library(
1101+
name = "derived_executor",
1102+
srcs = ["derived_executor.cc"],
1103+
hdrs = ["derived_executor.h"],
1104+
deps = [
1105+
":executor",
1106+
"@com_google_absl//absl/base:nullability",
1107+
"@com_google_absl//absl/strings",
1108+
"@com_google_arolla//arolla/util",
1109+
],
1110+
)
1111+
10871112
cc_test(
10881113
name = "parallel_call_utils_test",
10891114
srcs = ["parallel_call_utils_test.cc"],
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
#include "koladata/functor/parallel/derived_executor.h"
16+
17+
#include <memory>
18+
#include <string>
19+
#include <utility>
20+
21+
#include "absl/base/nullability.h"
22+
#include "absl/strings/str_cat.h"
23+
#include "arolla/util/fast_dynamic_downcast_final.h"
24+
#include "koladata/functor/parallel/executor.h"
25+
26+
namespace koladata::functor::parallel {
27+
28+
// The task of setting up the extra ContextGuard is delegated to the base class.
29+
DerivedExecutor::DerivedExecutor(
30+
absl_nonnull ExecutorPtr base_executor,
31+
ContextGuardInitializer extra_context_initializer)
32+
: Executor(std::move(extra_context_initializer)),
33+
base_executor_(std::move(base_executor)) {}
34+
35+
std::string DerivedExecutor::Repr() const noexcept {
36+
int derived_depth = 1;
37+
38+
Executor* executor = base_executor_.get();
39+
DerivedExecutor* derived_executor;
40+
while (
41+
(derived_executor = arolla::fast_dynamic_downcast_final<DerivedExecutor*>(
42+
executor)) != nullptr) {
43+
++derived_depth;
44+
executor = derived_executor->base_executor_.get();
45+
}
46+
47+
// executor points to the first non-derived executor in the chain.
48+
return absl::StrCat("derived_executor[", executor->Repr(), ", ",
49+
derived_depth, "]");
50+
}
51+
52+
void DerivedExecutor::DoSchedule(TaskFn task_fn) noexcept {
53+
base_executor_->Schedule(std::move(task_fn));
54+
}
55+
56+
} // namespace koladata::functor::parallel
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
#ifndef KOLADATA_FUNCTOR_PARALLEL_DERIVED_EXECUTOR_H_
16+
#define KOLADATA_FUNCTOR_PARALLEL_DERIVED_EXECUTOR_H_
17+
18+
#include <memory>
19+
#include <string>
20+
21+
#include "absl/base/nullability.h"
22+
#include "koladata/functor/parallel/executor.h"
23+
24+
namespace koladata::functor::parallel {
25+
26+
// An executor that adds an extra ContextGuard to scheduled tasks.
27+
class DerivedExecutor final : public Executor {
28+
public:
29+
DerivedExecutor(absl_nonnull ExecutorPtr base_executor,
30+
ContextGuardInitializer extra_context_initializer);
31+
32+
std::string Repr() const noexcept final;
33+
34+
private:
35+
void DoSchedule(TaskFn task_fn) noexcept final;
36+
37+
ExecutorPtr base_executor_;
38+
};
39+
40+
using DerivedExecutorPtr = std::shared_ptr<DerivedExecutor>;
41+
42+
} // namespace koladata::functor::parallel
43+
44+
#endif // KOLADATA_FUNCTOR_PARALLEL_DERIVED_EXECUTOR_H_
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
#include "koladata/functor/parallel/derived_executor.h"
16+
17+
#include <functional>
18+
#include <memory>
19+
#include <string>
20+
#include <utility>
21+
#include <vector>
22+
23+
#include "gmock/gmock.h"
24+
#include "gtest/gtest.h"
25+
#include "absl/status/status.h"
26+
#include "absl/status/status_matchers.h"
27+
#include "koladata/functor/parallel/context_guard.h"
28+
#include "koladata/functor/parallel/executor.h"
29+
30+
namespace koladata::functor::parallel {
31+
namespace {
32+
33+
using ::absl_testing::IsOkAndHolds;
34+
using ::absl_testing::StatusIs;
35+
36+
class TestExecutor : public Executor {
37+
public:
38+
using Executor::Executor;
39+
40+
void DoSchedule(TaskFn task_fn) noexcept override {
41+
std::move(task_fn)();
42+
}
43+
44+
std::string Repr() const noexcept override { return "test_executor"; }
45+
};
46+
47+
class TestScopeGuard {
48+
public:
49+
TestScopeGuard(std::function<void()> on_construct_fn,
50+
std::function<void()> on_destruct_fn)
51+
: on_construct_fn_(std::move(on_construct_fn)),
52+
on_destruct_fn_(std::move(on_destruct_fn)) {
53+
on_construct_fn_();
54+
}
55+
56+
~TestScopeGuard() noexcept { on_destruct_fn_(); }
57+
58+
private:
59+
std::function<void()> on_construct_fn_;
60+
std::function<void()> on_destruct_fn_;
61+
};
62+
63+
TEST(DerivedExecutorTest, Repr) {
64+
auto base_executor = std::make_shared<TestExecutor>();
65+
auto derived_executor =
66+
std::make_shared<DerivedExecutor>(base_executor, nullptr);
67+
EXPECT_EQ(derived_executor->Repr(), "derived_executor[test_executor, 1]");
68+
69+
auto derived_executor_2 = std::make_shared<DerivedExecutor>(
70+
derived_executor, nullptr);
71+
EXPECT_EQ(derived_executor_2->Repr(), "derived_executor[test_executor, 2]");
72+
73+
auto derived_executor_3 = std::make_shared<DerivedExecutor>(
74+
derived_executor_2, nullptr);
75+
EXPECT_EQ(derived_executor_3->Repr(), "derived_executor[test_executor, 3]");
76+
}
77+
78+
TEST(DerivedExecutorTest, ScheduleCallsBaseExecutor) {
79+
auto base_executor = std::make_shared<TestExecutor>();
80+
DerivedExecutorPtr derived_executor =
81+
std::make_shared<DerivedExecutor>(base_executor, nullptr);
82+
83+
bool task_executed = false;
84+
derived_executor->Schedule([&]() mutable { task_executed = true; });
85+
EXPECT_TRUE(task_executed);
86+
}
87+
88+
TEST(DerivedExecutorTest, CurrentExecutor) {
89+
EXPECT_THAT(CurrentExecutor(), StatusIs(absl::StatusCode::kInvalidArgument));
90+
auto base_executor = std::make_shared<TestExecutor>();
91+
auto derived_executor =
92+
std::make_shared<DerivedExecutor>(base_executor, nullptr);
93+
bool task_executed = false;
94+
derived_executor->Schedule([&] {
95+
task_executed = true;
96+
EXPECT_THAT(CurrentExecutor(), IsOkAndHolds(derived_executor));
97+
});
98+
EXPECT_THAT(CurrentExecutor(), StatusIs(absl::StatusCode::kInvalidArgument));
99+
EXPECT_TRUE(task_executed);
100+
}
101+
102+
TEST(DerivedExecutorTest, ExtraContextGuard) {
103+
std::vector<std::string> log;
104+
auto base_guard_initializer = [&](ContextGuard& context_guard) {
105+
context_guard.init<TestScopeGuard>(
106+
[&] { log.push_back("enter_base_scope"); },
107+
[&] { log.push_back("leave_base_scope"); });
108+
};
109+
auto derived_guard_initializer = [&](ContextGuard& context_guard) {
110+
context_guard.init<TestScopeGuard>(
111+
[&] { log.push_back("enter_derived_scope"); },
112+
[&] { log.push_back("leave_derived_scope"); });
113+
};
114+
auto base_executor =
115+
std::make_shared<TestExecutor>(std::move(base_guard_initializer));
116+
auto derived_executor = std::make_shared<DerivedExecutor>(
117+
base_executor, std::move(derived_guard_initializer));
118+
derived_executor->Schedule([&] { log.push_back("run_task_1"); });
119+
derived_executor->Schedule([&] { log.push_back("run_task_2"); });
120+
EXPECT_THAT(log, testing::ElementsAre(
121+
"enter_base_scope", "enter_derived_scope", "run_task_1",
122+
"leave_derived_scope", "leave_base_scope",
123+
"enter_base_scope", "enter_derived_scope", "run_task_2",
124+
"leave_derived_scope", "leave_base_scope"));
125+
}
126+
127+
} // namespace
128+
} // namespace koladata::functor::parallel

0 commit comments

Comments
 (0)