Skip to content

Commit bb99ad1

Browse files
committed
feat(rate_limiter) Adding tenant based rate limiting based on input tokens.
Signed-off-by: Varun Shenoy <varun.vinayak.shenoy@oracle.com>
1 parent cb84bbb commit bb99ad1

11 files changed

Lines changed: 624 additions & 7 deletions

File tree

model_gateway/src/app_context.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::{
2222
middleware::TokenBucket,
2323
observability::inflight_tracker::InFlightRequestTracker,
2424
policies::PolicyRegistry,
25+
rate_limit::LocalTokenRateLimiter,
2526
routers::{
2627
grpc::multimodal::MultimodalConfigRegistry, openai::realtime::RealtimeRegistry,
2728
router_manager::RouterManager,
@@ -54,6 +55,7 @@ pub struct AppContext {
5455
pub client: Client,
5556
pub router_config: RouterConfig,
5657
pub rate_limiter: Option<Arc<TokenBucket>>,
58+
pub token_rate_limiter: Option<Arc<LocalTokenRateLimiter>>,
5759
pub tokenizer_registry: Arc<TokenizerRegistry>,
5860
pub multimodal_config_registry: Arc<MultimodalConfigRegistry>,
5961
pub reasoning_parser_factory: Option<ReasoningParserFactory>,
@@ -97,6 +99,7 @@ pub struct AppContextBuilder {
9799
client: Option<Client>,
98100
router_config: Option<RouterConfig>,
99101
rate_limiter: Option<Arc<TokenBucket>>,
102+
token_rate_limiter: Option<Arc<LocalTokenRateLimiter>>,
100103
tokenizer_registry: Option<Arc<TokenizerRegistry>>,
101104
reasoning_parser_factory: Option<ReasoningParserFactory>,
102105
tool_parser_factory: Option<ToolParserFactory>,
@@ -152,6 +155,7 @@ impl AppContextBuilder {
152155
client: None,
153156
router_config: None,
154157
rate_limiter: None,
158+
token_rate_limiter: None,
155159
tokenizer_registry: None,
156160
reasoning_parser_factory: None,
157161
tool_parser_factory: None,
@@ -362,6 +366,7 @@ impl AppContextBuilder {
362366
.ok_or(AppContextBuildError::MissingField("client"))?,
363367
router_config,
364368
rate_limiter: self.rate_limiter,
369+
token_rate_limiter: self.token_rate_limiter,
365370
tokenizer_registry: self
366371
.tokenizer_registry
367372
.ok_or(AppContextBuildError::MissingField("tokenizer_registry"))?,
@@ -520,6 +525,11 @@ impl AppContextBuilder {
520525
)))
521526
}
522527
};
528+
self.token_rate_limiter = config.multi_tenant_rate_limit.enabled.then(|| {
529+
Arc::new(LocalTokenRateLimiter::new(
530+
config.multi_tenant_rate_limit.clone(),
531+
))
532+
});
523533
self
524534
}
525535

model_gateway/src/config/builder.rs

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::collections::HashMap;
1+
use std::collections::{hash_map::Entry, HashMap};
22

33
use smg_mcp::McpConfig;
44

@@ -580,6 +580,59 @@ impl RouterConfigBuilder {
580580
self
581581
}
582582

583+
pub fn multi_tenant_rate_limit_enabled(mut self, enabled: bool) -> Self {
584+
self.config.multi_tenant_rate_limit.enabled = enabled;
585+
self
586+
}
587+
588+
pub fn default_tokens_per_minute(mut self, limit: u32) -> Self {
589+
self.config
590+
.multi_tenant_rate_limit
591+
.default_tokens_per_minute = limit;
592+
self
593+
}
594+
595+
pub fn default_requests_per_minute(mut self, limit: u32) -> Self {
596+
self.config
597+
.multi_tenant_rate_limit
598+
.default_requests_per_minute = limit;
599+
self
600+
}
601+
602+
pub fn tenant_rate_limit<S: Into<String>>(
603+
mut self,
604+
tenant_key: S,
605+
tokens_per_minute: u32,
606+
requests_per_minute: u32,
607+
) -> Self {
608+
let tenant_key = tenant_key.into();
609+
let new_policy = crate::rate_limit::TenantTokenPolicy {
610+
tokens_per_minute,
611+
requests_per_minute,
612+
};
613+
614+
match self
615+
.config
616+
.multi_tenant_rate_limit
617+
.tenants
618+
.entry(tenant_key.clone())
619+
{
620+
Entry::Vacant(entry) => {
621+
entry.insert(new_policy);
622+
}
623+
Entry::Occupied(mut entry) => {
624+
tracing::warn!(
625+
tenant_key = %tenant_key,
626+
previous_policy = ?entry.get(),
627+
new_policy = ?new_policy,
628+
"overwriting duplicate tenant rate limit policy"
629+
);
630+
entry.insert(new_policy);
631+
}
632+
}
633+
self
634+
}
635+
583636
pub fn maybe_model_path(mut self, path: Option<impl Into<String>>) -> Self {
584637
self.config.model_path = path.map(|p| p.into());
585638
self
@@ -930,6 +983,56 @@ mod tests {
930983
assert!(modified.trace_config.is_some());
931984
}
932985

986+
#[test]
987+
fn test_builder_multi_tenant_rate_limit_round_trip() {
988+
let config = RouterConfigBuilder::new()
989+
.regular_mode(vec!["http://worker1:8000".to_string()])
990+
.multi_tenant_rate_limit_enabled(true)
991+
.default_tokens_per_minute(10_000)
992+
.default_requests_per_minute(60)
993+
.tenant_rate_limit("team-a", 50_000, 600)
994+
.tenant_rate_limit("team-b", 100_000, 1_200)
995+
.build()
996+
.unwrap();
997+
998+
assert!(config.multi_tenant_rate_limit.enabled);
999+
assert_eq!(
1000+
config.multi_tenant_rate_limit.default_tokens_per_minute,
1001+
10_000
1002+
);
1003+
assert_eq!(
1004+
config.multi_tenant_rate_limit.default_requests_per_minute,
1005+
60
1006+
);
1007+
let team_a = config
1008+
.multi_tenant_rate_limit
1009+
.tenants
1010+
.get("team-a")
1011+
.expect("team-a override registered");
1012+
assert_eq!(team_a.tokens_per_minute, 50_000);
1013+
assert_eq!(team_a.requests_per_minute, 600);
1014+
assert_eq!(config.multi_tenant_rate_limit.tenants.len(), 2);
1015+
}
1016+
1017+
#[test]
1018+
fn test_builder_duplicate_tenant_rate_limit_overwrites_latest_policy() {
1019+
let config = RouterConfigBuilder::new()
1020+
.regular_mode(vec!["http://worker1:8000".to_string()])
1021+
.tenant_rate_limit("team-a", 50_000, 600)
1022+
.tenant_rate_limit("team-a", 100_000, 1_200)
1023+
.build()
1024+
.unwrap();
1025+
1026+
let team_a = config
1027+
.multi_tenant_rate_limit
1028+
.tenants
1029+
.get("team-a")
1030+
.expect("team-a override registered");
1031+
assert_eq!(team_a.tokens_per_minute, 100_000);
1032+
assert_eq!(team_a.requests_per_minute, 1_200);
1033+
assert_eq!(config.multi_tenant_rate_limit.tenants.len(), 1);
1034+
}
1035+
9331036
/// Test complex routing mode helper method
9341037
#[test]
9351038
fn test_builder_prefill_decode_mode() {

model_gateway/src/config/types.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ pub struct RouterConfig {
110110
pub background: BackgroundConfig,
111111
#[serde(default)]
112112
pub tenant_resolution: TenantResolutionConfig,
113+
#[serde(default)]
114+
pub multi_tenant_rate_limit: crate::rate_limit::MultiTenantRateLimitConfig,
113115
/// Set to -1 to disable rate limiting
114116
pub max_concurrent_requests: i32,
115117
pub queue_size: usize,
@@ -646,6 +648,7 @@ impl Default for RouterConfig {
646648
memory_runtime: MemoryRuntimeConfig::default(),
647649
background: BackgroundConfig::default(),
648650
tenant_resolution: TenantResolutionConfig::default(),
651+
multi_tenant_rate_limit: crate::rate_limit::MultiTenantRateLimitConfig::default(),
649652
max_concurrent_requests: -1,
650653
queue_size: 100,
651654
queue_timeout_secs: 60,

model_gateway/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod mesh;
55
pub mod middleware;
66
pub mod observability;
77
pub mod policies;
8+
pub mod rate_limit;
89
pub mod routers;
910
pub mod server;
1011
pub mod service_discovery;

model_gateway/src/main.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,22 @@ struct CliArgs {
355355
#[arg(long, help_heading = "Rate Limiting")]
356356
rate_limit_tokens_per_second: Option<i32>,
357357

358+
/// Enable tenant-aware token rate limiting.
359+
#[arg(long, default_value_t = false, help_heading = "Rate Limiting")]
360+
multi_tenant_rate_limit_enabled: bool,
361+
362+
/// Default token budget per minute for tenants without an explicit override.
363+
#[arg(long, default_value_t = 0, help_heading = "Rate Limiting")]
364+
default_tokens_per_minute: u32,
365+
366+
/// Default request budget per minute for tenants without an explicit override.
367+
#[arg(long, default_value_t = 0, help_heading = "Rate Limiting")]
368+
default_requests_per_minute: u32,
369+
370+
/// Per-tenant override in the form tenant_key:tpm:rpm, e.g. header:team-a:1000:10
371+
#[arg(long = "tenant-rate-limit", num_args = 0.., help_heading = "Rate Limiting")]
372+
tenant_rate_limits: Vec<String>,
373+
358374
// ==================== Retry Configuration ====================
359375
/// Maximum number of retry attempts
360376
#[arg(long, default_value_t = 5, help_heading = "Retry Configuration")]
@@ -1249,6 +1265,9 @@ impl CliArgs {
12491265
.trust_tenant_header(self.trust_tenant_header)
12501266
.tenant_header_name(&self.tenant_header_name)
12511267
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
1268+
.multi_tenant_rate_limit_enabled(self.multi_tenant_rate_limit_enabled)
1269+
.default_tokens_per_minute(self.default_tokens_per_minute)
1270+
.default_requests_per_minute(self.default_requests_per_minute)
12521271
.maybe_model_path(self.model_path.as_ref())
12531272
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
12541273
.maybe_chat_template(self.chat_template.as_ref())
@@ -1269,6 +1288,23 @@ impl CliArgs {
12691288
.dp_minimum_tokens_scheduler(self.dp_minimum_tokens_scheduler)
12701289
.maybe_server_cert_and_key(self.tls_cert_path.as_ref(), self.tls_key_path.as_ref());
12711290

1291+
let mut builder = builder;
1292+
for spec in &self.tenant_rate_limits {
1293+
let mut parts = spec.rsplitn(3, ':');
1294+
let rpm = parts.next().and_then(|s| s.parse::<u32>().ok());
1295+
let tpm = parts.next().and_then(|s| s.parse::<u32>().ok());
1296+
let tenant_key = parts.next();
1297+
if let (Some(tenant_key), Some(tpm), Some(rpm)) = (tenant_key, tpm, rpm) {
1298+
builder = builder.tenant_rate_limit(tenant_key, tpm, rpm);
1299+
} else {
1300+
return Err(ConfigError::ValidationFailed {
1301+
reason: format!(
1302+
"invalid --tenant-rate-limit '{spec}'; expected tenant_key:tpm:rpm"
1303+
),
1304+
});
1305+
}
1306+
}
1307+
12721308
builder.build()
12731309
}
12741310

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
use axum::{
2+
http::{self, header::RETRY_AFTER, HeaderValue},
3+
response::Response,
4+
};
5+
6+
use super::local::TERMINAL_REJECTION_RETRY_AFTER_SECS;
7+
use crate::routers::error::create_error;
8+
9+
pub fn rate_limit_exceeded_response(retry_after_secs: u64) -> Response {
10+
let mut response = create_error(
11+
http::StatusCode::TOO_MANY_REQUESTS,
12+
"tenant_rate_limit_exceeded",
13+
"Tenant rate limit exceeded for this request",
14+
);
15+
16+
if retry_after_secs == TERMINAL_REJECTION_RETRY_AFTER_SECS {
17+
return response;
18+
}
19+
20+
if let Ok(v) = HeaderValue::from_str(&retry_after_secs.max(1).to_string()) {
21+
response.headers_mut().insert(RETRY_AFTER, v);
22+
}
23+
24+
response
25+
}
26+
27+
#[cfg(test)]
28+
mod tests {
29+
use axum::http::header::RETRY_AFTER;
30+
31+
use super::{rate_limit_exceeded_response, TERMINAL_REJECTION_RETRY_AFTER_SECS};
32+
33+
#[test]
34+
fn omits_retry_after_for_terminal_rate_limit_rejection() {
35+
let response = rate_limit_exceeded_response(TERMINAL_REJECTION_RETRY_AFTER_SECS);
36+
37+
assert!(response.headers().get(RETRY_AFTER).is_none());
38+
}
39+
}

0 commit comments

Comments
 (0)