Skip to content

Commit af3e758

Browse files
author
Rafał Hibner
committed
Add BackpressureCombiner
1 parent 0ecc472 commit af3e758

File tree

4 files changed

+244
-0
lines changed

4 files changed

+244
-0
lines changed

cpp/src/arrow/acero/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set(ARROW_ACERO_REQUIRED_DEPENDENCIES Arrow ArrowCompute)
3333

3434
set(ARROW_ACERO_SRCS
3535
accumulation_queue.cc
36+
backpressure.cc
3637
scalar_aggregate_node.cc
3738
groupby_aggregate_node.cc
3839
aggregate_internal.cc
@@ -173,6 +174,7 @@ function(add_arrow_acero_test REL_TEST_NAME)
173174
${ARG_UNPARSED_ARGUMENTS})
174175
endfunction()
175176

177+
add_arrow_acero_test(backpressure_test SOURCES backpressure_test.cc)
176178
add_arrow_acero_test(plan_test SOURCES plan_test.cc test_nodes_test.cc)
177179
add_arrow_acero_test(source_node_test SOURCES source_node_test.cc)
178180
add_arrow_acero_test(fetch_node_test SOURCES fetch_node_test.cc)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "arrow/acero/backpressure.h"
19+
namespace arrow::acero {
20+
BackpressureCombiner::BackpressureCombiner(
21+
std::unique_ptr<BackpressureControl> backpressure_control)
22+
: backpressure_control_(std::move(backpressure_control)) {}
23+
24+
// Called from Source nodes
25+
void BackpressureCombiner::Pause(Source* output, bool strong_connection) {
26+
std::lock_guard<std::mutex> lg(mutex_);
27+
auto& paused_ = strong_connection ? strong_paused_ : weak_paused_;
28+
auto& paused_count_ = strong_connection ? strong_paused_count_ : weak_paused_count_;
29+
30+
if (!paused_[output]) {
31+
paused_[output] = true;
32+
paused_count_++;
33+
UpdatePauseStateUnlocked();
34+
}
35+
}
36+
37+
// Called from Source nodes
38+
void BackpressureCombiner::Resume(Source* output, bool strong_connection) {
39+
std::lock_guard<std::mutex> lg(mutex_);
40+
auto& paused_ = strong_connection ? strong_paused_ : weak_paused_;
41+
auto& paused_count_ = strong_connection ? strong_paused_count_ : weak_paused_count_;
42+
if (paused_.find(output) == paused_.end()) {
43+
paused_[output] = false;
44+
UpdatePauseStateUnlocked();
45+
}
46+
if (paused_[output]) {
47+
paused_[output] = false;
48+
paused_count_--;
49+
UpdatePauseStateUnlocked();
50+
}
51+
}
52+
53+
void BackpressureCombiner::UpdatePauseStateUnlocked() {
54+
bool should_be_paused =
55+
strong_paused_count_ > 0 || weak_paused_count_ == weak_paused_.size();
56+
if (should_be_paused) {
57+
if (!paused) {
58+
backpressure_control_->Pause();
59+
paused = true;
60+
}
61+
} else {
62+
if (paused) {
63+
backpressure_control_->Resume();
64+
paused = false;
65+
}
66+
}
67+
}
68+
69+
BackpressureCombiner::Source::Source(BackpressureCombiner* ctrl, bool strong_connection) {
70+
if (ctrl) {
71+
AddController(ctrl, strong_connection);
72+
}
73+
}
74+
75+
void BackpressureCombiner::Source::AddController(BackpressureCombiner* ctrl,
76+
bool strong_connection) {
77+
ctrl->Resume(this, strong_connection); // populate map in controller
78+
connections_.push_back(Connection{ctrl, strong_connection});
79+
}
80+
void BackpressureCombiner::Source::Pause() {
81+
for (auto& conn_ : connections_) {
82+
conn_.ctrl->Pause(this, conn_.strong);
83+
}
84+
}
85+
void BackpressureCombiner::Source::Resume() {
86+
for (auto& conn_ : connections_) {
87+
conn_.ctrl->Resume(this, conn_.strong);
88+
}
89+
}
90+
91+
} // namespace arrow::acero

cpp/src/arrow/acero/backpressure.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#pragma once
19+
#include "arrow/acero/options.h"
20+
21+
#include <mutex>
22+
namespace arrow::acero {
23+
24+
// Provides infrastructure of combining multiple backpressure sources and propagate the
25+
// result into BackpressureControl There are two types of Source: strong - pause on any
26+
// strong Source within controller
27+
class ARROW_ACERO_EXPORT BackpressureCombiner {
28+
public:
29+
explicit BackpressureCombiner(
30+
std::unique_ptr<BackpressureControl> backpressure_control);
31+
32+
// Instances of Source can be used as usual BackpresureControl.
33+
// Source can be connected with one or more BackpressureCombiner
34+
class ARROW_ACERO_EXPORT Source : public BackpressureControl {
35+
public:
36+
// strong - strong_connection=true
37+
// weak - strong_connection=false
38+
explicit Source(BackpressureCombiner* ctrl = nullptr, bool strong_connection = true);
39+
void AddController(BackpressureCombiner* ctrl, bool strong_connection = true);
40+
void Pause() override;
41+
void Resume() override;
42+
43+
private:
44+
struct Connection {
45+
BackpressureCombiner* ctrl;
46+
bool strong;
47+
};
48+
std::vector<Connection> connections_;
49+
};
50+
51+
private:
52+
friend class Source;
53+
void Pause(Source* output, bool strong_connection);
54+
void Resume(Source* output, bool strong_connection);
55+
56+
void UpdatePauseStateUnlocked();
57+
std::unique_ptr<BackpressureControl> backpressure_control_;
58+
std::mutex mutex_;
59+
std::unordered_map<Source*, bool> strong_paused_;
60+
std::unordered_map<Source*, bool> weak_paused_;
61+
size_t strong_paused_count_{0};
62+
size_t weak_paused_count_{0};
63+
bool paused{false};
64+
};
65+
66+
} // namespace arrow::acero
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <gtest/gtest.h>
19+
20+
#include "arrow/acero/backpressure.h"
21+
22+
namespace arrow {
23+
namespace acero {
24+
25+
class MonitorBackpressureControl : public acero::BackpressureControl {
26+
public:
27+
explicit MonitorBackpressureControl(std::atomic<bool>& paused) : paused(paused) {}
28+
virtual void Pause() { paused = true; }
29+
virtual void Resume() { paused = false; }
30+
std::atomic<bool>& paused;
31+
};
32+
33+
TEST(BackpressureCombiner, Basic) {
34+
std::atomic<bool> paused;
35+
BackpressureCombiner combiner(std::make_unique<MonitorBackpressureControl>(paused));
36+
37+
BackpressureCombiner::Source weak_source1(&combiner, false);
38+
BackpressureCombiner::Source weak_source2;
39+
weak_source2.AddController(&combiner, false);
40+
BackpressureCombiner::Source strong_source1(&combiner);
41+
BackpressureCombiner::Source strong_source2;
42+
strong_source2.AddController(&combiner);
43+
44+
// Any strong causes pause
45+
ASSERT_FALSE(paused);
46+
strong_source1.Pause();
47+
ASSERT_TRUE(paused);
48+
strong_source2.Pause();
49+
ASSERT_TRUE(paused);
50+
strong_source1.Resume();
51+
ASSERT_TRUE(paused);
52+
strong_source2.Resume();
53+
ASSERT_FALSE(paused);
54+
55+
// All weak cause pause
56+
ASSERT_FALSE(paused);
57+
weak_source1.Pause();
58+
ASSERT_FALSE(paused);
59+
weak_source2.Pause();
60+
ASSERT_TRUE(paused);
61+
weak_source1.Resume();
62+
ASSERT_FALSE(paused);
63+
weak_source2.Resume();
64+
ASSERT_FALSE(paused);
65+
66+
// mixed use
67+
strong_source1.Pause();
68+
ASSERT_TRUE(paused);
69+
70+
ASSERT_TRUE(paused);
71+
weak_source1.Pause();
72+
ASSERT_TRUE(paused);
73+
weak_source2.Pause();
74+
75+
strong_source1.Resume();
76+
ASSERT_TRUE(paused);
77+
78+
weak_source1.Resume();
79+
ASSERT_FALSE(paused);
80+
weak_source2.Resume();
81+
ASSERT_FALSE(paused);
82+
}
83+
84+
} // namespace acero
85+
} // namespace arrow

0 commit comments

Comments
 (0)