Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions model_gateway/src/app_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
middleware::TokenBucket,
observability::inflight_tracker::InFlightRequestTracker,
policies::PolicyRegistry,
rate_limit::LocalTokenRateLimiter,
routers::{
common::openai_bridge::FormatRegistry, grpc::multimodal::MultimodalConfigRegistry,
openai::realtime::RealtimeRegistry, router_manager::RouterManager,
Expand Down Expand Up @@ -51,6 +52,7 @@ pub struct AppContext {
pub client: Client,
pub router_config: RouterConfig,
pub rate_limiter: Option<Arc<TokenBucket>>,
pub token_rate_limiter: Option<Arc<LocalTokenRateLimiter>>,
pub tokenizer_registry: Arc<TokenizerRegistry>,
pub multimodal_config_registry: Arc<MultimodalConfigRegistry>,
pub reasoning_parser_factory: Option<ReasoningParserFactory>,
Expand Down Expand Up @@ -91,6 +93,7 @@ pub struct AppContextBuilder {
client: Option<Client>,
router_config: Option<RouterConfig>,
rate_limiter: Option<Arc<TokenBucket>>,
token_rate_limiter: Option<Arc<LocalTokenRateLimiter>>,
tokenizer_registry: Option<Arc<TokenizerRegistry>>,
reasoning_parser_factory: Option<ReasoningParserFactory>,
tool_parser_factory: Option<ToolParserFactory>,
Expand Down Expand Up @@ -144,6 +147,7 @@ impl AppContextBuilder {
client: None,
router_config: None,
rate_limiter: None,
token_rate_limiter: None,
tokenizer_registry: None,
reasoning_parser_factory: None,
tool_parser_factory: None,
Expand Down Expand Up @@ -337,6 +341,7 @@ impl AppContextBuilder {
.ok_or(AppContextBuildError::MissingField("client"))?,
router_config,
rate_limiter: self.rate_limiter,
token_rate_limiter: self.token_rate_limiter,
tokenizer_registry: self
.tokenizer_registry
.ok_or(AppContextBuildError::MissingField("tokenizer_registry"))?,
Expand Down Expand Up @@ -488,6 +493,11 @@ impl AppContextBuilder {
)))
}
};
self.token_rate_limiter = config.multi_tenant_rate_limit.enabled.then(|| {
Arc::new(LocalTokenRateLimiter::new(
config.multi_tenant_rate_limit.clone(),
))
});
self
}

Expand Down
107 changes: 105 additions & 2 deletions model_gateway/src/config/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{hash_map::Entry, HashMap};

use smg_mcp::McpConfig;

Expand All @@ -7,7 +7,7 @@ use super::{
HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, PostgresConfig, RedisConfig,
RetryConfig, RouterConfig, RoutingMode, TokenizerCacheConfig, TraceConfig,
};
use crate::worker::ConnectionMode;
use crate::{rate_limit::TenantTokenPolicy, worker::ConnectionMode};

/// Builder for RouterConfig that wraps the config itself
/// This eliminates field duplication and stays in sync automatically
Expand Down Expand Up @@ -608,6 +608,59 @@ impl RouterConfigBuilder {
self
}

pub fn multi_tenant_rate_limit_enabled(mut self, enabled: bool) -> Self {
self.config.multi_tenant_rate_limit.enabled = enabled;
self
}

pub fn default_tokens_per_minute(mut self, limit: u32) -> Self {
self.config
.multi_tenant_rate_limit
.default_tokens_per_minute = limit;
self
}

pub fn default_requests_per_minute(mut self, limit: u32) -> Self {
self.config
.multi_tenant_rate_limit
.default_requests_per_minute = limit;
self
}

pub fn tenant_rate_limit<S: Into<String>>(
mut self,
tenant_key: S,
tokens_per_minute: u32,
requests_per_minute: u32,
) -> Self {
let tenant_key = tenant_key.into();
let new_policy = TenantTokenPolicy {
tokens_per_minute,
requests_per_minute,
};

match self
.config
.multi_tenant_rate_limit
.tenants
.entry(tenant_key.clone())
{
Entry::Vacant(entry) => {
entry.insert(new_policy);
}
Entry::Occupied(mut entry) => {
tracing::warn!(
tenant_key = %tenant_key,
previous_policy = ?entry.get(),
new_policy = ?new_policy,
"overwriting duplicate tenant rate limit policy"
);
entry.insert(new_policy);
}
}
self
}
Comment thread
shenoyvvarun marked this conversation as resolved.
Comment thread
shenoyvvarun marked this conversation as resolved.

pub fn maybe_model_path(mut self, path: Option<impl Into<String>>) -> Self {
self.config.model_path = path.map(|p| p.into());
self
Expand Down Expand Up @@ -897,6 +950,56 @@ mod tests {
assert!(modified.trace_config.is_some());
}

#[test]
fn test_builder_multi_tenant_rate_limit_round_trip() {
let config = RouterConfigBuilder::new()
.regular_mode(vec!["http://worker1:8000".to_string()])
.multi_tenant_rate_limit_enabled(true)
.default_tokens_per_minute(10_000)
.default_requests_per_minute(60)
.tenant_rate_limit("team-a", 50_000, 600)
.tenant_rate_limit("team-b", 100_000, 1_200)
.build()
.unwrap();

assert!(config.multi_tenant_rate_limit.enabled);
assert_eq!(
config.multi_tenant_rate_limit.default_tokens_per_minute,
10_000
);
assert_eq!(
config.multi_tenant_rate_limit.default_requests_per_minute,
60
);
let team_a = config
.multi_tenant_rate_limit
.tenants
.get("team-a")
.expect("team-a override registered");
assert_eq!(team_a.tokens_per_minute, 50_000);
assert_eq!(team_a.requests_per_minute, 600);
assert_eq!(config.multi_tenant_rate_limit.tenants.len(), 2);
}

#[test]
fn test_builder_duplicate_tenant_rate_limit_overwrites_latest_policy() {
let config = RouterConfigBuilder::new()
.regular_mode(vec!["http://worker1:8000".to_string()])
.tenant_rate_limit("team-a", 50_000, 600)
.tenant_rate_limit("team-a", 100_000, 1_200)
.build()
.unwrap();

let team_a = config
.multi_tenant_rate_limit
.tenants
.get("team-a")
.expect("team-a override registered");
assert_eq!(team_a.tokens_per_minute, 100_000);
assert_eq!(team_a.requests_per_minute, 1_200);
assert_eq!(config.multi_tenant_rate_limit.tenants.len(), 1);
}

/// Test complex routing mode helper method
#[test]
fn test_builder_prefill_decode_mode() {
Expand Down
8 changes: 7 additions & 1 deletion model_gateway/src/config/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ pub use smg_data_connector::{
};

use super::{validation::ConfigValidator, ConfigResult};
use crate::{tenant::DEFAULT_TENANT_HEADER_NAME, worker::ConnectionMode};
use crate::{
rate_limit::MultiTenantRateLimitConfig, tenant::DEFAULT_TENANT_HEADER_NAME,
worker::ConnectionMode,
};

/// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -54,6 +57,8 @@ pub struct RouterConfig {
pub storage_context_headers: HashMap<String, String>,
#[serde(default)]
pub tenant_resolution: TenantResolutionConfig,
#[serde(default)]
pub multi_tenant_rate_limit: MultiTenantRateLimitConfig,
/// Set to -1 to disable rate limiting
pub max_concurrent_requests: i32,
pub queue_size: usize,
Expand Down Expand Up @@ -680,6 +685,7 @@ impl Default for RouterConfig {
request_id_headers: None,
storage_context_headers: HashMap::new(),
tenant_resolution: TenantResolutionConfig::default(),
multi_tenant_rate_limit: MultiTenantRateLimitConfig::default(),
max_concurrent_requests: -1,
queue_size: 100,
queue_timeout_secs: 60,
Expand Down
1 change: 1 addition & 0 deletions model_gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod mesh;
pub mod middleware;
pub mod observability;
pub mod policies;
pub mod rate_limit;
pub mod routers;
pub mod server;
pub mod service_discovery;
Expand Down
43 changes: 43 additions & 0 deletions model_gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ struct CliArgs {
#[arg(long, help_heading = "Rate Limiting")]
rate_limit_tokens_per_second: Option<i32>,

/// Enable tenant-aware token rate limiting.
#[arg(long, default_value_t = false, help_heading = "Rate Limiting")]
multi_tenant_rate_limit_enabled: bool,

/// Default token budget per minute for tenants without an explicit override.
#[arg(long, default_value_t = 0, help_heading = "Rate Limiting")]
default_tokens_per_minute: u32,

/// Default request budget per minute for tenants without an explicit override.
#[arg(long, default_value_t = 0, help_heading = "Rate Limiting")]
default_requests_per_minute: u32,

/// Per-tenant override in the form tenant_key:tpm:rpm, e.g. header:team-a:1000:10
#[arg(long = "tenant-rate-limit", num_args = 0.., help_heading = "Rate Limiting")]
Comment thread
shenoyvvarun marked this conversation as resolved.
tenant_rate_limits: Vec<String>,

// ==================== Retry Configuration ====================
/// Maximum number of retry attempts
#[arg(long, default_value_t = 5, help_heading = "Retry Configuration")]
Expand Down Expand Up @@ -1330,6 +1346,9 @@ impl CliArgs {
.trust_tenant_header(self.trust_tenant_header)
.tenant_header_name(&self.tenant_header_name)
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
.multi_tenant_rate_limit_enabled(self.multi_tenant_rate_limit_enabled)
.default_tokens_per_minute(self.default_tokens_per_minute)
.default_requests_per_minute(self.default_requests_per_minute)
.maybe_model_path(self.model_path.as_ref())
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
.maybe_chat_template(self.chat_template.as_ref())
Expand All @@ -1348,6 +1367,30 @@ impl CliArgs {
.dp_minimum_tokens_scheduler(self.dp_minimum_tokens_scheduler)
.maybe_server_cert_and_key(self.tls_cert_path.as_ref(), self.tls_key_path.as_ref());

let mut builder = builder;
for spec in &self.tenant_rate_limits {
let mut parts = spec.rsplitn(3, ':');
let rpm = parts.next().and_then(|s| s.parse::<u32>().ok());
let tpm = parts.next().and_then(|s| s.parse::<u32>().ok());
let tenant_key = parts.next();
if let (Some(tenant_key), Some(tpm), Some(rpm)) = (tenant_key, tpm, rpm) {
if tenant_key.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: format!(
"invalid --tenant-rate-limit '{spec}'; expected tenant_key:tpm:rpm"
),
});
}
builder = builder.tenant_rate_limit(tenant_key, tpm, rpm);
} else {
return Err(ConfigError::ValidationFailed {
reason: format!(
"invalid --tenant-rate-limit '{spec}'; expected tenant_key:tpm:rpm"
),
});
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

builder.build()
}

Expand Down
49 changes: 49 additions & 0 deletions model_gateway/src/rate_limit/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use axum::{
http::{self, header::RETRY_AFTER, HeaderValue},
response::Response,
};

use super::local::TERMINAL_REJECTION_RETRY_AFTER_SECS;
use crate::routers::error::create_error;

pub fn rate_limit_exceeded_response(retry_after_secs: u64) -> Response {
if retry_after_secs == TERMINAL_REJECTION_RETRY_AFTER_SECS {
return create_error(
http::StatusCode::PAYLOAD_TOO_LARGE,
"tenant_rate_limit_exceeded",
"Request exceeds the tenant capacity limit and cannot be retried without reducing its size",
);
}

let mut response = create_error(
http::StatusCode::TOO_MANY_REQUESTS,
"tenant_rate_limit_exceeded",
"Tenant rate limit exceeded for this request",
);

if let Ok(v) = HeaderValue::from_str(&retry_after_secs.max(1).to_string()) {
response.headers_mut().insert(RETRY_AFTER, v);
}

response
}

#[cfg(test)]
mod tests {
use axum::http::{header::RETRY_AFTER, StatusCode};

use super::{rate_limit_exceeded_response, TERMINAL_REJECTION_RETRY_AFTER_SECS};
use crate::routers::error::extract_error_code_from_response;

#[test]
fn returns_distinct_terminal_rate_limit_rejection() {
let response = rate_limit_exceeded_response(TERMINAL_REJECTION_RETRY_AFTER_SECS);

assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
assert!(response.headers().get(RETRY_AFTER).is_none());
assert_eq!(
extract_error_code_from_response(&response),
"tenant_rate_limit_exceeded"
);
}
}
Loading
Loading