1818//! against brute-force attacks. It tracks failed authentication attempts per IP
1919//! and automatically bans IPs that exceed the configured threshold.
2020
21- use std:: collections:: HashMap ;
21+ use std:: collections:: { HashMap , HashSet } ;
2222use std:: net:: IpAddr ;
2323use std:: sync:: Arc ;
2424use std:: time:: { Duration , Instant } ;
@@ -31,6 +31,7 @@ use tokio::sync::RwLock;
3131/// - `window`: Time window for counting attempts
3232/// - `ban_duration`: How long to ban an IP
3333/// - `whitelist`: IPs that are never banned
34+ /// - `max_tracked_ips`: Maximum IPs to track (prevents memory exhaustion)
3435#[ derive( Debug , Clone ) ]
3536pub struct AuthRateLimitConfig {
3637 /// Maximum failed attempts before ban.
@@ -39,8 +40,11 @@ pub struct AuthRateLimitConfig {
3940 pub window : Duration ,
4041 /// Ban duration after exceeding max attempts.
4142 pub ban_duration : Duration ,
42- /// Whitelist IPs (never banned).
43- pub whitelist : Vec < IpAddr > ,
43+ /// Whitelist IPs (never banned). Uses HashSet for O(1) lookups.
44+ pub whitelist : HashSet < IpAddr > ,
45+ /// Maximum number of IPs to track (prevents memory exhaustion).
46+ /// When exceeded, oldest entries are removed.
47+ pub max_tracked_ips : usize ,
4448}
4549
4650impl Default for AuthRateLimitConfig {
@@ -49,7 +53,8 @@ impl Default for AuthRateLimitConfig {
4953 max_attempts : 5 ,
5054 window : Duration :: from_secs ( 300 ) , // 5 minutes
5155 ban_duration : Duration :: from_secs ( 300 ) , // 5 minutes
52- whitelist : vec ! [ ] ,
56+ whitelist : HashSet :: new ( ) ,
57+ max_tracked_ips : 10000 , // Limit memory usage
5358 }
5459 }
5560}
@@ -67,20 +72,25 @@ impl AuthRateLimitConfig {
6772 max_attempts,
6873 window : Duration :: from_secs ( window_secs) ,
6974 ban_duration : Duration :: from_secs ( ban_duration_secs) ,
70- whitelist : vec ! [ ] ,
75+ whitelist : HashSet :: new ( ) ,
76+ max_tracked_ips : 10000 ,
7177 }
7278 }
7379
7480 /// Add an IP to the whitelist.
7581 pub fn add_whitelist ( & mut self , ip : IpAddr ) {
76- if !self . whitelist . contains ( & ip) {
77- self . whitelist . push ( ip) ;
78- }
82+ self . whitelist . insert ( ip) ;
7983 }
8084
8185 /// Set the whitelist from a list of IPs.
8286 pub fn with_whitelist ( mut self , whitelist : Vec < IpAddr > ) -> Self {
83- self . whitelist = whitelist;
87+ self . whitelist = whitelist. into_iter ( ) . collect ( ) ;
88+ self
89+ }
90+
91+ /// Set the maximum number of IPs to track.
92+ pub fn with_max_tracked_ips ( mut self , max : usize ) -> Self {
93+ self . max_tracked_ips = max;
8494 self
8595 }
8696}
@@ -168,28 +178,57 @@ impl AuthRateLimiter {
168178 return false ;
169179 }
170180
171- let mut failures = self . failures . write ( ) . await ;
172- let now = Instant :: now ( ) ;
181+ let should_ban;
182+ {
183+ let mut failures = self . failures . write ( ) . await ;
184+ let now = Instant :: now ( ) ;
173185
174- let record = failures. entry ( ip) . or_insert_with ( || FailureRecord {
175- count : 0 ,
176- first_failure : now,
177- last_failure : now,
178- } ) ;
186+ // Enforce capacity limit to prevent memory exhaustion
187+ // If at capacity and this is a new IP, remove oldest entry
188+ if failures. len ( ) >= self . config . max_tracked_ips && !failures. contains_key ( & ip) {
189+ // Find and remove the oldest entry by last_failure time
190+ if let Some ( oldest_ip) = failures
191+ . iter ( )
192+ . min_by_key ( |( _, record) | record. last_failure )
193+ . map ( |( ip, _) | * ip)
194+ {
195+ failures. remove ( & oldest_ip) ;
196+ tracing:: debug!(
197+ removed_ip = %oldest_ip,
198+ capacity = self . config. max_tracked_ips,
199+ "Removed oldest failure record due to capacity limit"
200+ ) ;
201+ }
202+ }
179203
180- // Reset if window expired
181- if now. duration_since ( record. first_failure ) > self . config . window {
182- record. count = 1 ;
183- record. first_failure = now;
184- } else {
185- record. count += 1 ;
186- }
187- record. last_failure = now;
204+ let record = failures. entry ( ip) . or_insert_with ( || FailureRecord {
205+ count : 0 ,
206+ first_failure : now,
207+ last_failure : now,
208+ } ) ;
209+
210+ // Reset if window expired
211+ if now. duration_since ( record. first_failure ) > self . config . window {
212+ record. count = 1 ;
213+ record. first_failure = now;
214+ } else {
215+ record. count += 1 ;
216+ }
217+ record. last_failure = now;
218+
219+ // Check if should ban - record the decision while holding the lock
220+ should_ban = record. count >= self . config . max_attempts ;
221+
222+ // If banning, remove from failures while we still hold the lock
223+ // This prevents race conditions with concurrent record_failure calls
224+ if should_ban {
225+ failures. remove ( & ip) ;
226+ }
227+ } // failures lock released here
188228
189- // Check if should ban
190- if record. count >= self . config . max_attempts {
191- drop ( failures) ; // Release lock before acquiring ban lock
192- self . ban ( ip) . await ;
229+ // Now apply the ban if needed
230+ if should_ban {
231+ self . ban_internal ( ip) . await ;
193232 return true ;
194233 }
195234
@@ -209,6 +248,18 @@ impl AuthRateLimiter {
209248 /// The IP will be banned for the configured ban duration.
210249 /// Also clears the failure record for the IP.
211250 pub async fn ban ( & self , ip : IpAddr ) {
251+ // Clean up failure record first
252+ {
253+ let mut failures = self . failures . write ( ) . await ;
254+ failures. remove ( & ip) ;
255+ }
256+
257+ self . ban_internal ( ip) . await ;
258+ }
259+
260+ /// Internal method to apply a ban without modifying failure records.
261+ /// Used by record_failure which has already cleaned up the failure record.
262+ async fn ban_internal ( & self , ip : IpAddr ) {
212263 tracing:: warn!(
213264 ip = %ip,
214265 duration_secs = self . config. ban_duration. as_secs( ) ,
@@ -218,11 +269,6 @@ impl AuthRateLimiter {
218269 let mut bans = self . bans . write ( ) . await ;
219270 let expiry = Instant :: now ( ) + self . config . ban_duration ;
220271 bans. insert ( ip, expiry) ;
221-
222- // Clean up failure record
223- drop ( bans) ;
224- let mut failures = self . failures . write ( ) . await ;
225- failures. remove ( & ip) ;
226272 }
227273
228274 /// Manually unban an IP address.
@@ -462,7 +508,8 @@ mod tests {
462508 max_attempts : 3 ,
463509 window : Duration :: from_millis ( 50 ) ,
464510 ban_duration : Duration :: from_secs ( 300 ) ,
465- whitelist : vec ! [ ] ,
511+ whitelist : HashSet :: new ( ) ,
512+ max_tracked_ips : 10000 ,
466513 } ;
467514 let limiter = AuthRateLimiter :: new ( config) ;
468515
@@ -489,7 +536,8 @@ mod tests {
489536 max_attempts : 2 ,
490537 window : Duration :: from_millis ( 10 ) ,
491538 ban_duration : Duration :: from_millis ( 10 ) ,
492- whitelist : vec ! [ ] ,
539+ whitelist : HashSet :: new ( ) ,
540+ max_tracked_ips : 10000 ,
493541 } ;
494542 let limiter = AuthRateLimiter :: new ( config) ;
495543
@@ -610,4 +658,29 @@ mod tests {
610658 assert_eq ! ( limiter. config( ) . window. as_secs( ) , 600 ) ;
611659 assert_eq ! ( limiter. config( ) . ban_duration. as_secs( ) , 1800 ) ;
612660 }
661+
662+ #[ tokio:: test]
663+ async fn test_capacity_limit ( ) {
664+ // Test that capacity limit prevents unbounded memory growth
665+ let config = AuthRateLimitConfig :: new ( 5 , 300 , 300 ) . with_max_tracked_ips ( 3 ) ;
666+ let limiter = AuthRateLimiter :: new ( config) ;
667+
668+ let ip1: IpAddr = "192.168.1.1" . parse ( ) . unwrap ( ) ;
669+ let ip2: IpAddr = "192.168.1.2" . parse ( ) . unwrap ( ) ;
670+ let ip3: IpAddr = "192.168.1.3" . parse ( ) . unwrap ( ) ;
671+ let ip4: IpAddr = "192.168.1.4" . parse ( ) . unwrap ( ) ;
672+
673+ // Record failures for first 3 IPs
674+ limiter. record_failure ( ip1) . await ;
675+ limiter. record_failure ( ip2) . await ;
676+ limiter. record_failure ( ip3) . await ;
677+ assert_eq ! ( limiter. tracked_count( ) . await , 3 ) ;
678+
679+ // Recording for 4th IP should evict the oldest
680+ limiter. record_failure ( ip4) . await ;
681+ assert_eq ! ( limiter. tracked_count( ) . await , 3 ) ;
682+
683+ // ip4 should be tracked, ip1 should be evicted (it was oldest)
684+ assert_eq ! ( limiter. remaining_attempts( & ip4) . await , 4 ) ;
685+ }
613686}
0 commit comments