diff --git a/crates/rattler_networking/src/mirror_middleware.rs b/crates/rattler_networking/src/mirror_middleware.rs index 2a4b0d68b7..d79d9dfc2f 100644 --- a/crates/rattler_networking/src/mirror_middleware.rs +++ b/crates/rattler_networking/src/mirror_middleware.rs @@ -118,11 +118,16 @@ impl Middleware for MirrorMiddleware { }; let mirror = &selected_mirror.mirror; - let selected_url = mirror.url.join(url_rest).map_err(|e| { - reqwest_middleware::Error::Middleware(anyhow::anyhow!( - "Failed to join mirror URL with '{url_rest}': {e}" - )) - })?; + let selected_url = { + let mut u = mirror.url.clone(); + let base_path = u.path().trim_end_matches('/'); + if url_rest.is_empty() { + u.set_path(&format!("{base_path}/")); + } else { + u.set_path(&format!("{base_path}/{url_rest}")); + } + u + }; // Short-circuit if the mirror does not support the file type if url_rest.ends_with(".json.zst") && mirror.no_zstd { @@ -305,4 +310,48 @@ mod test { len = path.0.len(); } } + + #[tokio::test] + async fn test_mirror_middleware_path_rewrite() { + // Start a server that serves at /channel/count + let state = String::from("mirror server"); + let router = Router::new() + .route("/channel/count", get(count)) + .with_state(state); + + let addr = SocketAddr::new([127, 0, 0, 1].into(), 0); + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(axum::serve(listener, router.into_make_service()).into_future()); + + let mirror_url: Url = format!("http://{}:{}/channel", addr.ip(), addr.port()) + .parse() + .unwrap(); + + let mut mirror_map = std::collections::HashMap::new(); + + // Upstream key includes a path segment (e.g. conda-forge) + // Mirror URL also has a path segment (e.g. channel) + // The mirror path must fully replace the upstream path. + mirror_map.insert( + "https://prefix.dev/conda-forge".parse().unwrap(), + vec![mirror_setting(mirror_url)], + ); + + let middleware = MirrorMiddleware::from_map(mirror_map); + let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()) + .with(middleware) + .build(); + + // Request to upstream: https://prefix.dev/conda-forge/count + // Should be rewritten to: http://127.0.0.1:PORT/channel/count + let res = client + .get("https://prefix.dev/conda-forge/count") + .send() + .await + .unwrap(); + assert!(res.status().is_success(), "status: {}", res.status()); + let body = res.text().await.unwrap(); + assert_eq!(body, "Hi from counter: mirror server"); + } } diff --git a/crates/rattler_repodata_gateway/src/gateway/query.rs b/crates/rattler_repodata_gateway/src/gateway/query.rs index 775133e583..eccb0dfffb 100644 --- a/crates/rattler_repodata_gateway/src/gateway/query.rs +++ b/crates/rattler_repodata_gateway/src/gateway/query.rs @@ -402,7 +402,7 @@ impl QueryExecutor { } } - /// Add matching records to the result. + /// Add matching records to the result and populate `package_names`. fn accumulate_records( &mut self, result_idx: usize, @@ -411,14 +411,21 @@ impl QueryExecutor { ) { let result = &mut self.result[result_idx]; + for record in &records { + // Always track package name for channel presence. + result + .package_names + .insert(record.package_record.name.clone()); + } + match request_specs { SourceSpecs::Transitive => { // All records match — extend with Arc clones (cheap refcount bumps). result.records.extend(records); } SourceSpecs::Input(specs) => { - // Only a subset matches — filter and clone matching Arcs. for record in &records { + // Only include records that match at least one input spec. if specs.iter().any(|s| s.matches(record.as_ref())) { result.records.push(record.clone()); } diff --git a/crates/rattler_repodata_gateway/src/gateway/repo_data.rs b/crates/rattler_repodata_gateway/src/gateway/repo_data.rs index 51f6019982..8bb530eea7 100644 --- a/crates/rattler_repodata_gateway/src/gateway/repo_data.rs +++ b/crates/rattler_repodata_gateway/src/gateway/repo_data.rs @@ -1,6 +1,6 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use rattler_conda_types::RepoDataRecord; +use rattler_conda_types::{PackageName, RepoDataRecord}; /// A container for [`RepoDataRecord`]s that are returned from the [`super::Gateway`]. /// @@ -12,6 +12,9 @@ use rattler_conda_types::RepoDataRecord; #[derive(Debug, Default, Clone)] pub struct RepoData { pub(crate) records: Vec>, + /// All package names present in this channel. + /// Includes names whose records were filtered out by version constraints. + pub(crate) package_names: HashSet, } impl RepoData { @@ -32,6 +35,11 @@ impl RepoData { self.records.is_empty() } + /// Returns the package names present in this channel. + pub fn package_names(&self) -> &HashSet { + &self.package_names + } + /// Returns an iterator over the Arc-wrapped records. /// /// This is useful when you want to clone records cheaply (Arc clone diff --git a/crates/rattler_solve/src/lib.rs b/crates/rattler_solve/src/lib.rs index cd1a46f663..db85cf0895 100644 --- a/crates/rattler_solve/src/lib.rs +++ b/crates/rattler_solve/src/lib.rs @@ -282,6 +282,7 @@ pub struct SolverTask { /// Dependency overrides that replace dependencies of matching packages. pub dependency_overrides: Vec<(MatchSpec, MatchSpec)>, + // channel_package_names field removed; package presence is now tracked in RepoData } impl<'r, I: IntoIterator> FromIterator diff --git a/crates/rattler_solve/src/resolvo/mod.rs b/crates/rattler_solve/src/resolvo/mod.rs index a88430376a..93d84778f4 100644 --- a/crates/rattler_solve/src/resolvo/mod.rs +++ b/crates/rattler_solve/src/resolvo/mod.rs @@ -50,12 +50,18 @@ pub struct DependencyOverride { pub struct RepoData<'a> { /// The actual records after parsing `repodata.json` pub records: Vec<&'a RepoDataRecord>, + /// The channel URL for this `RepoData` + pub channel: Option, + /// All package names present in this channel + pub package_names: std::collections::HashSet, } impl<'a> FromIterator<&'a RepoDataRecord> for RepoData<'a> { fn from_iter>(iter: T) -> Self { Self { records: Vec::from_iter(iter), + channel: None, + package_names: std::collections::HashSet::new(), } } } @@ -333,10 +339,23 @@ impl<'a> CondaDependencyProvider<'a> { .collect::>(); // Hashmap that maps the package name to the channel it was first found in. - let mut package_name_found_in_channel = HashMap::>::new(); + let mut package_name_found_in_channel = HashMap::>::new(); + + // Pre-populate channel ownership from RepoData.package_names. + let repodata_vec: Vec> = repodata.into_iter().collect(); + if channel_priority == ChannelPriority::Strict { + for repo_data in &repodata_vec { + let channel = repo_data.channel.clone(); + for name in repo_data.package_names.iter() { + package_name_found_in_channel + .entry(name.as_normalized().to_string()) + .or_insert_with(|| channel.clone()); + } + } + } // Add additional records - for repo_data in repodata { + for repo_data in &repodata_vec { // Iterate over all records and dedup records that refer to the same package // data but with different archive types. This can happen if you // have two variants of the same package but with different @@ -352,7 +371,7 @@ impl<'a> CondaDependencyProvider<'a> { let mut package_to_type: HashMap<&ArchiveIdentifier, (DistArchiveType, usize, bool)> = HashMap::with_capacity(repo_data.records.len()); - for record in repo_data.records { + for record in &repo_data.records { // Determine if this record will be excluded by exclude_newer. let excluded_by_newer = matches!((&exclude_newer, &record.package_record.timestamp), (Some(exclude_newer), Some(record_timestamp)) @@ -504,7 +523,7 @@ impl<'a> CondaDependencyProvider<'a> { channel_priority, ) { // Add the record to the excluded list when it is from a different channel. - if first_channel != &&record.channel { + if first_channel != &record.channel { if let Some(channel) = &record.channel { tracing::debug!( "Ignoring '{}' from '{}' because of strict channel priority.", @@ -531,7 +550,7 @@ impl<'a> CondaDependencyProvider<'a> { } else { package_name_found_in_channel.insert( record.package_record.name.as_normalized().to_string(), - &record.channel, + record.channel.clone(), ); } } diff --git a/crates/rattler_solve/tests/backends/main.rs b/crates/rattler_solve/tests/backends/main.rs index 53446c965b..6df50e9f0b 100644 --- a/crates/rattler_solve/tests/backends/main.rs +++ b/crates/rattler_solve/tests/backends/main.rs @@ -2,8 +2,10 @@ use std::{collections::BTreeMap, str::FromStr, time::Instant}; use chrono::{DateTime, Utc}; use once_cell::sync::Lazy; +use rattler_conda_types::package::{ + ArchiveIdentifier, CondaArchiveType, DistArchiveIdentifier, DistArchiveType, +}; use rattler_conda_types::{ - package::{ArchiveIdentifier, CondaArchiveType, DistArchiveIdentifier, DistArchiveType}, Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, NoArchType, PackageRecord, ParseMatchSpecOptions, ParseStrictness, RepoData, RepoDataRecord, SolverResult, Version, }; diff --git a/crates/rattler_solve/tests/sorting.rs b/crates/rattler_solve/tests/sorting.rs index 23dc93ee55..46ecc930f3 100644 --- a/crates/rattler_solve/tests/sorting.rs +++ b/crates/rattler_solve/tests/sorting.rs @@ -53,6 +53,7 @@ fn create_sorting_snapshot(package_name: &str, strategy: SolveStrategy) -> Strin None, // min_age strategy, Vec::new(), // dependency_overrides + // channel_package_names removed; package presence tracked in RepoData ) .expect("failed to create dependency provider");