-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrate-limiting.cpp
More file actions
377 lines (295 loc) · 12.9 KB
/
rate-limiting.cpp
File metadata and controls
377 lines (295 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
#include <array>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <print>
#include <queue>
#include <thread>
#include <cassert>
using Clock = std::chrono::steady_clock;
using TimePoint = std::chrono::time_point<Clock>;
constexpr auto get_interval(double requests_per_second) {
return std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::duration<double>(1.0 / requests_per_second));
}
auto wait(TimePoint last_call, std::chrono::nanoseconds minimum_wait_duration) -> TimePoint {
const auto earliest_allowed_time = last_call + minimum_wait_duration;
constexpr auto use_sleep = false;
if constexpr (use_sleep) {
//
// peaks at around 60/s
//
auto call_time = Clock::now();
if (call_time < earliest_allowed_time) {
std::this_thread::sleep_until(earliest_allowed_time);
}
return Clock::now();
} else {
// sleeping/waking is too slow for 20k/s, just thrash Clock::now() till time has passed
auto call_time = Clock::now();
while (call_time < earliest_allowed_time) {
call_time = Clock::now();
}
return call_time;
}
}
///
/// @brief Limits function executions to a particular rate. No queue or jobs lost. Blocks when rate is exceeded.
///
template <typename F> class Limiter {
public:
Limiter(double reqs_per_second, F func) : minimum_wait_duration_(get_interval(reqs_per_second)), func_(std::move(func)) {
last_call_ = Clock::now() - minimum_wait_duration_;
std::println("dt = {}ns", minimum_wait_duration_.count());
}
template <typename... Args> auto request(Args&&... args) -> void {
//
{
auto lock = std::unique_lock{mtx_};
last_call_ = wait(last_call_, minimum_wait_duration_);
}
func_(std::forward<Args>(args)...);
}
private:
std::mutex mtx_;
TimePoint last_call_;
std::chrono::nanoseconds minimum_wait_duration_;
F func_;
};
auto test_limiter() -> void {
auto stdout_mutex = std::mutex{};
auto start = Clock::now();
auto count = 0;
auto limiter = Limiter(1, [&](int thread, int arg) {
const auto lock = std::unique_lock{stdout_mutex};
const auto now = Clock::now();
const auto dt = now - start;
std::println("thread: {}, arg: {}, dt: {}", thread, arg, std::chrono::duration_cast<std::chrono::duration<double>>(dt).count());
// if (dt >= std::chrono::seconds{1}) {
// std::println("thread: {}, arg: {}, dt: {}, count: {}", thread, arg, dt.count(), count);
// start = now;
// count = 1;
// } else {
// ++count;
// }
});
auto threads = std::array<std::jthread, 4>{};
for (auto t = 0; t < 4; ++t) {
threads[t] = std::jthread([&, t] {
for (auto i = 0; i < 1000000000; ++i) {
limiter.request(t, i);
}
});
}
}
/* template <typename F> class RateLimiter {
public:
explicit RateLimiter(double reqs_per_second, F func) noexcept : minimum_wait_duration_(get_interval(reqs_per_second)), func_(std::move(func)) {
last_call_ = Clock::now() - minimum_wait_duration_;
std::println("dt = {}ns", minimum_wait_duration_.count());
}
auto add_request(int thread, int arg) -> void {
{
auto lock = std::unique_lock{mtx_};
requests_.push({thread, arg});
}
cv_.notify_one();
}
auto process_requests() {
while (true) {
const auto args = get_next_args();
func_(std::move(args[0]), std::move(args[1]));
}
}
private:
auto get_next_args() -> std::array<int, 2> {
auto lock = std::unique_lock{mtx_};
cv_.wait(lock, [&] { return !requests_.empty(); });
auto args = std::move(requests_.front());
requests_.pop();
last_call_ = wait(last_call_, minimum_wait_duration_);
return args;
}
std::queue<std::array<int, 2>> requests_;
std::mutex mtx_;
std::condition_variable cv_;
F func_;
TimePoint last_call_;
std::chrono::nanoseconds minimum_wait_duration_;
};
auto test_limiter_queue() -> void {
auto stdout_mutex = std::mutex{};
auto start = Clock::now();
auto count = 0;
auto limiter = RateLimiter(20000, [&](int thread, int arg) {
const auto lock = std::unique_lock{stdout_mutex};
const auto now = Clock::now();
const auto dt = now - start;
if (dt >= std::chrono::seconds{1}) {
std::println("thread: {}, arg: {}, dt: {}, count: {}", thread, arg, dt.count(), count);
start = now;
count = 1;
} else {
++count;
}
});
auto processing_thread = std::jthread([&] { limiter.process_requests(); });
auto threads = std::array<std::jthread, 4>{};
for (auto t = 0; t < 4; ++t) {
threads[t] = std::jthread([&, t] {
for (auto i = 0; i < 1000000; ++i) {
limiter.add_request(t, i);
}
});
}
}*/
static std::mutex stdout_mtx;
template <typename F> class TokenBucketLimiter {
public:
///
/// @brief Constructor. Limits function calls to a maximum number of requests per second with a maximum token capacity, allowing for bursting (set to 1 to disable bursting).
///
TokenBucketLimiter(double reqs_per_second, int max_tokens, F func) : last_call_(Clock::now()), minimum_wait_duration_(get_interval(reqs_per_second)), func_(std::move(func)), max_tokens_(max_tokens) {}
///
/// @brief Attempts to invoke the function synchronously, returning true if successful. Otherwise immediately returns false.
///
template <typename... Args> auto try_request(Args... args) -> bool {
const auto now = Clock::now();
// NOTE: If we are ONLY interesting in dropping failed requests, last_call_ can probably be non-atomic
const auto previous_call = last_call_.exchange(now);
return try_request(now, previous_call, args...);
}
template <typename... Args> auto request_block(Args... args) -> void {
auto now = Clock::now();
auto previous_call = last_call_.load();
// std::println("\nNew request...");
while (true) {
if (last_call_.compare_exchange_weak(previous_call, now)) {
assert(now >= previous_call);
// std::println("Trying... {}", attempts);
if (try_request(now, previous_call, args...)) {
// std::println("... Success!");
return;
}
previous_call = now;
}
now = wait(previous_call, minimum_wait_duration_);
}
}
private:
template <typename... Args> auto try_request(TimePoint now, TimePoint last_attempt, Args... args) -> bool {
assert(now >= last_attempt);
const auto dt = now - last_attempt;
// std::println("\tdt = {}s", std::chrono::duration_cast<std::chrono::duration<double>>(dt).count());
if (dt >= minimum_wait_duration_) {
// waited long enough for tokens to regenerate so execute the function anyway
const auto replenished_tokens = std::min(static_cast<int>(dt / minimum_wait_duration_), max_tokens_);
// std::println("\tReplenished tokens: {}", replenished_tokens);
// std::println("\tTokens before: {}", current_tokens_.load());
const auto remaining_tokens = current_tokens_.fetch_add(replenished_tokens - 1);
// std::println("\tTokens after: {}", remaining_tokens);
func_(remaining_tokens, args...);
return true;
}
// std::println("\tTokens before: {}", current_tokens_.load());
const auto remaining_tokens = current_tokens_.fetch_sub(1);
// std::println("\tTokens after: {}", current_tokens_.load());
// std::println("\tRemaining tokens read: {}", remaining_tokens);
// only execute the function if we had enough tokens
// 0 is ok because we just removed 1
if (remaining_tokens >= 0) {
// std::println("Tokens not replenished!");
func_(remaining_tokens, args...);
return true;
}
// had negative tokens, function didn't run, put the token back
++current_tokens_;
// std::println("\tTokens after replacement: {}", current_tokens_.load());
return false;
}
// std::mutex mtx_;
std::atomic<TimePoint> last_call_;
std::chrono::nanoseconds minimum_wait_duration_;
int max_tokens_;
std::atomic<int> current_tokens_ = max_tokens_;
F func_;
};
auto test_limiter_queue() -> void {
auto start = Clock::now();
auto count = 0;
auto limiter = TokenBucketLimiter(20000, 100000, [&](int tokens, int thread, int batch, int value) {
const auto lock = std::unique_lock{stdout_mtx};
const auto now = Clock::now();
const auto dt = now - start;
// std::println("tokens: {}, thread: {}, batch: {}, value: {}, dt: {}", tokens, thread, batch, value, std::chrono::duration_cast<std::chrono::duration<double>>(dt).count());
if (dt >= std::chrono::seconds{1}) {
std::println("tokens: {}, thread: {}, batch: {}, value: {}, dt: {}, count: {}", tokens, thread, batch, value, dt.count(), count);
start = now;
count = 1;
} else {
++count;
}
});
constexpr auto nb_threads = 4;
auto threads = std::array<std::jthread, nb_threads>{};
for (auto t = 0; t < nb_threads; ++t) {
threads[t] = std::jthread([&, t] {
for (auto j = 0; j < 100000; ++j) {
const auto batch_size = 500000;
for (auto i = 0; i < batch_size; ++i) {
limiter.request_block(t, j, i);
}
{
const auto lock = std::unique_lock{stdout_mtx};
std::println("Thread {} pausing...", t);
}
std::this_thread::sleep_for(std::chrono::seconds{15});
{
const auto lock = std::unique_lock{stdout_mtx};
std::println("Thread {} resuming...", t);
}
}
});
}
}
auto main() -> int {
// test_limiter();
test_limiter_queue();
/* const auto time_per_token = std::chrono::seconds{1};
const auto capacity = 5;
auto time_ = std::atomic<TimePoint>{Clock::now() - std::chrono::seconds(10)}; // stores the last time a successful invocation ocurred
auto now = Clock::now(); // now, the time of the current request
auto time_full_burst_ago = now - capacity * time_per_token; // the time when a request should have been made if we wanted to have full capacity at this request time
auto last_attempt_time = time_.load(); // stores the last time an attempt was made (successful or not)
//
// algorithm works by essentially pretending to handle requests at time uniformly spread out in the past, according to the buckets current capacity
// e.g. for these values, the requests could be handled at -5, -4, -3, -2, -1 seconds in the past
//
// however, if the last attempt was made 2 seconds ago, we should only go back as far as -2, -1
//
// if another thread handles the request in the -5 slot, we proceed to try -4, then -3, ... until the request slot has not been taken
// if we use all our slots, we reject the request
//
// in particular, if (last_attempt_time + time_per_token > now), the next available request time is in the future, so we must reject it
//
// on the other hand, if the last_attempt_time was long ago, we should only add as many slots as the maximum capacity
//
auto next_allowed_attempt_time = (time_full_burst_ago > last_attempt_time ? time_full_burst_ago : last_attempt_time) + time_per_token;
//
while (true) {
// we cannot handle the request until the future, reject it
if (next_allowed_attempt_time > now) {
return false;
}
// if time_ == last_attempt_time, no other thread has modified it
// therefore we can handle the request and set the value of time_ to be as if it was invoked on the next token interval (if the requests occur at a uniform rate)
// otherwise, some other thread got there first
// last_attempt_time now holds the value of time_, which has now increased to the next possible token interval
// we need to now try again
if (time_.compare_exchange_weak(last_attempt_time, next_allowed_attempt_time)) {
return true;
}
// increase the next attempt time by a token interval
next_allowed_attempt_time = last_attempt_time + time_per_token;
}*/
return 0;
}