Skip to content

Commit 0dee594

Browse files
committed
fix: Address PR review feedback for auth rate limiting
- Use configuration values instead of hardcoded values for auth_window and ban_time - Integrate whitelist_ips from configuration with validation and logging - Fix TOCTOU race condition in record_failure by removing entry atomically - Add capacity limit (max_tracked_ips) to prevent memory exhaustion DoS - Use HashSet for whitelist O(1) lookups instead of Vec O(n) - Add auth rate limit config fields to ServerConfig - Propagate security config from ServerFileConfig to ServerConfig - Add test for capacity limit enforcement
1 parent b576086 commit 0dee594

3 files changed

Lines changed: 165 additions & 39 deletions

File tree

src/server/config/mod.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ pub struct ServerConfig {
149149
/// Configuration for command execution.
150150
#[serde(default)]
151151
pub exec: ExecConfig,
152+
153+
/// Time window for counting authentication attempts in seconds.
154+
///
155+
/// Default: 300 (5 minutes)
156+
#[serde(default = "default_auth_window_secs")]
157+
pub auth_window_secs: u64,
158+
159+
/// Ban duration in seconds after exceeding max auth attempts.
160+
///
161+
/// Default: 300 (5 minutes)
162+
#[serde(default = "default_ban_time_secs")]
163+
pub ban_time_secs: u64,
164+
165+
/// IP addresses that are never banned (whitelist).
166+
#[serde(default)]
167+
pub whitelist_ips: Vec<String>,
152168
}
153169

154170
/// Serializable configuration for public key authentication.
@@ -213,6 +229,14 @@ fn default_idle_timeout_secs() -> u64 {
213229
0 // 0 means no timeout
214230
}
215231

232+
fn default_auth_window_secs() -> u64 {
233+
300 // 5 minutes
234+
}
235+
236+
fn default_ban_time_secs() -> u64 {
237+
300 // 5 minutes
238+
}
239+
216240
fn default_true() -> bool {
217241
true
218242
}
@@ -233,6 +257,9 @@ impl Default for ServerConfig {
233257
publickey_auth: PublicKeyAuthConfigSerde::default(),
234258
password_auth: PasswordAuthConfigSerde::default(),
235259
exec: ExecConfig::default(),
260+
auth_window_secs: default_auth_window_secs(),
261+
ban_time_secs: default_ban_time_secs(),
262+
whitelist_ips: Vec::new(),
236263
}
237264
}
238265
}
@@ -521,6 +548,9 @@ impl ServerFileConfig {
521548
allowed_commands: None,
522549
blocked_commands: Vec::new(),
523550
},
551+
auth_window_secs: self.security.auth_window,
552+
ban_time_secs: self.security.ban_time,
553+
whitelist_ips: self.security.whitelist_ips,
524554
}
525555
}
526556
}

src/server/mod.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,34 @@ impl BsshServer {
217217
let rate_limiter = RateLimiter::with_simple_config(100, 10.0);
218218

219219
// Create auth rate limiter with configuration
220-
let auth_rate_limiter = AuthRateLimiter::new(AuthRateLimitConfig::new(
220+
// Parse whitelist IPs from config
221+
let whitelist_ips: Vec<std::net::IpAddr> = self
222+
.config
223+
.whitelist_ips
224+
.iter()
225+
.filter_map(|s| {
226+
s.parse().map_err(|e| {
227+
tracing::warn!(ip = %s, error = %e, "Invalid whitelist IP address in config, skipping");
228+
e
229+
}).ok()
230+
})
231+
.collect();
232+
233+
let auth_config = AuthRateLimitConfig::new(
221234
self.config.max_auth_attempts,
222-
300, // Default 5 minute window
223-
300, // Default 5 minute ban
224-
));
235+
self.config.auth_window_secs,
236+
self.config.ban_time_secs,
237+
).with_whitelist(whitelist_ips);
238+
239+
let auth_rate_limiter = AuthRateLimiter::new(auth_config);
240+
241+
tracing::info!(
242+
max_attempts = self.config.max_auth_attempts,
243+
auth_window_secs = self.config.auth_window_secs,
244+
ban_time_secs = self.config.ban_time_secs,
245+
whitelist_count = self.config.whitelist_ips.len(),
246+
"Auth rate limiter configured"
247+
);
225248

226249
// Start background cleanup task for auth rate limiter
227250
let cleanup_limiter = auth_rate_limiter.clone();

src/server/security/rate_limit.rs

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! against brute-force attacks. It tracks failed authentication attempts per IP
1919
//! and automatically bans IPs that exceed the configured threshold.
2020
21-
use std::collections::HashMap;
21+
use std::collections::{HashMap, HashSet};
2222
use std::net::IpAddr;
2323
use std::sync::Arc;
2424
use std::time::{Duration, Instant};
@@ -31,6 +31,7 @@ use tokio::sync::RwLock;
3131
/// - `window`: Time window for counting attempts
3232
/// - `ban_duration`: How long to ban an IP
3333
/// - `whitelist`: IPs that are never banned
34+
/// - `max_tracked_ips`: Maximum IPs to track (prevents memory exhaustion)
3435
#[derive(Debug, Clone)]
3536
pub struct AuthRateLimitConfig {
3637
/// Maximum failed attempts before ban.
@@ -39,8 +40,11 @@ pub struct AuthRateLimitConfig {
3940
pub window: Duration,
4041
/// Ban duration after exceeding max attempts.
4142
pub ban_duration: Duration,
42-
/// Whitelist IPs (never banned).
43-
pub whitelist: Vec<IpAddr>,
43+
/// Whitelist IPs (never banned). Uses HashSet for O(1) lookups.
44+
pub whitelist: HashSet<IpAddr>,
45+
/// Maximum number of IPs to track (prevents memory exhaustion).
46+
/// When exceeded, oldest entries are removed.
47+
pub max_tracked_ips: usize,
4448
}
4549

4650
impl Default for AuthRateLimitConfig {
@@ -49,7 +53,8 @@ impl Default for AuthRateLimitConfig {
4953
max_attempts: 5,
5054
window: Duration::from_secs(300), // 5 minutes
5155
ban_duration: Duration::from_secs(300), // 5 minutes
52-
whitelist: vec![],
56+
whitelist: HashSet::new(),
57+
max_tracked_ips: 10000, // Limit memory usage
5358
}
5459
}
5560
}
@@ -67,20 +72,25 @@ impl AuthRateLimitConfig {
6772
max_attempts,
6873
window: Duration::from_secs(window_secs),
6974
ban_duration: Duration::from_secs(ban_duration_secs),
70-
whitelist: vec![],
75+
whitelist: HashSet::new(),
76+
max_tracked_ips: 10000,
7177
}
7278
}
7379

7480
/// Add an IP to the whitelist.
7581
pub fn add_whitelist(&mut self, ip: IpAddr) {
76-
if !self.whitelist.contains(&ip) {
77-
self.whitelist.push(ip);
78-
}
82+
self.whitelist.insert(ip);
7983
}
8084

8185
/// Set the whitelist from a list of IPs.
8286
pub fn with_whitelist(mut self, whitelist: Vec<IpAddr>) -> Self {
83-
self.whitelist = whitelist;
87+
self.whitelist = whitelist.into_iter().collect();
88+
self
89+
}
90+
91+
/// Set the maximum number of IPs to track.
92+
pub fn with_max_tracked_ips(mut self, max: usize) -> Self {
93+
self.max_tracked_ips = max;
8494
self
8595
}
8696
}
@@ -168,28 +178,57 @@ impl AuthRateLimiter {
168178
return false;
169179
}
170180

171-
let mut failures = self.failures.write().await;
172-
let now = Instant::now();
181+
let should_ban;
182+
{
183+
let mut failures = self.failures.write().await;
184+
let now = Instant::now();
173185

174-
let record = failures.entry(ip).or_insert_with(|| FailureRecord {
175-
count: 0,
176-
first_failure: now,
177-
last_failure: now,
178-
});
186+
// Enforce capacity limit to prevent memory exhaustion
187+
// If at capacity and this is a new IP, remove oldest entry
188+
if failures.len() >= self.config.max_tracked_ips && !failures.contains_key(&ip) {
189+
// Find and remove the oldest entry by last_failure time
190+
if let Some(oldest_ip) = failures
191+
.iter()
192+
.min_by_key(|(_, record)| record.last_failure)
193+
.map(|(ip, _)| *ip)
194+
{
195+
failures.remove(&oldest_ip);
196+
tracing::debug!(
197+
removed_ip = %oldest_ip,
198+
capacity = self.config.max_tracked_ips,
199+
"Removed oldest failure record due to capacity limit"
200+
);
201+
}
202+
}
179203

180-
// Reset if window expired
181-
if now.duration_since(record.first_failure) > self.config.window {
182-
record.count = 1;
183-
record.first_failure = now;
184-
} else {
185-
record.count += 1;
186-
}
187-
record.last_failure = now;
204+
let record = failures.entry(ip).or_insert_with(|| FailureRecord {
205+
count: 0,
206+
first_failure: now,
207+
last_failure: now,
208+
});
209+
210+
// Reset if window expired
211+
if now.duration_since(record.first_failure) > self.config.window {
212+
record.count = 1;
213+
record.first_failure = now;
214+
} else {
215+
record.count += 1;
216+
}
217+
record.last_failure = now;
218+
219+
// Check if should ban - record the decision while holding the lock
220+
should_ban = record.count >= self.config.max_attempts;
221+
222+
// If banning, remove from failures while we still hold the lock
223+
// This prevents race conditions with concurrent record_failure calls
224+
if should_ban {
225+
failures.remove(&ip);
226+
}
227+
} // failures lock released here
188228

189-
// Check if should ban
190-
if record.count >= self.config.max_attempts {
191-
drop(failures); // Release lock before acquiring ban lock
192-
self.ban(ip).await;
229+
// Now apply the ban if needed
230+
if should_ban {
231+
self.ban_internal(ip).await;
193232
return true;
194233
}
195234

@@ -209,6 +248,18 @@ impl AuthRateLimiter {
209248
/// The IP will be banned for the configured ban duration.
210249
/// Also clears the failure record for the IP.
211250
pub async fn ban(&self, ip: IpAddr) {
251+
// Clean up failure record first
252+
{
253+
let mut failures = self.failures.write().await;
254+
failures.remove(&ip);
255+
}
256+
257+
self.ban_internal(ip).await;
258+
}
259+
260+
/// Internal method to apply a ban without modifying failure records.
261+
/// Used by record_failure which has already cleaned up the failure record.
262+
async fn ban_internal(&self, ip: IpAddr) {
212263
tracing::warn!(
213264
ip = %ip,
214265
duration_secs = self.config.ban_duration.as_secs(),
@@ -218,11 +269,6 @@ impl AuthRateLimiter {
218269
let mut bans = self.bans.write().await;
219270
let expiry = Instant::now() + self.config.ban_duration;
220271
bans.insert(ip, expiry);
221-
222-
// Clean up failure record
223-
drop(bans);
224-
let mut failures = self.failures.write().await;
225-
failures.remove(&ip);
226272
}
227273

228274
/// Manually unban an IP address.
@@ -462,7 +508,8 @@ mod tests {
462508
max_attempts: 3,
463509
window: Duration::from_millis(50),
464510
ban_duration: Duration::from_secs(300),
465-
whitelist: vec![],
511+
whitelist: HashSet::new(),
512+
max_tracked_ips: 10000,
466513
};
467514
let limiter = AuthRateLimiter::new(config);
468515

@@ -489,7 +536,8 @@ mod tests {
489536
max_attempts: 2,
490537
window: Duration::from_millis(10),
491538
ban_duration: Duration::from_millis(10),
492-
whitelist: vec![],
539+
whitelist: HashSet::new(),
540+
max_tracked_ips: 10000,
493541
};
494542
let limiter = AuthRateLimiter::new(config);
495543

@@ -610,4 +658,29 @@ mod tests {
610658
assert_eq!(limiter.config().window.as_secs(), 600);
611659
assert_eq!(limiter.config().ban_duration.as_secs(), 1800);
612660
}
661+
662+
#[tokio::test]
663+
async fn test_capacity_limit() {
664+
// Test that capacity limit prevents unbounded memory growth
665+
let config = AuthRateLimitConfig::new(5, 300, 300).with_max_tracked_ips(3);
666+
let limiter = AuthRateLimiter::new(config);
667+
668+
let ip1: IpAddr = "192.168.1.1".parse().unwrap();
669+
let ip2: IpAddr = "192.168.1.2".parse().unwrap();
670+
let ip3: IpAddr = "192.168.1.3".parse().unwrap();
671+
let ip4: IpAddr = "192.168.1.4".parse().unwrap();
672+
673+
// Record failures for first 3 IPs
674+
limiter.record_failure(ip1).await;
675+
limiter.record_failure(ip2).await;
676+
limiter.record_failure(ip3).await;
677+
assert_eq!(limiter.tracked_count().await, 3);
678+
679+
// Recording for 4th IP should evict the oldest
680+
limiter.record_failure(ip4).await;
681+
assert_eq!(limiter.tracked_count().await, 3);
682+
683+
// ip4 should be tracked, ip1 should be evicted (it was oldest)
684+
assert_eq!(limiter.remaining_attempts(&ip4).await, 4);
685+
}
613686
}

0 commit comments

Comments
 (0)