Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 152 additions & 4 deletions crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,34 @@ use crate::{
GatewayError,
};

/// Returns `true` if the error is transient and the request should be retried.
/// This includes 429 (Too Many Requests), 5xx server errors, and connection
/// errors.
fn is_transient_error(err: &GatewayError) -> bool {
match err {
GatewayError::FetchRepoDataError(fetch_err) => match fetch_err {
FetchRepoDataError::HttpError(err) => is_transient_reqwest_middleware_error(err),
_ => false,
},
GatewayError::ReqwestError(err) => is_transient_reqwest_error(err),
_ => false,
}
}

fn is_transient_reqwest_middleware_error(err: &reqwest_middleware::Error) -> bool {
match err {
reqwest_middleware::Error::Reqwest(err) => is_transient_reqwest_error(err),
_ => false,
}
}

fn is_transient_reqwest_error(err: &reqwest::Error) -> bool {
err.status()
.is_some_and(|s| s == http::StatusCode::TOO_MANY_REQUESTS || s.is_server_error())
|| err.is_connect()
|| err.is_timeout()
}

cfg_if! {
if #[cfg(target_arch = "wasm32")] {
mod wasm;
Expand Down Expand Up @@ -168,23 +196,34 @@ mod tests {
use crate::gateway::subdir::SubdirClient;
use axum::{
body::Body,
extract::State,
http::{Response, StatusCode},
routing::get,
Router,
};
use rattler_conda_types::{Channel, ShardedRepodata, ShardedSubdirInfo};
use rattler_conda_types::{Channel, Shard, ShardedRepodata, ShardedSubdirInfo};
use rattler_digest::{parse_digest_from_hex, Sha256};
use std::future::IntoFuture;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::sync::oneshot;
use url::Url;

use super::ShardedSubdir;

/// Shared state for the mock server to track request counts.
#[derive(Clone)]
struct MockState {
shard_response: MockShardResponse,
request_count: Arc<AtomicU32>,
}

/// A mock server that serves a sharded repodata index but returns
/// configurable responses for shard requests.
struct MockShardedServer {
local_addr: SocketAddr,
request_count: Arc<AtomicU32>,
_shutdown_sender: oneshot::Sender<()>,
}

Expand Down Expand Up @@ -213,6 +252,12 @@ mod tests {
let index_bytes = rmp_serde::to_vec(&sharded_index).unwrap();
let compressed_index = zstd::encode_all(index_bytes.as_slice(), 3).unwrap();

let request_count = Arc::new(AtomicU32::new(0));
let state = MockState {
shard_response,
request_count: request_count.clone(),
};

let app = Router::new()
.route(
"/linux-64/repodata_shards.msgpack.zst",
Expand All @@ -226,8 +271,9 @@ mod tests {
)
.route(
"/linux-64/shards/{shard_file}",
get(move || async move {
match shard_response {
get(|State(state): State<MockState>| async move {
let count = state.request_count.fetch_add(1, Ordering::SeqCst);
match state.shard_response {
MockShardResponse::Empty => Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
Expand All @@ -239,9 +285,33 @@ mod tests {
.body(Body::from(vec![0x28, 0xb5, 0x2f, 0xfd]))
.unwrap()
}
MockShardResponse::TooManyRequests { fail_count } => {
if count < fail_count {
Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::empty())
.unwrap()
} else {
// Return a valid shard
let shard = Shard {
packages: Default::default(),
conda_packages: Default::default(),
removed: Default::default(),
experimental_v3: Default::default(),
};
let shard_bytes = rmp_serde::to_vec(&shard).unwrap();
let compressed =
zstd::encode_all(shard_bytes.as_slice(), 3).unwrap();
Response::builder()
.status(StatusCode::OK)
.body(Body::from(compressed))
.unwrap()
}
}
}
}),
);
)
.with_state(state);

let addr = SocketAddr::new([127, 0, 0, 1].into(), 0);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
Expand All @@ -258,6 +328,7 @@ mod tests {

Self {
local_addr,
request_count,
_shutdown_sender: tx,
}
}
Expand All @@ -269,12 +340,20 @@ mod tests {
fn channel(&self) -> Channel {
Channel::from_url(self.url())
}

fn request_count(&self) -> u32 {
self.request_count.load(Ordering::SeqCst)
}
}

#[derive(Clone, Copy)]
enum MockShardResponse {
Empty,
Truncated,
/// Return 429 for the first `fail_count` requests, then succeed.
TooManyRequests {
fail_count: u32,
},
}

#[tokio::test]
Expand Down Expand Up @@ -346,4 +425,73 @@ mod tests {

insta::assert_snapshot!("truncated_shard_response_error", err_string);
}

#[tokio::test]
async fn test_429_retry_succeeds() {
// Server returns 429 twice, then succeeds on the 3rd request
let server =
MockShardedServer::new(MockShardResponse::TooManyRequests { fail_count: 2 }).await;
let channel = server.channel();
let cache_dir = tempfile::tempdir().unwrap();

let client = rattler_networking::LazyClient::default();

let subdir = ShardedSubdir::new(
channel,
"linux-64".to_string(),
client,
cache_dir.path().to_path_buf(),
CacheAction::NoCache,
None,
None,
)
.await
.unwrap();

let package_name = "test-package".parse().unwrap();
let result = subdir.fetch_package_records(&package_name, None).await;

// Should succeed after retries
assert!(
result.is_ok(),
"expected success after retries, got: {result:?}"
);
// Should have made 3 requests (2 failures + 1 success)
assert_eq!(server.request_count(), 3);
}

#[tokio::test]
async fn test_429_retry_exhausted() {
// Server always returns 429 (more failures than retries allow)
let server =
MockShardedServer::new(MockShardResponse::TooManyRequests { fail_count: 100 }).await;
let channel = server.channel();
let cache_dir = tempfile::tempdir().unwrap();

let client = rattler_networking::LazyClient::default();

let subdir = ShardedSubdir::new(
channel,
"linux-64".to_string(),
client,
cache_dir.path().to_path_buf(),
CacheAction::NoCache,
None,
None,
)
.await
.unwrap();

let package_name = "test-package".parse().unwrap();
let result = subdir.fetch_package_records(&package_name, None).await;

// Should fail after exhausting retries
let err = result.expect_err("should fail after exhausting retries");
assert!(
err.to_string().contains("429"),
"error should mention 429: {err}"
);
// default_retry_policy retries 3 times, so 4 total requests (1 initial + 3 retries)
assert_eq!(server.request_count(), 4);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{

use rattler_conda_types::Platform;

use super::{add_trailing_slash, decode_zst_bytes_async, parse_records};
use super::{add_trailing_slash, decode_zst_bytes_async, is_transient_error, parse_records};
use crate::{
fetch::{CacheAction, FetchRepoDataError},
gateway::{
Expand All @@ -22,7 +22,8 @@ use fs_err::tokio as tokio_fs;
use futures::future::OptionFuture;
use http::{header::CACHE_CONTROL, HeaderValue, StatusCode};
use rattler_conda_types::{Channel, PackageName, ShardedRepodata};
use rattler_networking::LazyClient;
use rattler_networking::{retry_policies::default_retry_policy, LazyClient};
use retry_policies::{RetryDecision, RetryPolicy};
use simple_spawn_blocking::tokio::run_blocking_task;
use url::Url;

Expand All @@ -38,6 +39,9 @@ pub struct ShardedSubdir {
concurrent_requests_semaphore: Option<Arc<tokio::sync::Semaphore>>,
cache_dir: PathBuf,
cache_action: CacheAction,
/// Shared backoff deadline. When a 429 is received, this is set so that
/// all concurrent requests to the same host wait before retrying.
backoff_until: Arc<tokio::sync::Mutex<Option<tokio::time::Instant>>>,
}

impl ShardedSubdir {
Expand Down Expand Up @@ -113,6 +117,7 @@ impl ShardedSubdir {
cache_dir,
cache_action,
concurrent_requests_semaphore,
backoff_until: Arc::default(),
})
}

Expand Down Expand Up @@ -210,42 +215,95 @@ impl SubdirClient for ShardedSubdir {
.join(&format!("{shard:x}.msgpack.zst"))
.expect("invalid shard url");

let shard_request = self
.client
.client()
.get(shard_url.clone())
.header(CACHE_CONTROL, HeaderValue::from_static("no-store"))
.build()
.expect("failed to build shard request");
let retry_policy = default_retry_policy();
let mut retry_count = 0u32;

let shard_bytes = loop {
// If another request recently received a 429, wait for the shared
// backoff deadline before sending a new request.
{
let deadline = *self.backoff_until.lock().await;
if let Some(deadline) = deadline {
tokio::time::sleep_until(deadline).await;
}
}

let shard_request = self
.client
.client()
.get(shard_url.clone())
.header(CACHE_CONTROL, HeaderValue::from_static("no-store"))
.build()
.expect("failed to build shard request");

let shard_bytes = {
let _request_permit = OptionFuture::from(
self.concurrent_requests_semaphore
.as_deref()
.map(tokio::sync::Semaphore::acquire),
)
.await;

let request_start = std::time::SystemTime::now();
let reporter = reporter
.and_then(Reporter::download_reporter)
.map(|r| (r, r.on_download_start(&shard_url)));
let shard_response = self
.client
.client()
.execute(shard_request)
.await
.and_then(|r| r.error_for_status().map_err(Into::into))
.map_err(FetchRepoDataError::from)?;

let bytes = shard_response
.bytes_with_progress(reporter)
.await
.map_err(FetchRepoDataError::from)?;

if let Some((reporter, index)) = reporter {
reporter.on_download_complete(&shard_url, index);

let result = async {
let shard_response = self
.client
.client()
.execute(shard_request)
.await
.and_then(|r| r.error_for_status().map_err(Into::into))
.map_err(FetchRepoDataError::from)?;

let bytes = shard_response
.bytes_with_progress(reporter)
.await
.map_err(FetchRepoDataError::from)?;

if let Some((reporter, index)) = reporter {
reporter.on_download_complete(&shard_url, index);
}

Ok::<_, GatewayError>(bytes)
}
.await;

bytes
match result {
Ok(bytes) => break bytes,
Err(err) if is_transient_error(&err) => {
match retry_policy.should_retry(request_start, retry_count) {
RetryDecision::Retry { execute_after } => {
let sleep_duration = execute_after
.duration_since(std::time::SystemTime::now())
.unwrap_or_default();

// Set the shared backoff deadline so other concurrent
// requests also wait instead of hammering the server.
{
let new_deadline = tokio::time::Instant::now() + sleep_duration;
let mut backoff = self.backoff_until.lock().await;
// Only push the deadline forward, never backward.
if backoff.map_or(true, |d| new_deadline > d) {
*backoff = Some(new_deadline);
}
}

tracing::warn!(
"transient error fetching shard {}: {}. Retry #{}, sleeping {sleep_duration:?}...",
shard_url,
err,
retry_count + 1,
);
tokio::time::sleep(sleep_duration).await;
retry_count += 1;
}
RetryDecision::DoNotRetry => return Err(err),
}
}
Err(err) => return Err(err),
}
};

let shard_bytes = decode_zst_bytes_async(shard_bytes, shard_url).await?;
Expand Down
Loading
Loading