Skip to content

Commit d8a6714

Browse files
fix: Update on ports now work on active connections (#51)
* fix: Update on ports now work on active connections * Keep consumer on request, but only use key for limiting
1 parent 29aeec1 commit d8a6714

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

Diff for: proxy/src/limiter.rs

+32-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
use futures_util::future::join_all;
22
use leaky_bucket::RateLimiter;
33
use std::sync::Arc;
4+
use std::{error::Error, fmt::Display};
45

56
use crate::{tiers::Tier, Consumer, State};
67

7-
async fn has_limiter(state: &State, consumer: &Consumer) -> bool {
8+
#[derive(Debug)]
9+
pub enum LimiterError {
10+
PortDeleted,
11+
InvalidTier,
12+
}
13+
impl Display for LimiterError {
14+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15+
match self {
16+
LimiterError::PortDeleted => f.write_str("Port was deleted"),
17+
LimiterError::InvalidTier => f.write_str("Tier is invalid"),
18+
}
19+
}
20+
}
21+
impl Error for LimiterError {}
22+
23+
async fn has_limiter(state: &State, consumer_key: &String) -> bool {
824
let rate_limiter_map = state.limiter.read().await;
9-
rate_limiter_map.get(&consumer.key).is_some()
25+
rate_limiter_map.get(consumer_key).is_some()
1026
}
1127

1228
async fn add_limiter(state: &State, consumer: &Consumer, tier: &Tier) {
@@ -31,16 +47,24 @@ async fn add_limiter(state: &State, consumer: &Consumer, tier: &Tier) {
3147
.insert(consumer.key.clone(), rates);
3248
}
3349

34-
pub async fn limiter(state: Arc<State>, consumer: &Consumer) {
35-
let tiers = state.tiers.read().await.clone();
36-
let tier = tiers.get(&consumer.tier).unwrap();
37-
38-
if !has_limiter(&state, consumer).await {
50+
pub async fn limiter(state: Arc<State>, consumer_key: String) -> Result<(), LimiterError> {
51+
if !has_limiter(&state, &consumer_key).await {
52+
let consumers = state.consumers.read().await.clone();
53+
let consumer = match consumers.get(&consumer_key) {
54+
Some(consumer) => consumer,
55+
None => return Err(LimiterError::PortDeleted),
56+
};
57+
let tiers = state.tiers.read().await.clone();
58+
let tier = match tiers.get(&consumer.tier) {
59+
Some(tier) => tier,
60+
None => return Err(LimiterError::InvalidTier),
61+
};
3962
add_limiter(&state, consumer, tier).await;
4063
}
4164

4265
let rate_limiter_map = state.limiter.read().await.clone();
43-
let rates = rate_limiter_map.get(&consumer.key).unwrap();
66+
let rates = rate_limiter_map.get(&consumer_key).unwrap();
4467

4568
join_all(rates.iter().map(|r| async { r.acquire_one().await })).await;
69+
Ok(())
4670
}

Diff for: proxy/src/proxy.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,13 @@ async fn handle_websocket(
199199
while let Some(result) = client_incoming.next().await {
200200
match result {
201201
Ok(data) => {
202-
limiter(state.clone(), proxy_req.consumer.as_ref().unwrap()).await;
202+
if let Err(err) =
203+
limiter(state.clone(), proxy_req.consumer.clone().unwrap().key)
204+
.await
205+
{
206+
error!(error = err.to_string(), "Failed to run limiter");
207+
break;
208+
};
203209
if let Err(err) = instance_outgoing.send(data).await {
204210
error!(
205211
error = err.to_string(),

0 commit comments

Comments
 (0)