1
1
use futures_util:: future:: join_all;
2
2
use leaky_bucket:: RateLimiter ;
3
3
use std:: sync:: Arc ;
4
+ use std:: { error:: Error , fmt:: Display } ;
4
5
5
6
use crate :: { tiers:: Tier , Consumer , State } ;
6
7
7
- async fn has_limiter ( state : & State , consumer : & Consumer ) -> bool {
8
+ #[ derive( Debug ) ]
9
+ pub enum LimiterError {
10
+ PortDeleted ,
11
+ InvalidTier ,
12
+ }
13
+ impl Display for LimiterError {
14
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
15
+ match self {
16
+ LimiterError :: PortDeleted => f. write_str ( "Port was deleted" ) ,
17
+ LimiterError :: InvalidTier => f. write_str ( "Tier is invalid" ) ,
18
+ }
19
+ }
20
+ }
21
+ impl Error for LimiterError { }
22
+
23
+ async fn has_limiter ( state : & State , consumer_key : & String ) -> bool {
8
24
let rate_limiter_map = state. limiter . read ( ) . await ;
9
- rate_limiter_map. get ( & consumer . key ) . is_some ( )
25
+ rate_limiter_map. get ( consumer_key ) . is_some ( )
10
26
}
11
27
12
28
async fn add_limiter ( state : & State , consumer : & Consumer , tier : & Tier ) {
@@ -31,16 +47,24 @@ async fn add_limiter(state: &State, consumer: &Consumer, tier: &Tier) {
31
47
. insert ( consumer. key . clone ( ) , rates) ;
32
48
}
33
49
34
- pub async fn limiter ( state : Arc < State > , consumer : & Consumer ) {
35
- let tiers = state. tiers . read ( ) . await . clone ( ) ;
36
- let tier = tiers. get ( & consumer. tier ) . unwrap ( ) ;
37
-
38
- if !has_limiter ( & state, consumer) . await {
50
+ pub async fn limiter ( state : Arc < State > , consumer_key : String ) -> Result < ( ) , LimiterError > {
51
+ if !has_limiter ( & state, & consumer_key) . await {
52
+ let consumers = state. consumers . read ( ) . await . clone ( ) ;
53
+ let consumer = match consumers. get ( & consumer_key) {
54
+ Some ( consumer) => consumer,
55
+ None => return Err ( LimiterError :: PortDeleted ) ,
56
+ } ;
57
+ let tiers = state. tiers . read ( ) . await . clone ( ) ;
58
+ let tier = match tiers. get ( & consumer. tier ) {
59
+ Some ( tier) => tier,
60
+ None => return Err ( LimiterError :: InvalidTier ) ,
61
+ } ;
39
62
add_limiter ( & state, consumer, tier) . await ;
40
63
}
41
64
42
65
let rate_limiter_map = state. limiter . read ( ) . await . clone ( ) ;
43
- let rates = rate_limiter_map. get ( & consumer . key ) . unwrap ( ) ;
66
+ let rates = rate_limiter_map. get ( & consumer_key ) . unwrap ( ) ;
44
67
45
68
join_all ( rates. iter ( ) . map ( |r| async { r. acquire_one ( ) . await } ) ) . await ;
69
+ Ok ( ( ) )
46
70
}
0 commit comments