diff --git a/Cargo.lock b/Cargo.lock index 0357b655c..adcab959e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5198,6 +5198,7 @@ dependencies = [ "itertools 0.14.0", "keyring", "netrc-rs", + "rattler_conda_types", "rattler_config", "reqwest 0.12.28", "reqwest-middleware", diff --git a/crates/rattler_networking/Cargo.toml b/crates/rattler_networking/Cargo.toml index 21c6df169..c98542057 100644 --- a/crates/rattler_networking/Cargo.toml +++ b/crates/rattler_networking/Cargo.toml @@ -22,6 +22,7 @@ system-integration = ["keyring", "netrc-rs", "dirs"] features = ["gcs", "s3"] [dependencies] +rattler_conda_types = { workspace = true } anyhow = { workspace = true } async-once-cell = { workspace = true } async-trait = { workspace = true } diff --git a/crates/rattler_networking/src/mirror_middleware.rs b/crates/rattler_networking/src/mirror_middleware.rs index f36ee396b..7ba967fab 100644 --- a/crates/rattler_networking/src/mirror_middleware.rs +++ b/crates/rattler_networking/src/mirror_middleware.rs @@ -1,4 +1,5 @@ //! Middleware to handle mirrors +use rattler_conda_types::utils::url_with_trailing_slash::UrlWithTrailingSlash; use std::{ collections::HashMap, sync::atomic::{self, AtomicUsize}, @@ -118,7 +119,9 @@ impl Middleware for MirrorMiddleware { }; let mirror = &selected_mirror.mirror; - let selected_url = mirror.url.join(url_rest).unwrap(); + + let base = UrlWithTrailingSlash::from(mirror.url.clone()); + let selected_url = base.as_ref().join(url_rest).unwrap(); // Short-circuit if the mirror does not support the file type if url_rest.ends_with(".json.zst") && mirror.no_zstd { @@ -301,4 +304,48 @@ mod test { len = path.0.len(); } } + + #[tokio::test] + async fn test_mirror_middleware_path_rewrite() { + // Start a server that serves at /channel/count + let state = String::from("mirror server"); + let router = Router::new() + .route("/channel/count", get(count)) + .with_state(state); + + let addr = SocketAddr::new([127, 0, 0, 1].into(), 0); + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(axum::serve(listener, router.into_make_service()).into_future()); + + let mirror_url: Url = format!("http://{}:{}/channel", addr.ip(), addr.port()) + .parse() + .unwrap(); + + let mut mirror_map = std::collections::HashMap::new(); + + // Upstream key includes a path segment (e.g. conda-forge) + // Mirror URL also has a path segment (e.g. channel) + // The mirror path must fully replace the upstream path. + mirror_map.insert( + "https://prefix.dev/conda-forge".parse().unwrap(), + vec![mirror_setting(mirror_url)], + ); + + let middleware = MirrorMiddleware::from_map(mirror_map); + let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()) + .with(middleware) + .build(); + + // Request to upstream: https://prefix.dev/conda-forge/count + // Should be rewritten to: http://127.0.0.1:PORT/channel/count + let res = client + .get("https://prefix.dev/conda-forge/count") + .send() + .await + .unwrap(); + assert!(res.status().is_success(), "status: {}", res.status()); + let body = res.text().await.unwrap(); + assert_eq!(body, "Hi from counter: mirror server"); + } }