Skip to content

Commit 475fc60

Browse files
committed
parallel downloads
1 parent e9d21a3 commit 475fc60

File tree

1 file changed

+277
-1
lines changed
  • crates/network/artifacts/src

1 file changed

+277
-1
lines changed

crates/network/artifacts/src/lib.rs

Lines changed: 277 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,16 @@ use aws_sdk_s3::{
3838
use aws_smithy_async::rt::sleep::default_async_sleep;
3939
use bytes::Bytes;
4040
use serde::{de::DeserializeOwned, Serialize};
41-
use tokio::sync::RwLock;
41+
use tokio::{sync::RwLock, task::JoinSet};
4242
use tracing::instrument;
4343
use url::Url;
4444

45+
/// Chunk size for parallel downloads in bytes (32MB).
46+
const CHUNK_SIZE: usize = 32 * 1024 * 1024;
47+
48+
/// Default concurrency for parallel downloads.
49+
const DEFAULT_CONCURRENCY: usize = 32;
50+
4551
/// S3 Clients that are cached across the entire application.
4652
#[allow(clippy::type_complexity)]
4753
static S3_CLIENTS: LazyLock<Arc<RwLock<HashMap<String, Arc<S3Client>>>>> =
@@ -164,6 +170,41 @@ impl Artifact {
164170
}
165171
}
166172

173+
/// Downloads raw bytes of an artifact from a URI using parallel downloads.
174+
///
175+
/// Supports both S3 URIs (s3://bucket/path) and HTTPS URLs. Downloads large
176+
/// files in parallel chunks (32MB each) for improved performance. For S3 URIs,
177+
/// uses byte-range requests via the S3 client. For HTTPS URLs, uses HTTP Range
178+
/// headers if supported by the server, otherwise falls back to sequential download.
179+
///
180+
/// # Arguments
181+
/// * `uri` - The URI to download from (s3:// or https://)
182+
/// * `s3_region` - The AWS region for S3 operations
183+
/// * `artifact_type` - The type of artifact determining the S3 prefix
184+
/// * `concurrency` - Optional concurrency limit (default: 8)
185+
#[instrument(fields(label = self.label, id = self.id), skip_all)]
186+
pub async fn download_raw_from_uri_par(
187+
&self,
188+
uri: &str,
189+
s3_region: &str,
190+
artifact_type: ArtifactType,
191+
concurrency: Option<usize>,
192+
) -> Result<Bytes> {
193+
let parsed_url = Url::parse(uri).context("Failed to parse URI")?;
194+
match parsed_url.scheme() {
195+
"s3" => {
196+
let bucket =
197+
parsed_url.host_str().ok_or_else(|| anyhow!("S3 URI missing bucket: {uri}"))?;
198+
let s3_client = get_s3_client(s3_region).await;
199+
download_s3_file_par(&s3_client, bucket, &self.id, artifact_type, concurrency).await
200+
}
201+
"https" => download_https_file_par(uri, concurrency).await,
202+
scheme => {
203+
Err(anyhow!("Unsupported URI scheme for download_raw_from_uri_par: {scheme}"))
204+
}
205+
}
206+
}
207+
167208
/// Downloads and deserializes a program artifact from S3.
168209
///
169210
/// Downloads the program artifact and deserializes it using bincode into the
@@ -431,6 +472,107 @@ async fn download_s3_file(
431472
Ok(bytes)
432473
}
433474

475+
async fn download_s3_file_par(
476+
client: &S3Client,
477+
bucket: &str,
478+
id: &str,
479+
artifact_type: ArtifactType,
480+
concurrency: Option<usize>,
481+
) -> Result<Bytes> {
482+
let key = get_s3_key(artifact_type, id);
483+
let concurrency = concurrency.unwrap_or(DEFAULT_CONCURRENCY);
484+
485+
let head_res = client
486+
.head_object()
487+
.bucket(bucket)
488+
.key(&key)
489+
.send()
490+
.await
491+
.context("Failed to get object metadata from S3")?;
492+
493+
let size = head_res.content_length().unwrap_or(0);
494+
495+
if size <= 0 {
496+
return Err(anyhow!("Invalid object size: {size}"));
497+
}
498+
499+
if size as usize <= CHUNK_SIZE {
500+
return download_s3_file(client, bucket, id, artifact_type).await;
501+
}
502+
503+
let starts: Vec<(usize, i64)> = (0..size).step_by(CHUNK_SIZE).enumerate().collect();
504+
505+
let num_chunks = starts.len();
506+
let concurrency = std::cmp::min(concurrency, num_chunks);
507+
508+
let mut set = JoinSet::new();
509+
let (tx, mut rx) = tokio::sync::mpsc::channel(num_chunks);
510+
511+
for chunk_group in starts.chunks(std::cmp::max(num_chunks / concurrency, 1)) {
512+
let client = client.clone();
513+
let bucket = bucket.to_string();
514+
let key = key.clone();
515+
let tx = tx.clone();
516+
let chunk_group = chunk_group.to_vec();
517+
518+
set.spawn(async move {
519+
for (index, start) in chunk_group {
520+
let end = std::cmp::min(start + CHUNK_SIZE as i64, size) - 1;
521+
let range = format!("bytes={start}-{end}");
522+
523+
let mut retry_count = 0;
524+
let max_retries = 5;
525+
let mut delay = Duration::from_secs(1);
526+
527+
loop {
528+
match client.get_object().bucket(&bucket).key(&key).range(&range).send().await {
529+
Ok(res) => {
530+
let data = res.body.collect().await?;
531+
let bytes = data.into_bytes();
532+
tx.send((index, bytes)).await?;
533+
break;
534+
}
535+
Err(e) => {
536+
retry_count += 1;
537+
if retry_count >= max_retries {
538+
return Err(anyhow!(
539+
"Failed to download S3 chunk after {max_retries} retries: {e}"
540+
));
541+
}
542+
543+
tracing::warn!(
544+
"retry attempt {} for downloading S3 chunk {}: {}",
545+
retry_count,
546+
index,
547+
e
548+
);
549+
550+
tokio::time::sleep(delay).await;
551+
delay *= 2;
552+
}
553+
}
554+
}
555+
}
556+
Ok::<(), anyhow::Error>(())
557+
});
558+
}
559+
560+
drop(tx);
561+
562+
let mut result = vec![0_u8; size as usize];
563+
while let Some((index, chunk)) = rx.recv().await {
564+
let start = index * CHUNK_SIZE;
565+
let end = std::cmp::min(start + chunk.len(), size as usize);
566+
result[start..end].copy_from_slice(&chunk);
567+
}
568+
569+
while let Some(res) = set.join_next().await {
570+
res.context("S3 download task panicked")??;
571+
}
572+
573+
Ok(Bytes::from(result))
574+
}
575+
434576
async fn download_https_file(uri: &str) -> Result<Bytes> {
435577
let client = reqwest::Client::new();
436578
let res = client
@@ -446,6 +588,140 @@ async fn download_https_file(uri: &str) -> Result<Bytes> {
446588
Ok(bytes)
447589
}
448590

591+
#[allow(clippy::too_many_lines)]
592+
async fn download_https_file_par(uri: &str, concurrency: Option<usize>) -> Result<Bytes> {
593+
let concurrency = concurrency.unwrap_or(DEFAULT_CONCURRENCY);
594+
let client = reqwest::Client::new();
595+
596+
let head_res = client
597+
.head(uri)
598+
.timeout(Duration::from_secs(30))
599+
.send()
600+
.await
601+
.context("Failed to HEAD HTTPS URL")?;
602+
603+
if !head_res.status().is_success() {
604+
return Err(anyhow!(
605+
"Failed to get metadata from HTTPS URL {uri}: status {}",
606+
head_res.status()
607+
));
608+
}
609+
610+
let supports_range = head_res
611+
.headers()
612+
.get(reqwest::header::ACCEPT_RANGES)
613+
.and_then(|v| v.to_str().ok())
614+
.is_some_and(|v| v == "bytes");
615+
616+
let size = head_res
617+
.headers()
618+
.get(reqwest::header::CONTENT_LENGTH)
619+
.and_then(|v| v.to_str().ok())
620+
.and_then(|v| v.parse::<usize>().ok());
621+
622+
if !supports_range || size.is_none() {
623+
return download_https_file(uri).await;
624+
}
625+
626+
let size = size.unwrap();
627+
628+
if size <= CHUNK_SIZE {
629+
return download_https_file(uri).await;
630+
}
631+
632+
let starts: Vec<(usize, usize)> = (0..size).step_by(CHUNK_SIZE).enumerate().collect();
633+
634+
let num_chunks = starts.len();
635+
let concurrency = std::cmp::min(concurrency, num_chunks);
636+
637+
let mut set = JoinSet::new();
638+
let (tx, mut rx) = tokio::sync::mpsc::channel(num_chunks);
639+
640+
for chunk_group in starts.chunks(std::cmp::max(num_chunks / concurrency, 1)) {
641+
let uri = uri.to_string();
642+
let tx = tx.clone();
643+
let chunk_group = chunk_group.to_vec();
644+
645+
set.spawn(async move {
646+
let client = reqwest::Client::new();
647+
for (index, start) in chunk_group {
648+
let end = std::cmp::min(start + CHUNK_SIZE, size) - 1;
649+
let range = format!("bytes={start}-{end}");
650+
651+
let mut retry_count = 0;
652+
let max_retries = 5;
653+
let mut delay = Duration::from_secs(1);
654+
655+
loop {
656+
match client
657+
.get(&uri)
658+
.header(reqwest::header::RANGE, &range)
659+
.timeout(Duration::from_secs(90))
660+
.send()
661+
.await
662+
{
663+
Ok(res) => {
664+
if !res.status().is_success() && res.status() != reqwest::StatusCode::PARTIAL_CONTENT {
665+
retry_count += 1;
666+
if retry_count >= max_retries {
667+
return Err(anyhow!("Failed to download HTTPS chunk after {max_retries} retries: status {}", res.status()));
668+
}
669+
670+
tracing::warn!(
671+
"retry attempt {} for downloading HTTPS chunk {}: status {}",
672+
retry_count,
673+
index,
674+
res.status()
675+
);
676+
677+
tokio::time::sleep(delay).await;
678+
delay *= 2;
679+
continue;
680+
}
681+
682+
let bytes = res.bytes().await?;
683+
tx.send((index, bytes)).await?;
684+
break;
685+
}
686+
Err(e) => {
687+
retry_count += 1;
688+
if retry_count >= max_retries {
689+
return Err(anyhow!("Failed to download HTTPS chunk after {max_retries} retries: {e}"));
690+
}
691+
692+
tracing::warn!(
693+
"retry attempt {} for downloading HTTPS chunk {}: {}",
694+
retry_count,
695+
index,
696+
e
697+
);
698+
699+
tokio::time::sleep(delay).await;
700+
delay *= 2;
701+
}
702+
}
703+
}
704+
}
705+
Ok::<(), anyhow::Error>(())
706+
});
707+
}
708+
709+
drop(tx);
710+
711+
let mut result = vec![0_u8; size];
712+
while let Some((index, chunk)) = rx.recv().await {
713+
let start = index * CHUNK_SIZE;
714+
let end = std::cmp::min(start + chunk.len(), size);
715+
result[start..end].copy_from_slice(&chunk);
716+
}
717+
718+
while let Some(res) = set.join_next().await {
719+
res.context("HTTPS download task panicked")??;
720+
}
721+
722+
Ok(Bytes::from(result))
723+
}
724+
449725
async fn upload_file(
450726
client: &S3Client,
451727
bucket: &str,

0 commit comments

Comments
 (0)