Skip to content

Commit c6be761

Browse files
authored
fix: Faster implementation of work queue (#2887)
1 parent f33f15c commit c6be761

File tree

4 files changed

+198
-127
lines changed

4 files changed

+198
-127
lines changed

benchmarks/rpc/WorkQueueBenchmarks.cpp

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@
2929

3030
#include <benchmark/benchmark.h>
3131
#include <boost/asio/steady_timer.hpp>
32+
#include <boost/asio/thread_pool.hpp>
33+
#include <boost/json/object.hpp>
3234

35+
#include <algorithm>
3336
#include <atomic>
3437
#include <cassert>
3538
#include <chrono>
3639
#include <cstddef>
3740
#include <cstdint>
3841
#include <mutex>
42+
#include <thread>
43+
#include <vector>
3944

4045
using namespace rpc;
4146
using namespace util::config;
@@ -75,36 +80,56 @@ benchmarkWorkQueue(benchmark::State& state)
7580
{
7681
init();
7782

78-
auto const total = static_cast<size_t>(state.range(0));
79-
auto const numThreads = static_cast<uint32_t>(state.range(1));
80-
auto const maxSize = static_cast<uint32_t>(state.range(2));
81-
auto const delayMs = static_cast<uint32_t>(state.range(3));
83+
auto const wqThreads = static_cast<uint32_t>(state.range(0));
84+
auto const maxQueueSize = static_cast<uint32_t>(state.range(1));
85+
auto const clientThreads = static_cast<uint32_t>(state.range(2));
86+
auto const itemsPerClient = static_cast<uint32_t>(state.range(3));
87+
auto const clientProcessingMs = static_cast<uint32_t>(state.range(4));
8288

8389
for (auto _ : state) {
8490
std::atomic_size_t totalExecuted = 0uz;
8591
std::atomic_size_t totalQueued = 0uz;
8692

8793
state.PauseTiming();
88-
WorkQueue queue(numThreads, maxSize);
94+
WorkQueue queue(wqThreads, maxQueueSize);
8995
state.ResumeTiming();
9096

91-
for (auto i = 0uz; i < total; ++i) {
92-
totalQueued += static_cast<std::size_t>(queue.postCoro(
93-
[&delayMs, &totalExecuted](auto yield) {
94-
++totalExecuted;
95-
96-
boost::asio::steady_timer timer(yield.get_executor(), std::chrono::milliseconds{delayMs});
97-
timer.async_wait(yield);
98-
},
99-
/* isWhiteListed = */ false
100-
));
97+
std::vector<std::thread> threads;
98+
threads.reserve(clientThreads);
99+
100+
for (auto t = 0uz; t < clientThreads; ++t) {
101+
threads.emplace_back([&] {
102+
for (auto i = 0uz; i < itemsPerClient; ++i) {
103+
totalQueued += static_cast<std::size_t>(queue.postCoro(
104+
[&clientProcessingMs, &totalExecuted](auto yield) {
105+
++totalExecuted;
106+
107+
boost::asio::steady_timer timer(
108+
yield.get_executor(), std::chrono::milliseconds{clientProcessingMs}
109+
);
110+
timer.async_wait(yield);
111+
112+
std::this_thread::sleep_for(std::chrono::microseconds{10});
113+
},
114+
/* isWhiteListed = */ false
115+
));
116+
}
117+
});
101118
}
102119

120+
for (auto& t : threads)
121+
t.join();
122+
103123
queue.stop();
104124

105125
ASSERT(totalExecuted == totalQueued, "Totals don't match");
106-
ASSERT(totalQueued <= total, "Queued more than requested");
107-
ASSERT(totalQueued >= maxSize, "Queued less than maxSize");
126+
ASSERT(totalQueued <= itemsPerClient * clientThreads, "Queued more than requested");
127+
128+
if (maxQueueSize == 0) {
129+
ASSERT(totalQueued == itemsPerClient * clientThreads, "Queued exactly the expected amount");
130+
} else {
131+
ASSERT(totalQueued >= std::min(maxQueueSize, itemsPerClient * clientThreads), "Queued less than expected");
132+
}
108133
}
109134
}
110135

@@ -118,5 +143,5 @@ benchmarkWorkQueue(benchmark::State& state)
118143
*/
119144
// TODO: figure out what happens on 1 thread
120145
BENCHMARK(benchmarkWorkQueue)
121-
->ArgsProduct({{1'000, 10'000, 100'000}, {2, 4, 8}, {0, 5'000}, {10, 100, 250}})
146+
->ArgsProduct({{2, 4, 8, 16}, {0, 5'000}, {4, 8, 16}, {1'000, 10'000}, {10, 100, 250}})
122147
->Unit(benchmark::kMillisecond);

src/rpc/WorkQueue.cpp

Lines changed: 66 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
#include "util/prometheus/Label.hpp"
2626
#include "util/prometheus/Prometheus.hpp"
2727

28-
#include <boost/asio/post.hpp>
2928
#include <boost/asio/spawn.hpp>
30-
#include <boost/asio/strand.hpp>
3129
#include <boost/json/object.hpp>
3230

3331
#include <chrono>
@@ -39,6 +37,27 @@
3937

4038
namespace rpc {
4139

40+
void
41+
WorkQueue::OneTimeCallable::setCallable(std::function<void()> func)
42+
{
43+
func_ = std::move(func);
44+
}
45+
46+
void
47+
WorkQueue::OneTimeCallable::operator()()
48+
{
49+
if (not called_) {
50+
func_();
51+
called_ = true;
52+
}
53+
}
54+
55+
WorkQueue::OneTimeCallable::
56+
operator bool() const
57+
{
58+
return func_.operator bool();
59+
}
60+
4261
WorkQueue::WorkQueue(DontStartProcessingTag, std::uint32_t numWorkers, uint32_t maxSize)
4362
: queued_{PrometheusService::counterInt(
4463
"work_queue_queued_total_number",
@@ -56,8 +75,6 @@ WorkQueue::WorkQueue(DontStartProcessingTag, std::uint32_t numWorkers, uint32_t
5675
"The current number of tasks in the queue"
5776
)}
5877
, ioc_{numWorkers}
59-
, strand_{ioc_.get_executor()}
60-
, waitTimer_(ioc_)
6178
{
6279
if (maxSize != 0)
6380
maxSize_ = maxSize;
@@ -77,12 +94,14 @@ WorkQueue::~WorkQueue()
7794
void
7895
WorkQueue::startProcessing()
7996
{
80-
util::spawn(strand_, [this](auto yield) {
81-
ASSERT(not hasDispatcher_, "Dispatcher already running");
97+
ASSERT(not processingStarted_, "Attempt to start processing work queue more than once");
98+
processingStarted_ = true;
8299

83-
hasDispatcher_ = true;
84-
dispatcherLoop(yield);
85-
});
100+
// Spawn workers for all tasks that were queued before processing started
101+
auto const numTasks = size();
102+
for (auto i = 0uz; i < numTasks; ++i) {
103+
util::spawn(ioc_, [this](auto yield) { executeTask(yield); });
104+
}
86105
}
87106

88107
bool
@@ -98,93 +117,28 @@ WorkQueue::postCoro(TaskType func, bool isWhiteListed, Priority priority)
98117
return false;
99118
}
100119

101-
++curSize_.get();
102-
auto needsWakeup = false;
103-
104120
{
105-
auto state = dispatcherState_.lock();
106-
107-
needsWakeup = std::exchange(state->isIdle, false);
108-
121+
auto state = queueState_.lock();
109122
state->push(priority, std::move(func));
110123
}
111124

112-
if (needsWakeup)
113-
boost::asio::post(strand_, [this] { waitTimer_.cancel(); });
114-
115-
return true;
116-
}
117-
118-
void
119-
WorkQueue::dispatcherLoop(boost::asio::yield_context yield)
120-
{
121-
LOG(log_.info()) << "WorkQueue dispatcher starting";
122-
123-
// all ongoing tasks must be completed before stopping fully
124-
while (not stopping_ or size() > 0) {
125-
std::optional<TaskType> task;
126-
127-
{
128-
auto state = dispatcherState_.lock();
129-
130-
if (state->empty()) {
131-
state->isIdle = true;
132-
} else {
133-
task = state->popNext();
134-
}
135-
}
136-
137-
if (not stopping_ and not task.has_value()) {
138-
waitTimer_.expires_at(std::chrono::steady_clock::time_point::max());
139-
boost::system::error_code ec;
140-
waitTimer_.async_wait(yield[ec]);
141-
} else if (task.has_value()) {
142-
util::spawn(
143-
ioc_,
144-
[this, spawnedAt = std::chrono::system_clock::now(), task = std::move(*task)](auto yield) mutable {
145-
auto const takenAt = std::chrono::system_clock::now();
146-
auto const waited =
147-
std::chrono::duration_cast<std::chrono::microseconds>(takenAt - spawnedAt).count();
148-
149-
++queued_.get();
150-
durationUs_.get() += waited;
151-
LOG(log_.info()) << "WorkQueue wait time: " << waited << ", queue size: " << size();
152-
153-
task(yield);
154-
155-
--curSize_.get();
156-
}
157-
);
158-
}
159-
}
125+
++curSize_.get();
160126

161-
LOG(log_.info()) << "WorkQueue dispatcher shutdown requested - time to execute onTasksComplete";
127+
if (not processingStarted_)
128+
return true;
162129

163-
{
164-
auto onTasksComplete = onQueueEmpty_.lock();
165-
ASSERT(onTasksComplete->operator bool(), "onTasksComplete must be set when stopping is true.");
166-
onTasksComplete->operator()();
167-
}
130+
util::spawn(ioc_, [this](auto yield) { executeTask(yield); });
168131

169-
LOG(log_.info()) << "WorkQueue dispatcher finished";
132+
return true;
170133
}
171134

172135
void
173136
WorkQueue::requestStop(std::function<void()> onQueueEmpty)
174137
{
175138
auto handler = onQueueEmpty_.lock();
176-
*handler = std::move(onQueueEmpty);
139+
handler->setCallable(std::move(onQueueEmpty));
177140

178141
stopping_ = true;
179-
auto needsWakeup = false;
180-
181-
{
182-
auto state = dispatcherState_.lock();
183-
needsWakeup = std::exchange(state->isIdle, false);
184-
}
185-
186-
if (needsWakeup)
187-
boost::asio::post(strand_, [this] { waitTimer_.cancel(); });
188142
}
189143

190144
void
@@ -194,6 +148,12 @@ WorkQueue::stop()
194148
requestStop();
195149

196150
ioc_.join();
151+
152+
{
153+
auto onTasksComplete = onQueueEmpty_.lock();
154+
ASSERT(onTasksComplete->operator bool(), "onTasksComplete must be set when stopping is true.");
155+
onTasksComplete->operator()();
156+
}
197157
}
198158

199159
WorkQueue
@@ -227,4 +187,29 @@ WorkQueue::size() const
227187
return curSize_.get().value();
228188
}
229189

190+
void
191+
WorkQueue::executeTask(boost::asio::yield_context yield)
192+
{
193+
std::optional<TaskWithTimestamp> taskWithTimestamp;
194+
{
195+
auto state = queueState_.lock();
196+
taskWithTimestamp = state->popNext();
197+
}
198+
199+
ASSERT(
200+
taskWithTimestamp.has_value(),
201+
"Queue should not be empty as we spawn a coro with executeTask for each postCoro."
202+
);
203+
auto const takenAt = std::chrono::system_clock::now();
204+
auto const waited =
205+
std::chrono::duration_cast<std::chrono::microseconds>(takenAt - taskWithTimestamp->queuedAt).count();
206+
207+
++queued_.get();
208+
durationUs_.get() += waited;
209+
LOG(log_.info()) << "WorkQueue wait time: " << waited << ", queue size: " << size();
210+
211+
taskWithTimestamp->task(yield);
212+
--curSize_.get();
213+
}
214+
230215
} // namespace rpc

0 commit comments

Comments
 (0)