|
1 | 1 | use axum::{ |
2 | | - extract::{Query, State}, |
3 | | - http::{header, HeaderMap, HeaderValue, StatusCode}, |
| 2 | + extract::{ConnectInfo, Query, State}, |
| 3 | + http::{header, HeaderValue, StatusCode}, |
4 | 4 | response::{Html, IntoResponse, Response}, |
5 | 5 | routing::{get, post}, |
6 | 6 | Json, Router, |
7 | 7 | }; |
8 | 8 | use serde::{Deserialize, Serialize}; |
9 | 9 | use std::collections::HashMap; |
| 10 | +use std::net::SocketAddr; |
10 | 11 | use std::sync::Arc; |
11 | 12 | use std::time::Duration; |
12 | 13 | use std::time::Instant; |
@@ -147,35 +148,8 @@ fn is_valid_dns_label(label: &str) -> bool { |
147 | 148 | && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') |
148 | 149 | } |
149 | 150 |
|
150 | | -fn request_host_is_loopback(headers: &HeaderMap) -> bool { |
151 | | - headers |
152 | | - .get(header::HOST) |
153 | | - .and_then(|host| host.to_str().ok()) |
154 | | - .is_some_and(authority_is_loopback) |
155 | | -} |
156 | | - |
157 | | -fn authority_is_loopback(authority: &str) -> bool { |
158 | | - if authority.contains('@') { |
159 | | - return false; |
160 | | - } |
161 | | - |
162 | | - let host = if let Some(remainder) = authority.strip_prefix('[') { |
163 | | - remainder |
164 | | - .split_once(']') |
165 | | - .map(|(host, _)| host) |
166 | | - .unwrap_or(remainder) |
167 | | - } else { |
168 | | - authority.split(':').next().unwrap_or(authority) |
169 | | - }; |
170 | | - |
171 | | - let host = host.to_ascii_lowercase(); |
172 | | - host == "localhost" |
173 | | - || host |
174 | | - .parse::<std::net::Ipv4Addr>() |
175 | | - .is_ok_and(|addr| addr.is_loopback()) |
176 | | - || host |
177 | | - .parse::<std::net::Ipv6Addr>() |
178 | | - .is_ok_and(|addr| addr.is_loopback()) |
| 151 | +fn peer_addr_is_loopback(peer_addr: &SocketAddr) -> bool { |
| 152 | + peer_addr.ip().is_loopback() |
179 | 153 | } |
180 | 154 |
|
181 | 155 | fn parse_domains(domains: Option<&String>) -> Vec<String> { |
@@ -248,16 +222,16 @@ fn build_outer_csp( |
248 | 222 |
|
249 | 223 | async fn mcp_app_proxy( |
250 | 224 | State(state): State<AppState>, |
251 | | - headers: HeaderMap, |
| 225 | + ConnectInfo(peer_addr): ConnectInfo<SocketAddr>, |
252 | 226 | Query(params): Query<ProxyQuery>, |
253 | 227 | ) -> Response { |
254 | 228 | if params.secret != state.secret_key { |
255 | 229 | return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(); |
256 | 230 | } |
257 | | - if !request_host_is_loopback(&headers) { |
| 231 | + if !peer_addr_is_loopback(&peer_addr) { |
258 | 232 | return ( |
259 | 233 | StatusCode::BAD_REQUEST, |
260 | | - "MCP app proxy is only available from a loopback host", |
| 234 | + "MCP app proxy is only available to loopback clients", |
261 | 235 | ) |
262 | 236 | .into_response(); |
263 | 237 | } |
@@ -289,16 +263,16 @@ async fn mcp_app_proxy( |
289 | 263 |
|
290 | 264 | async fn store_guest_html( |
291 | 265 | State(state): State<AppState>, |
292 | | - headers: HeaderMap, |
| 266 | + ConnectInfo(peer_addr): ConnectInfo<SocketAddr>, |
293 | 267 | Json(body): Json<StoreGuestBody>, |
294 | 268 | ) -> Response { |
295 | 269 | if body.secret != state.secret_key { |
296 | 270 | return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(); |
297 | 271 | } |
298 | | - if !request_host_is_loopback(&headers) { |
| 272 | + if !peer_addr_is_loopback(&peer_addr) { |
299 | 273 | return ( |
300 | 274 | StatusCode::BAD_REQUEST, |
301 | | - "MCP app guest storage is only available from a loopback host", |
| 275 | + "MCP app guest storage is only available to loopback clients", |
302 | 276 | ) |
303 | 277 | .into_response(); |
304 | 278 | } |
@@ -414,7 +388,8 @@ pub(crate) fn routes(secret_key: String) -> Router { |
414 | 388 |
|
415 | 389 | #[cfg(test)] |
416 | 390 | mod tests { |
417 | | - use super::{authority_is_loopback, normalize_csp_source, parse_domains}; |
| 391 | + use super::{normalize_csp_source, parse_domains, peer_addr_is_loopback}; |
| 392 | + use std::net::SocketAddr; |
418 | 393 |
|
419 | 394 | #[test] |
420 | 395 | fn normalizes_url_sources_to_origins() { |
@@ -470,10 +445,15 @@ mod tests { |
470 | 445 | } |
471 | 446 |
|
472 | 447 | #[test] |
473 | | - fn detects_loopback_authorities() { |
474 | | - assert!(authority_is_loopback("127.0.0.1:12345")); |
475 | | - assert!(authority_is_loopback("localhost:12345")); |
476 | | - assert!(authority_is_loopback("[::1]:12345")); |
477 | | - assert!(!authority_is_loopback("example.test")); |
| 448 | + fn detects_loopback_peer_addresses() { |
| 449 | + assert!(peer_addr_is_loopback( |
| 450 | + &"127.0.0.1:12345".parse::<SocketAddr>().unwrap() |
| 451 | + )); |
| 452 | + assert!(peer_addr_is_loopback( |
| 453 | + &"[::1]:12345".parse::<SocketAddr>().unwrap() |
| 454 | + )); |
| 455 | + assert!(!peer_addr_is_loopback( |
| 456 | + &"192.168.1.10:12345".parse::<SocketAddr>().unwrap() |
| 457 | + )); |
478 | 458 | } |
479 | 459 | } |
0 commit comments