|
1 | | -use std::collections::HashMap; |
| 1 | +use std::collections::{hash_map::Entry, HashMap}; |
2 | 2 |
|
3 | 3 | use smg_mcp::McpConfig; |
4 | 4 |
|
@@ -580,6 +580,59 @@ impl RouterConfigBuilder { |
580 | 580 | self |
581 | 581 | } |
582 | 582 |
|
| 583 | + pub fn multi_tenant_rate_limit_enabled(mut self, enabled: bool) -> Self { |
| 584 | + self.config.multi_tenant_rate_limit.enabled = enabled; |
| 585 | + self |
| 586 | + } |
| 587 | + |
| 588 | + pub fn default_tokens_per_minute(mut self, limit: u32) -> Self { |
| 589 | + self.config |
| 590 | + .multi_tenant_rate_limit |
| 591 | + .default_tokens_per_minute = limit; |
| 592 | + self |
| 593 | + } |
| 594 | + |
| 595 | + pub fn default_requests_per_minute(mut self, limit: u32) -> Self { |
| 596 | + self.config |
| 597 | + .multi_tenant_rate_limit |
| 598 | + .default_requests_per_minute = limit; |
| 599 | + self |
| 600 | + } |
| 601 | + |
| 602 | + pub fn tenant_rate_limit<S: Into<String>>( |
| 603 | + mut self, |
| 604 | + tenant_key: S, |
| 605 | + tokens_per_minute: u32, |
| 606 | + requests_per_minute: u32, |
| 607 | + ) -> Self { |
| 608 | + let tenant_key = tenant_key.into(); |
| 609 | + let new_policy = crate::rate_limit::TenantTokenPolicy { |
| 610 | + tokens_per_minute, |
| 611 | + requests_per_minute, |
| 612 | + }; |
| 613 | + |
| 614 | + match self |
| 615 | + .config |
| 616 | + .multi_tenant_rate_limit |
| 617 | + .tenants |
| 618 | + .entry(tenant_key.clone()) |
| 619 | + { |
| 620 | + Entry::Vacant(entry) => { |
| 621 | + entry.insert(new_policy); |
| 622 | + } |
| 623 | + Entry::Occupied(mut entry) => { |
| 624 | + tracing::warn!( |
| 625 | + tenant_key = %tenant_key, |
| 626 | + previous_policy = ?entry.get(), |
| 627 | + new_policy = ?new_policy, |
| 628 | + "overwriting duplicate tenant rate limit policy" |
| 629 | + ); |
| 630 | + entry.insert(new_policy); |
| 631 | + } |
| 632 | + } |
| 633 | + self |
| 634 | + } |
| 635 | + |
583 | 636 | pub fn maybe_model_path(mut self, path: Option<impl Into<String>>) -> Self { |
584 | 637 | self.config.model_path = path.map(|p| p.into()); |
585 | 638 | self |
@@ -930,6 +983,56 @@ mod tests { |
930 | 983 | assert!(modified.trace_config.is_some()); |
931 | 984 | } |
932 | 985 |
|
| 986 | + #[test] |
| 987 | + fn test_builder_multi_tenant_rate_limit_round_trip() { |
| 988 | + let config = RouterConfigBuilder::new() |
| 989 | + .regular_mode(vec!["http://worker1:8000".to_string()]) |
| 990 | + .multi_tenant_rate_limit_enabled(true) |
| 991 | + .default_tokens_per_minute(10_000) |
| 992 | + .default_requests_per_minute(60) |
| 993 | + .tenant_rate_limit("team-a", 50_000, 600) |
| 994 | + .tenant_rate_limit("team-b", 100_000, 1_200) |
| 995 | + .build() |
| 996 | + .unwrap(); |
| 997 | + |
| 998 | + assert!(config.multi_tenant_rate_limit.enabled); |
| 999 | + assert_eq!( |
| 1000 | + config.multi_tenant_rate_limit.default_tokens_per_minute, |
| 1001 | + 10_000 |
| 1002 | + ); |
| 1003 | + assert_eq!( |
| 1004 | + config.multi_tenant_rate_limit.default_requests_per_minute, |
| 1005 | + 60 |
| 1006 | + ); |
| 1007 | + let team_a = config |
| 1008 | + .multi_tenant_rate_limit |
| 1009 | + .tenants |
| 1010 | + .get("team-a") |
| 1011 | + .expect("team-a override registered"); |
| 1012 | + assert_eq!(team_a.tokens_per_minute, 50_000); |
| 1013 | + assert_eq!(team_a.requests_per_minute, 600); |
| 1014 | + assert_eq!(config.multi_tenant_rate_limit.tenants.len(), 2); |
| 1015 | + } |
| 1016 | + |
| 1017 | + #[test] |
| 1018 | + fn test_builder_duplicate_tenant_rate_limit_overwrites_latest_policy() { |
| 1019 | + let config = RouterConfigBuilder::new() |
| 1020 | + .regular_mode(vec!["http://worker1:8000".to_string()]) |
| 1021 | + .tenant_rate_limit("team-a", 50_000, 600) |
| 1022 | + .tenant_rate_limit("team-a", 100_000, 1_200) |
| 1023 | + .build() |
| 1024 | + .unwrap(); |
| 1025 | + |
| 1026 | + let team_a = config |
| 1027 | + .multi_tenant_rate_limit |
| 1028 | + .tenants |
| 1029 | + .get("team-a") |
| 1030 | + .expect("team-a override registered"); |
| 1031 | + assert_eq!(team_a.tokens_per_minute, 100_000); |
| 1032 | + assert_eq!(team_a.requests_per_minute, 1_200); |
| 1033 | + assert_eq!(config.multi_tenant_rate_limit.tenants.len(), 1); |
| 1034 | + } |
| 1035 | + |
933 | 1036 | /// Test complex routing mode helper method |
934 | 1037 | #[test] |
935 | 1038 | fn test_builder_prefill_decode_mode() { |
|
0 commit comments