From c852e509208267034ce8550e46a79550a5c9921b Mon Sep 17 00:00:00 2001 From: Wolf Vollprecht Date: Fri, 6 Mar 2026 08:22:43 +0100 Subject: [PATCH] add backoff-retry to shards downloading, 429 handling --- .../src/gateway/sharded_subdir/mod.rs | 156 +++++++++++++++++- .../src/gateway/sharded_subdir/tokio/mod.rs | 110 +++++++++--- .../src/gateway/sharded_subdir/wasm/mod.rs | 108 +++++++++--- 3 files changed, 319 insertions(+), 55 deletions(-) diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs index 47f365583..1f36af21b 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs @@ -15,6 +15,34 @@ use crate::{ GatewayError, }; +/// Returns `true` if the error is transient and the request should be retried. +/// This includes 429 (Too Many Requests), 5xx server errors, and connection +/// errors. +fn is_transient_error(err: &GatewayError) -> bool { + match err { + GatewayError::FetchRepoDataError(fetch_err) => match fetch_err { + FetchRepoDataError::HttpError(err) => is_transient_reqwest_middleware_error(err), + _ => false, + }, + GatewayError::ReqwestError(err) => is_transient_reqwest_error(err), + _ => false, + } +} + +fn is_transient_reqwest_middleware_error(err: &reqwest_middleware::Error) -> bool { + match err { + reqwest_middleware::Error::Reqwest(err) => is_transient_reqwest_error(err), + _ => false, + } +} + +fn is_transient_reqwest_error(err: &reqwest::Error) -> bool { + err.status() + .is_some_and(|s| s == http::StatusCode::TOO_MANY_REQUESTS || s.is_server_error()) + || err.is_connect() + || err.is_timeout() +} + cfg_if! { if #[cfg(target_arch = "wasm32")] { mod wasm; @@ -168,23 +196,34 @@ mod tests { use crate::gateway::subdir::SubdirClient; use axum::{ body::Body, + extract::State, http::{Response, StatusCode}, routing::get, Router, }; - use rattler_conda_types::{Channel, ShardedRepodata, ShardedSubdirInfo}; + use rattler_conda_types::{Channel, Shard, ShardedRepodata, ShardedSubdirInfo}; use rattler_digest::{parse_digest_from_hex, Sha256}; use std::future::IntoFuture; use std::net::SocketAddr; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::sync::Arc; use tokio::sync::oneshot; use url::Url; use super::ShardedSubdir; + /// Shared state for the mock server to track request counts. + #[derive(Clone)] + struct MockState { + shard_response: MockShardResponse, + request_count: Arc, + } + /// A mock server that serves a sharded repodata index but returns /// configurable responses for shard requests. struct MockShardedServer { local_addr: SocketAddr, + request_count: Arc, _shutdown_sender: oneshot::Sender<()>, } @@ -213,6 +252,12 @@ mod tests { let index_bytes = rmp_serde::to_vec(&sharded_index).unwrap(); let compressed_index = zstd::encode_all(index_bytes.as_slice(), 3).unwrap(); + let request_count = Arc::new(AtomicU32::new(0)); + let state = MockState { + shard_response, + request_count: request_count.clone(), + }; + let app = Router::new() .route( "/linux-64/repodata_shards.msgpack.zst", @@ -226,8 +271,9 @@ mod tests { ) .route( "/linux-64/shards/{shard_file}", - get(move || async move { - match shard_response { + get(|State(state): State| async move { + let count = state.request_count.fetch_add(1, Ordering::SeqCst); + match state.shard_response { MockShardResponse::Empty => Response::builder() .status(StatusCode::OK) .body(Body::empty()) @@ -239,9 +285,33 @@ mod tests { .body(Body::from(vec![0x28, 0xb5, 0x2f, 0xfd])) .unwrap() } + MockShardResponse::TooManyRequests { fail_count } => { + if count < fail_count { + Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .body(Body::empty()) + .unwrap() + } else { + // Return a valid shard + let shard = Shard { + packages: Default::default(), + conda_packages: Default::default(), + removed: Default::default(), + experimental_v3: Default::default(), + }; + let shard_bytes = rmp_serde::to_vec(&shard).unwrap(); + let compressed = + zstd::encode_all(shard_bytes.as_slice(), 3).unwrap(); + Response::builder() + .status(StatusCode::OK) + .body(Body::from(compressed)) + .unwrap() + } + } } }), - ); + ) + .with_state(state); let addr = SocketAddr::new([127, 0, 0, 1].into(), 0); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); @@ -258,6 +328,7 @@ mod tests { Self { local_addr, + request_count, _shutdown_sender: tx, } } @@ -269,12 +340,20 @@ mod tests { fn channel(&self) -> Channel { Channel::from_url(self.url()) } + + fn request_count(&self) -> u32 { + self.request_count.load(Ordering::SeqCst) + } } #[derive(Clone, Copy)] enum MockShardResponse { Empty, Truncated, + /// Return 429 for the first `fail_count` requests, then succeed. + TooManyRequests { + fail_count: u32, + }, } #[tokio::test] @@ -346,4 +425,73 @@ mod tests { insta::assert_snapshot!("truncated_shard_response_error", err_string); } + + #[tokio::test] + async fn test_429_retry_succeeds() { + // Server returns 429 twice, then succeeds on the 3rd request + let server = + MockShardedServer::new(MockShardResponse::TooManyRequests { fail_count: 2 }).await; + let channel = server.channel(); + let cache_dir = tempfile::tempdir().unwrap(); + + let client = rattler_networking::LazyClient::default(); + + let subdir = ShardedSubdir::new( + channel, + "linux-64".to_string(), + client, + cache_dir.path().to_path_buf(), + CacheAction::NoCache, + None, + None, + ) + .await + .unwrap(); + + let package_name = "test-package".parse().unwrap(); + let result = subdir.fetch_package_records(&package_name, None).await; + + // Should succeed after retries + assert!( + result.is_ok(), + "expected success after retries, got: {result:?}" + ); + // Should have made 3 requests (2 failures + 1 success) + assert_eq!(server.request_count(), 3); + } + + #[tokio::test] + async fn test_429_retry_exhausted() { + // Server always returns 429 (more failures than retries allow) + let server = + MockShardedServer::new(MockShardResponse::TooManyRequests { fail_count: 100 }).await; + let channel = server.channel(); + let cache_dir = tempfile::tempdir().unwrap(); + + let client = rattler_networking::LazyClient::default(); + + let subdir = ShardedSubdir::new( + channel, + "linux-64".to_string(), + client, + cache_dir.path().to_path_buf(), + CacheAction::NoCache, + None, + None, + ) + .await + .unwrap(); + + let package_name = "test-package".parse().unwrap(); + let result = subdir.fetch_package_records(&package_name, None).await; + + // Should fail after exhausting retries + let err = result.expect_err("should fail after exhausting retries"); + assert!( + err.to_string().contains("429"), + "error should mention 429: {err}" + ); + // default_retry_policy retries 3 times, so 4 total requests (1 initial + 3 retries) + assert_eq!(server.request_count(), 4); + } } diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/tokio/mod.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/tokio/mod.rs index cda3e83dd..03ab285b1 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/tokio/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/tokio/mod.rs @@ -8,7 +8,7 @@ use std::{ use rattler_conda_types::Platform; -use super::{add_trailing_slash, decode_zst_bytes_async, parse_records}; +use super::{add_trailing_slash, decode_zst_bytes_async, is_transient_error, parse_records}; use crate::{ fetch::{CacheAction, FetchRepoDataError}, gateway::{ @@ -22,7 +22,8 @@ use fs_err::tokio as tokio_fs; use futures::future::OptionFuture; use http::{header::CACHE_CONTROL, HeaderValue, StatusCode}; use rattler_conda_types::{Channel, PackageName, ShardedRepodata}; -use rattler_networking::LazyClient; +use rattler_networking::{retry_policies::default_retry_policy, LazyClient}; +use retry_policies::{RetryDecision, RetryPolicy}; use simple_spawn_blocking::tokio::run_blocking_task; use url::Url; @@ -38,6 +39,9 @@ pub struct ShardedSubdir { concurrent_requests_semaphore: Option>, cache_dir: PathBuf, cache_action: CacheAction, + /// Shared backoff deadline. When a 429 is received, this is set so that + /// all concurrent requests to the same host wait before retrying. + backoff_until: Arc>>, } impl ShardedSubdir { @@ -113,6 +117,7 @@ impl ShardedSubdir { cache_dir, cache_action, concurrent_requests_semaphore, + backoff_until: Arc::default(), }) } @@ -210,42 +215,95 @@ impl SubdirClient for ShardedSubdir { .join(&format!("{shard:x}.msgpack.zst")) .expect("invalid shard url"); - let shard_request = self - .client - .client() - .get(shard_url.clone()) - .header(CACHE_CONTROL, HeaderValue::from_static("no-store")) - .build() - .expect("failed to build shard request"); + let retry_policy = default_retry_policy(); + let mut retry_count = 0u32; + + let shard_bytes = loop { + // If another request recently received a 429, wait for the shared + // backoff deadline before sending a new request. + { + let deadline = *self.backoff_until.lock().await; + if let Some(deadline) = deadline { + tokio::time::sleep_until(deadline).await; + } + } + + let shard_request = self + .client + .client() + .get(shard_url.clone()) + .header(CACHE_CONTROL, HeaderValue::from_static("no-store")) + .build() + .expect("failed to build shard request"); - let shard_bytes = { let _request_permit = OptionFuture::from( self.concurrent_requests_semaphore .as_deref() .map(tokio::sync::Semaphore::acquire), ) .await; + + let request_start = std::time::SystemTime::now(); let reporter = reporter .and_then(Reporter::download_reporter) .map(|r| (r, r.on_download_start(&shard_url))); - let shard_response = self - .client - .client() - .execute(shard_request) - .await - .and_then(|r| r.error_for_status().map_err(Into::into)) - .map_err(FetchRepoDataError::from)?; - - let bytes = shard_response - .bytes_with_progress(reporter) - .await - .map_err(FetchRepoDataError::from)?; - - if let Some((reporter, index)) = reporter { - reporter.on_download_complete(&shard_url, index); + + let result = async { + let shard_response = self + .client + .client() + .execute(shard_request) + .await + .and_then(|r| r.error_for_status().map_err(Into::into)) + .map_err(FetchRepoDataError::from)?; + + let bytes = shard_response + .bytes_with_progress(reporter) + .await + .map_err(FetchRepoDataError::from)?; + + if let Some((reporter, index)) = reporter { + reporter.on_download_complete(&shard_url, index); + } + + Ok::<_, GatewayError>(bytes) } + .await; - bytes + match result { + Ok(bytes) => break bytes, + Err(err) if is_transient_error(&err) => { + match retry_policy.should_retry(request_start, retry_count) { + RetryDecision::Retry { execute_after } => { + let sleep_duration = execute_after + .duration_since(std::time::SystemTime::now()) + .unwrap_or_default(); + + // Set the shared backoff deadline so other concurrent + // requests also wait instead of hammering the server. + { + let new_deadline = tokio::time::Instant::now() + sleep_duration; + let mut backoff = self.backoff_until.lock().await; + // Only push the deadline forward, never backward. + if backoff.map_or(true, |d| new_deadline > d) { + *backoff = Some(new_deadline); + } + } + + tracing::warn!( + "transient error fetching shard {}: {}. Retry #{}, sleeping {sleep_duration:?}...", + shard_url, + err, + retry_count + 1, + ); + tokio::time::sleep(sleep_duration).await; + retry_count += 1; + } + RetryDecision::DoNotRetry => return Err(err), + } + } + Err(err) => return Err(err), + } }; let shard_bytes = decode_zst_bytes_async(shard_bytes, shard_url).await?; diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/wasm/mod.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/wasm/mod.rs index 43ae35cb7..0cf5ecc03 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/wasm/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/wasm/mod.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use futures::future::OptionFuture; use http::StatusCode; use rattler_conda_types::{Channel, PackageName, ShardedRepodata}; -use rattler_networking::LazyClient; +use rattler_networking::{retry_policies::default_retry_policy, LazyClient}; +use retry_policies::{RetryDecision, RetryPolicy}; use url::Url; -use super::add_trailing_slash; +use super::{add_trailing_slash, is_transient_error}; mod index; @@ -28,6 +29,9 @@ pub struct ShardedSubdir { package_base_url: Url, sharded_repodata: ShardedRepodata, concurrent_requests_semaphore: Option>, + /// Shared backoff deadline. When a 429 is received, this is set so that + /// all concurrent requests to the same host wait before retrying. + backoff_until: Arc>>, } impl ShardedSubdir { @@ -91,6 +95,7 @@ impl ShardedSubdir { package_base_url: add_trailing_slash(&package_base_url).into_owned(), sharded_repodata, concurrent_requests_semaphore, + backoff_until: Arc::default(), }) } } @@ -113,41 +118,94 @@ impl SubdirClient for ShardedSubdir { .join(&format!("{shard:x}.msgpack.zst")) .expect("invalid shard url"); - let shard_request = self - .client - .client() - .get(shard_url.clone()) - .build() - .expect("failed to build shard request"); + let retry_policy = default_retry_policy(); + let mut retry_count = 0u32; + + let shard_bytes = loop { + // If another request recently received a 429, wait for the shared + // backoff deadline before sending a new request. + { + let deadline = *self.backoff_until.lock().await; + if let Some(deadline) = deadline { + wasmtimer::tokio::sleep_until(deadline).await; + } + } + + let shard_request = self + .client + .client() + .get(shard_url.clone()) + .build() + .expect("failed to build shard request"); - let shard_bytes = { let _request_permit = OptionFuture::from( self.concurrent_requests_semaphore .as_deref() .map(tokio::sync::Semaphore::acquire), ) .await; + + let request_start = std::time::SystemTime::now(); let reporter = reporter .and_then(Reporter::download_reporter) .map(|r| (r, r.on_download_start(&shard_url))); - let shard_response = self - .client - .client() - .execute(shard_request) - .await - .and_then(|r| r.error_for_status().map_err(Into::into)) - .map_err(FetchRepoDataError::from)?; - - let bytes = shard_response - .bytes_with_progress(reporter) - .await - .map_err(FetchRepoDataError::from)?; - - if let Some((reporter, index)) = reporter { - reporter.on_download_complete(&shard_url, index); + + let result = async { + let shard_response = self + .client + .client() + .execute(shard_request) + .await + .and_then(|r| r.error_for_status().map_err(Into::into)) + .map_err(FetchRepoDataError::from)?; + + let bytes = shard_response + .bytes_with_progress(reporter) + .await + .map_err(FetchRepoDataError::from)?; + + if let Some((reporter, index)) = reporter { + reporter.on_download_complete(&shard_url, index); + } + + Ok::<_, GatewayError>(bytes) } + .await; - bytes + match result { + Ok(bytes) => break bytes, + Err(err) if is_transient_error(&err) => { + match retry_policy.should_retry(request_start, retry_count) { + RetryDecision::Retry { execute_after } => { + let sleep_duration = execute_after + .duration_since(std::time::SystemTime::now()) + .unwrap_or_default(); + + // Set the shared backoff deadline so other concurrent + // requests also wait instead of hammering the server. + { + let new_deadline = + wasmtimer::tokio::Instant::now() + sleep_duration; + let mut backoff = self.backoff_until.lock().await; + if backoff.map_or(true, |d| new_deadline > d) { + *backoff = Some(new_deadline); + } + } + + tracing::warn!( + "transient error fetching shard {}: {}. Retry #{}, sleeping {sleep_duration:?}...", + shard_url, + err, + retry_count + 1, + ); + wasmtimer::tokio::sleep(sleep_duration).await; + retry_count += 1; + } + RetryDecision::DoNotRetry => return Err(err), + } + } + Err(err) => return Err(err), + } }; let shard_bytes = decode_zst_bytes_async(shard_bytes, shard_url).await?;