11use axum:: {
22 extract:: { Query , State } ,
3- http:: { header, HeaderValue , StatusCode } ,
3+ http:: { header, HeaderMap , HeaderValue , StatusCode } ,
44 response:: { Html , IntoResponse , Response } ,
55 routing:: { get, post} ,
66 Json , Router ,
@@ -11,7 +11,6 @@ use std::sync::Arc;
1111use std:: time:: Duration ;
1212use std:: time:: Instant ;
1313use tokio:: sync:: RwLock ;
14- use url:: Url ;
1514use uuid:: Uuid ;
1615
1716const GUEST_HTML_TTL_SECS : u64 = 300 ;
@@ -78,12 +77,105 @@ fn normalize_csp_source(source: &str) -> Option<String> {
7877 return None ;
7978 }
8079
81- let url = Url :: parse ( source) . ok ( ) ?;
82- if !matches ! ( url. scheme( ) , "http" | "https" | "ws" | "wss" ) {
83- return None ;
80+ if let Some ( ( scheme, rest) ) = source. split_once ( "://" ) {
81+ let scheme = scheme. to_ascii_lowercase ( ) ;
82+ if !matches ! ( scheme. as_str( ) , "http" | "https" | "ws" | "wss" ) {
83+ return None ;
84+ }
85+
86+ let authority = rest. split ( [ '/' , '?' , '#' ] ) . next ( ) ?;
87+ if !is_valid_csp_host_source ( authority) {
88+ return None ;
89+ }
90+
91+ return Some ( format ! ( "{scheme}://{}" , authority. to_ascii_lowercase( ) ) ) ;
92+ }
93+
94+ if is_valid_csp_host_source ( source) {
95+ return Some ( source. to_ascii_lowercase ( ) ) ;
8496 }
85- url. host_str ( ) ?;
86- Some ( url. origin ( ) . ascii_serialization ( ) )
97+
98+ None
99+ }
100+
101+ fn is_valid_csp_host_source ( source : & str ) -> bool {
102+ if source. is_empty ( ) || source == "*" || source. contains ( '@' ) {
103+ return false ;
104+ }
105+
106+ let ( host, port) = split_host_and_port ( source) ;
107+ if host. is_empty ( ) {
108+ return false ;
109+ }
110+ if port. is_some_and ( |port| port. is_empty ( ) || port. parse :: < u16 > ( ) . is_err ( ) ) {
111+ return false ;
112+ }
113+
114+ let host = host. strip_prefix ( "*." ) . unwrap_or ( host) ;
115+ if host. eq_ignore_ascii_case ( "localhost" )
116+ || host. parse :: < std:: net:: Ipv4Addr > ( ) . is_ok ( )
117+ || host. parse :: < std:: net:: Ipv6Addr > ( ) . is_ok ( )
118+ {
119+ return true ;
120+ }
121+
122+ !host. is_empty ( )
123+ && host. contains ( '.' )
124+ && host
125+ . split ( '.' )
126+ . all ( |label| is_valid_dns_label ( label) && label != "*" )
127+ }
128+
129+ fn split_host_and_port ( source : & str ) -> ( & str , Option < & str > ) {
130+ if let Some ( remainder) = source. strip_prefix ( '[' ) {
131+ if let Some ( ( host, tail) ) = remainder. split_once ( ']' ) {
132+ let port = tail. strip_prefix ( ':' ) ;
133+ return ( host, port) ;
134+ }
135+ }
136+
137+ match source. rsplit_once ( ':' ) {
138+ Some ( ( host, port) ) if !host. contains ( ':' ) => ( host, Some ( port) ) ,
139+ _ => ( source, None ) ,
140+ }
141+ }
142+
143+ fn is_valid_dns_label ( label : & str ) -> bool {
144+ !label. is_empty ( )
145+ && !label. starts_with ( '-' )
146+ && !label. ends_with ( '-' )
147+ && label. chars ( ) . all ( |c| c. is_ascii_alphanumeric ( ) || c == '-' )
148+ }
149+
150+ fn request_host_is_loopback ( headers : & HeaderMap ) -> bool {
151+ headers
152+ . get ( header:: HOST )
153+ . and_then ( |host| host. to_str ( ) . ok ( ) )
154+ . is_some_and ( authority_is_loopback)
155+ }
156+
157+ fn authority_is_loopback ( authority : & str ) -> bool {
158+ if authority. contains ( '@' ) {
159+ return false ;
160+ }
161+
162+ let host = if let Some ( remainder) = authority. strip_prefix ( '[' ) {
163+ remainder
164+ . split_once ( ']' )
165+ . map ( |( host, _) | host)
166+ . unwrap_or ( remainder)
167+ } else {
168+ authority. split ( ':' ) . next ( ) . unwrap_or ( authority)
169+ } ;
170+
171+ let host = host. to_ascii_lowercase ( ) ;
172+ host == "localhost"
173+ || host
174+ . parse :: < std:: net:: Ipv4Addr > ( )
175+ . is_ok_and ( |addr| addr. is_loopback ( ) )
176+ || host
177+ . parse :: < std:: net:: Ipv6Addr > ( )
178+ . is_ok_and ( |addr| addr. is_loopback ( ) )
87179}
88180
89181fn parse_domains ( domains : Option < & String > ) -> Vec < String > {
@@ -156,11 +248,19 @@ fn build_outer_csp(
156248
157249async fn mcp_app_proxy (
158250 State ( state) : State < AppState > ,
251+ headers : HeaderMap ,
159252 Query ( params) : Query < ProxyQuery > ,
160253) -> Response {
161254 if params. secret != state. secret_key {
162255 return ( StatusCode :: UNAUTHORIZED , "Unauthorized" ) . into_response ( ) ;
163256 }
257+ if !request_host_is_loopback ( & headers) {
258+ return (
259+ StatusCode :: BAD_REQUEST ,
260+ "MCP app proxy is only available from a loopback host" ,
261+ )
262+ . into_response ( ) ;
263+ }
164264
165265 let html = MCP_APP_PROXY_HTML . replace (
166266 "{{OUTER_CSP}}" ,
@@ -189,11 +289,19 @@ async fn mcp_app_proxy(
189289
190290async fn store_guest_html (
191291 State ( state) : State < AppState > ,
292+ headers : HeaderMap ,
192293 Json ( body) : Json < StoreGuestBody > ,
193294) -> Response {
194295 if body. secret != state. secret_key {
195296 return ( StatusCode :: UNAUTHORIZED , "Unauthorized" ) . into_response ( ) ;
196297 }
298+ if !request_host_is_loopback ( & headers) {
299+ return (
300+ StatusCode :: BAD_REQUEST ,
301+ "MCP app guest storage is only available from a loopback host" ,
302+ )
303+ . into_response ( ) ;
304+ }
197305
198306 let nonce = Uuid :: new_v4 ( ) . to_string ( ) ;
199307 let csp = body. csp . unwrap_or_default ( ) ;
@@ -303,3 +411,69 @@ pub(crate) fn routes(secret_key: String) -> Router {
303411 . route ( "/mcp-app-guest" , post ( store_guest_html) )
304412 . with_state ( state)
305413}
414+
415+ #[ cfg( test) ]
416+ mod tests {
417+ use super :: { authority_is_loopback, normalize_csp_source, parse_domains} ;
418+
419+ #[ test]
420+ fn normalizes_url_sources_to_origins ( ) {
421+ assert_eq ! (
422+ normalize_csp_source( "https://cdn.example.com/assets/app.js" ) ,
423+ Some ( "https://cdn.example.com" . to_string( ) )
424+ ) ;
425+ assert_eq ! (
426+ normalize_csp_source( "wss://api.example.com/socket" ) ,
427+ Some ( "wss://api.example.com" . to_string( ) )
428+ ) ;
429+ }
430+
431+ #[ test]
432+ fn accepts_wildcard_and_host_sources ( ) {
433+ assert_eq ! (
434+ normalize_csp_source( "https://*.cloudflare.com" ) ,
435+ Some ( "https://*.cloudflare.com" . to_string( ) )
436+ ) ;
437+ assert_eq ! (
438+ normalize_csp_source( "cdn.example.com" ) ,
439+ Some ( "cdn.example.com" . to_string( ) )
440+ ) ;
441+ assert_eq ! (
442+ normalize_csp_source( "localhost:3000" ) ,
443+ Some ( "localhost:3000" . to_string( ) )
444+ ) ;
445+ }
446+
447+ #[ test]
448+ fn rejects_unsafe_csp_sources ( ) {
449+ assert_eq ! ( normalize_csp_source( "*" ) , None ) ;
450+ assert_eq ! ( normalize_csp_source( "'unsafe-inline'" ) , None ) ;
451+ assert_eq ! ( normalize_csp_source( "javascript:alert(1)" ) , None ) ;
452+ assert_eq ! ( normalize_csp_source( "https://example.com;" ) , None ) ;
453+ assert_eq ! ( normalize_csp_source( "https://user@example.com" ) , None ) ;
454+ }
455+
456+ #[ test]
457+ fn parse_domains_filters_invalid_sources ( ) {
458+ let domains =
459+ "https://cdn.example.com/app.js, https://*.cloudflare.com, *, cdn.example.com"
460+ . to_string ( ) ;
461+
462+ assert_eq ! (
463+ parse_domains( Some ( & domains) ) ,
464+ vec![
465+ "https://cdn.example.com" . to_string( ) ,
466+ "https://*.cloudflare.com" . to_string( ) ,
467+ "cdn.example.com" . to_string( ) ,
468+ ]
469+ ) ;
470+ }
471+
472+ #[ test]
473+ fn detects_loopback_authorities ( ) {
474+ assert ! ( authority_is_loopback( "127.0.0.1:12345" ) ) ;
475+ assert ! ( authority_is_loopback( "localhost:12345" ) ) ;
476+ assert ! ( authority_is_loopback( "[::1]:12345" ) ) ;
477+ assert ! ( !authority_is_loopback( "example.test" ) ) ;
478+ }
479+ }
0 commit comments