From cdebb999dd4bedda3f273d4e3bb3bc8d7c852c9a Mon Sep 17 00:00:00 2001 From: Oshadha Gunawardena Date: Wed, 22 Apr 2026 13:31:58 +0530 Subject: [PATCH] Stop stacking progress bars when a wheel download is retried --- .../src/distribution_database.rs | 181 ++++++++++++++++-- 1 file changed, 164 insertions(+), 17 deletions(-) diff --git a/crates/uv-distribution/src/distribution_database.rs b/crates/uv-distribution/src/distribution_database.rs index cff84342fd0d6..79ea05ff9280a 100644 --- a/crates/uv-distribution/src/distribution_database.rs +++ b/crates/uv-distribution/src/distribution_database.rs @@ -25,6 +25,7 @@ use uv_distribution_types::{ use uv_extract::hash::Hasher; use uv_fs::write_atomic; use uv_install_wheel::validate_and_heal_record; +use uv_normalize::PackageName; use uv_platform_tags::Tags; use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml}; use uv_redacted::DisplaySafeUrl; @@ -668,10 +669,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { async { let size = size.or_else(|| content_length(&response)); + // `DownloadProgress` finalizes the progress bar on drop, including on the + // retry-failure path where we exit early via `?`. See the type's docs. let progress = self .reporter - .as_ref() - .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size))); + .as_deref() + .map(|reporter| DownloadProgress::start(reporter, dist.name(), size)); let reader = response .bytes_stream() @@ -687,9 +690,10 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { let temp_dir = tempfile::tempdir_in(self.build_context.cache().root()) .map_err(Error::CacheWrite)?; - let files = match progress { - Some((reporter, progress)) => { - let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter); + let files = match progress.as_ref() { + Some(progress) => { + let mut reader = + ProgressReader::new(&mut hasher, progress.id(), progress.reporter()); match extension { WheelExtension::Whl => { uv_extract::stream::unzip(query_url, &mut reader, temp_dir.path()) @@ -734,9 +738,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { .await .map_err(Error::CacheRead)?; - if let Some((reporter, progress)) = progress { - reporter.on_download_complete(dist.name(), progress); - } + drop(progress); Ok(Archive::new( id, @@ -845,10 +847,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { async { let size = size.or_else(|| content_length(&response)); + // `DownloadProgress` finalizes the progress bar on drop, including on the + // retry-failure path where we exit early via `?`. See the type's docs. let progress = self .reporter - .as_ref() - .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size))); + .as_deref() + .map(|reporter| DownloadProgress::start(reporter, dist.name(), size)); let reader = response .bytes_stream() @@ -863,13 +867,16 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { fs_err::File::from_parts(temp_file, self.build_context.cache().root()), )); - match progress { - Some((reporter, progress)) => { + match progress.as_ref() { + Some(progress) => { // Wrap the reader in a progress reporter. This will report 100% progress // after the download is complete, even if we still have to unzip and hash // part of the file. - let mut reader = - ProgressReader::new(reader.compat(), progress, &**reporter); + let mut reader = ProgressReader::new( + reader.compat(), + progress.id(), + progress.reporter(), + ); tokio::io::copy(&mut reader, &mut writer) .await @@ -942,9 +949,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { .await .map_err(Error::CacheRead)?; - if let Some((reporter, progress)) = progress { - reporter.on_download_complete(dist.name(), progress); - } + drop(progress); Ok(Archive::new(id, hashes, filename.clone())) } @@ -1271,6 +1276,46 @@ fn content_length(response: &reqwest::Response) -> Option { .and_then(|val| val.parse::().ok()) } +/// A RAII guard for a wheel download progress bar. +/// +/// Calls [`Reporter::on_download_start`] on construction and +/// [`Reporter::on_download_complete`] on drop, guaranteeing that every bar +/// started is also finalized. This matters on the retry path: without the +/// guard, a download attempt that fails mid-stream returns early via `?` and +/// never reaches the explicit `on_download_complete` call, leaving a stale +/// progress bar on the screen. The outer retry loop then starts a fresh bar +/// for the next attempt, so long-running retries stack up ghost bars (see +/// ). +struct DownloadProgress<'a> { + reporter: &'a dyn Reporter, + name: &'a PackageName, + id: usize, +} + +impl<'a> DownloadProgress<'a> { + /// Start a download progress bar. + fn start(reporter: &'a dyn Reporter, name: &'a PackageName, size: Option) -> Self { + let id = reporter.on_download_start(name, size); + Self { reporter, name, id } + } + + /// The progress bar id, for reporting incremental progress. + fn id(&self) -> usize { + self.id + } + + /// The underlying reporter, for passing into [`ProgressReader`]. + fn reporter(&self) -> &'a dyn Reporter { + self.reporter + } +} + +impl Drop for DownloadProgress<'_> { + fn drop(&mut self) { + self.reporter.on_download_complete(self.name, self.id); + } +} + /// An asynchronous reader that reports progress as bytes are read. struct ProgressReader<'a, R> { reader: R, @@ -1447,8 +1492,110 @@ fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl { #[cfg(test)] mod tests { + use std::str::FromStr; + use std::sync::Mutex; + use std::sync::atomic::{AtomicUsize, Ordering}; + use super::*; + /// A [`Reporter`] that counts `on_download_start` / `on_download_complete` + /// calls. Other callbacks are no-ops since only the download lifecycle + /// matters for [`DownloadProgress`]. + #[derive(Default)] + struct CountingReporter { + started: AtomicUsize, + completed: AtomicUsize, + next_id: AtomicUsize, + events: Mutex>, + } + + impl Reporter for CountingReporter { + fn on_build_start(&self, _source: &BuildableSource) -> usize { + 0 + } + fn on_build_complete(&self, _source: &BuildableSource, _id: usize) {} + fn on_checkout_start(&self, _url: &DisplaySafeUrl, _rev: &str) -> usize { + 0 + } + fn on_checkout_complete(&self, _url: &DisplaySafeUrl, _rev: &str, _id: usize) {} + + fn on_download_start(&self, name: &PackageName, _size: Option) -> usize { + let id = self.next_id.fetch_add(1, Ordering::SeqCst); + self.started.fetch_add(1, Ordering::SeqCst); + self.events + .lock() + .unwrap() + .push(("start".into(), name.to_string(), id)); + id + } + + fn on_download_progress(&self, _id: usize, _inc: u64) {} + + fn on_download_complete(&self, name: &PackageName, id: usize) { + self.completed.fetch_add(1, Ordering::SeqCst); + self.events + .lock() + .unwrap() + .push(("complete".into(), name.to_string(), id)); + } + } + + #[test] + fn download_progress_guard_completes_on_drop() { + let reporter = CountingReporter::default(); + let name = PackageName::from_str("flask").unwrap(); + { + let guard = DownloadProgress::start(&reporter, &name, Some(100)); + assert_eq!(guard.id(), 0); + assert_eq!(reporter.started.load(Ordering::SeqCst), 1); + assert_eq!(reporter.completed.load(Ordering::SeqCst), 0); + } + assert_eq!(reporter.completed.load(Ordering::SeqCst), 1); + } + + /// Simulates the download closure being invoked once, returning `Err` + /// for a transient failure and `Ok` for success. Mirrors the `?` early + /// return in `stream_wheel` / `download_wheel`. + fn simulate_download_attempt( + reporter: &CountingReporter, + name: &PackageName, + succeed: bool, + ) -> Result<(), &'static str> { + let _progress = DownloadProgress::start(reporter, name, None); + if succeed { Ok(()) } else { Err("network") } + } + + /// Regression test for . + /// + /// The outer HTTP retry loop reruns the download closure on transient + /// network failures. Each attempt must leave the progress-bar count + /// balanced, otherwise the previous attempt's bar lingers on screen and + /// the next attempt draws a fresh one, stacking "ghost" bars. + #[test] + fn download_progress_guard_balances_start_and_complete_across_retries() { + let reporter = CountingReporter::default(); + let name = PackageName::from_str("flask").unwrap(); + + // Two transient failures (early `?` return) followed by a success. + assert!(simulate_download_attempt(&reporter, &name, false).is_err()); + assert!(simulate_download_attempt(&reporter, &name, false).is_err()); + assert!(simulate_download_attempt(&reporter, &name, true).is_ok()); + + assert_eq!(reporter.started.load(Ordering::SeqCst), 3); + assert_eq!(reporter.completed.load(Ordering::SeqCst), 3); + + // Every `start` event is immediately followed by a `complete` event + // with the matching id. + let events = reporter.events.lock().unwrap(); + assert_eq!(events.len(), 6); + for pair in events.chunks(2) { + assert_eq!(pair[0].0, "start"); + assert_eq!(pair[1].0, "complete"); + assert_eq!(pair[0].1, pair[1].1); + assert_eq!(pair[0].2, pair[1].2); + } + } + #[test] fn test_add_tar_zst_extension() { let url =