Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 164 additions & 17 deletions crates/uv-distribution/src/distribution_database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use uv_distribution_types::{
use uv_extract::hash::Hasher;
use uv_fs::write_atomic;
use uv_install_wheel::validate_and_heal_record;
use uv_normalize::PackageName;
use uv_platform_tags::Tags;
use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml};
use uv_redacted::DisplaySafeUrl;
Expand Down Expand Up @@ -668,10 +669,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
async {
let size = size.or_else(|| content_length(&response));

// `DownloadProgress` finalizes the progress bar on drop, including on the
// retry-failure path where we exit early via `?`. See the type's docs.
let progress = self
.reporter
.as_ref()
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
.as_deref()
.map(|reporter| DownloadProgress::start(reporter, dist.name(), size));

let reader = response
.bytes_stream()
Expand All @@ -687,9 +690,10 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
.map_err(Error::CacheWrite)?;

let files = match progress {
Some((reporter, progress)) => {
let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
let files = match progress.as_ref() {
Some(progress) => {
let mut reader =
ProgressReader::new(&mut hasher, progress.id(), progress.reporter());
match extension {
WheelExtension::Whl => {
uv_extract::stream::unzip(query_url, &mut reader, temp_dir.path())
Expand Down Expand Up @@ -734,9 +738,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
.await
.map_err(Error::CacheRead)?;

if let Some((reporter, progress)) = progress {
reporter.on_download_complete(dist.name(), progress);
}
drop(progress);

Ok(Archive::new(
id,
Expand Down Expand Up @@ -845,10 +847,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
async {
let size = size.or_else(|| content_length(&response));

// `DownloadProgress` finalizes the progress bar on drop, including on the
// retry-failure path where we exit early via `?`. See the type's docs.
let progress = self
.reporter
.as_ref()
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
.as_deref()
.map(|reporter| DownloadProgress::start(reporter, dist.name(), size));

let reader = response
.bytes_stream()
Expand All @@ -863,13 +867,16 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
fs_err::File::from_parts(temp_file, self.build_context.cache().root()),
));

match progress {
Some((reporter, progress)) => {
match progress.as_ref() {
Some(progress) => {
// Wrap the reader in a progress reporter. This will report 100% progress
// after the download is complete, even if we still have to unzip and hash
// part of the file.
let mut reader =
ProgressReader::new(reader.compat(), progress, &**reporter);
let mut reader = ProgressReader::new(
reader.compat(),
progress.id(),
progress.reporter(),
);

tokio::io::copy(&mut reader, &mut writer)
.await
Expand Down Expand Up @@ -942,9 +949,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
.await
.map_err(Error::CacheRead)?;

if let Some((reporter, progress)) = progress {
reporter.on_download_complete(dist.name(), progress);
}
drop(progress);

Ok(Archive::new(id, hashes, filename.clone()))
}
Expand Down Expand Up @@ -1271,6 +1276,46 @@ fn content_length(response: &reqwest::Response) -> Option<u64> {
.and_then(|val| val.parse::<u64>().ok())
}

/// A RAII guard for a wheel download progress bar.
///
/// Calls [`Reporter::on_download_start`] on construction and
/// [`Reporter::on_download_complete`] on drop, guaranteeing that every bar
/// started is also finalized. This matters on the retry path: without the
/// guard, a download attempt that fails mid-stream returns early via `?` and
/// never reaches the explicit `on_download_complete` call, leaving a stale
/// progress bar on the screen. The outer retry loop then starts a fresh bar
/// for the next attempt, so long-running retries stack up ghost bars (see
/// <https://github.com/astral-sh/uv/issues/19110>).
struct DownloadProgress<'a> {
reporter: &'a dyn Reporter,
name: &'a PackageName,
id: usize,
}

impl<'a> DownloadProgress<'a> {
/// Start a download progress bar.
fn start(reporter: &'a dyn Reporter, name: &'a PackageName, size: Option<u64>) -> Self {
let id = reporter.on_download_start(name, size);
Self { reporter, name, id }
}

/// The progress bar id, for reporting incremental progress.
fn id(&self) -> usize {
self.id
}

/// The underlying reporter, for passing into [`ProgressReader`].
fn reporter(&self) -> &'a dyn Reporter {
self.reporter
}
}

impl Drop for DownloadProgress<'_> {
fn drop(&mut self) {
self.reporter.on_download_complete(self.name, self.id);
}
}

/// An asynchronous reader that reports progress as bytes are read.
struct ProgressReader<'a, R> {
reader: R,
Expand Down Expand Up @@ -1447,8 +1492,110 @@ fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl {

#[cfg(test)]
mod tests {
use std::str::FromStr;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};

use super::*;

/// A [`Reporter`] that counts `on_download_start` / `on_download_complete`
/// calls. Other callbacks are no-ops since only the download lifecycle
/// matters for [`DownloadProgress`].
#[derive(Default)]
struct CountingReporter {
started: AtomicUsize,
completed: AtomicUsize,
next_id: AtomicUsize,
events: Mutex<Vec<(String, String, usize)>>,
}

impl Reporter for CountingReporter {
fn on_build_start(&self, _source: &BuildableSource) -> usize {
0
}
fn on_build_complete(&self, _source: &BuildableSource, _id: usize) {}
fn on_checkout_start(&self, _url: &DisplaySafeUrl, _rev: &str) -> usize {
0
}
fn on_checkout_complete(&self, _url: &DisplaySafeUrl, _rev: &str, _id: usize) {}

fn on_download_start(&self, name: &PackageName, _size: Option<u64>) -> usize {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
self.started.fetch_add(1, Ordering::SeqCst);
self.events
.lock()
.unwrap()
.push(("start".into(), name.to_string(), id));
id
}

fn on_download_progress(&self, _id: usize, _inc: u64) {}

fn on_download_complete(&self, name: &PackageName, id: usize) {
self.completed.fetch_add(1, Ordering::SeqCst);
self.events
.lock()
.unwrap()
.push(("complete".into(), name.to_string(), id));
}
}

#[test]
fn download_progress_guard_completes_on_drop() {
let reporter = CountingReporter::default();
let name = PackageName::from_str("flask").unwrap();
{
let guard = DownloadProgress::start(&reporter, &name, Some(100));
assert_eq!(guard.id(), 0);
assert_eq!(reporter.started.load(Ordering::SeqCst), 1);
assert_eq!(reporter.completed.load(Ordering::SeqCst), 0);
}
assert_eq!(reporter.completed.load(Ordering::SeqCst), 1);
}

/// Simulates the download closure being invoked once, returning `Err`
/// for a transient failure and `Ok` for success. Mirrors the `?` early
/// return in `stream_wheel` / `download_wheel`.
fn simulate_download_attempt(
reporter: &CountingReporter,
name: &PackageName,
succeed: bool,
) -> Result<(), &'static str> {
let _progress = DownloadProgress::start(reporter, name, None);
if succeed { Ok(()) } else { Err("network") }
}

/// Regression test for <https://github.com/astral-sh/uv/issues/19110>.
///
/// The outer HTTP retry loop reruns the download closure on transient
/// network failures. Each attempt must leave the progress-bar count
/// balanced, otherwise the previous attempt's bar lingers on screen and
/// the next attempt draws a fresh one, stacking "ghost" bars.
#[test]
fn download_progress_guard_balances_start_and_complete_across_retries() {
let reporter = CountingReporter::default();
let name = PackageName::from_str("flask").unwrap();

// Two transient failures (early `?` return) followed by a success.
assert!(simulate_download_attempt(&reporter, &name, false).is_err());
assert!(simulate_download_attempt(&reporter, &name, false).is_err());
assert!(simulate_download_attempt(&reporter, &name, true).is_ok());

assert_eq!(reporter.started.load(Ordering::SeqCst), 3);
assert_eq!(reporter.completed.load(Ordering::SeqCst), 3);

// Every `start` event is immediately followed by a `complete` event
// with the matching id.
let events = reporter.events.lock().unwrap();
assert_eq!(events.len(), 6);
for pair in events.chunks(2) {
assert_eq!(pair[0].0, "start");
assert_eq!(pair[1].0, "complete");
assert_eq!(pair[0].1, pair[1].1);
assert_eq!(pair[0].2, pair[1].2);
}
}

#[test]
fn test_add_tar_zst_extension() {
let url =
Expand Down
Loading