@@ -38,10 +38,16 @@ use aws_sdk_s3::{
3838use aws_smithy_async:: rt:: sleep:: default_async_sleep;
3939use bytes:: Bytes ;
4040use serde:: { de:: DeserializeOwned , Serialize } ;
41- use tokio:: sync:: RwLock ;
41+ use tokio:: { sync:: RwLock , task :: JoinSet } ;
4242use tracing:: instrument;
4343use url:: Url ;
4444
45+ /// Chunk size for parallel downloads in bytes (32MB).
46+ const CHUNK_SIZE : usize = 32 * 1024 * 1024 ;
47+
48+ /// Default concurrency for parallel downloads.
49+ const DEFAULT_CONCURRENCY : usize = 32 ;
50+
4551/// S3 Clients that are cached across the entire application.
4652#[ allow( clippy:: type_complexity) ]
4753static S3_CLIENTS : LazyLock < Arc < RwLock < HashMap < String , Arc < S3Client > > > > > =
@@ -164,6 +170,41 @@ impl Artifact {
164170 }
165171 }
166172
173+ /// Downloads raw bytes of an artifact from a URI using parallel downloads.
174+ ///
175+ /// Supports both S3 URIs (s3://bucket/path) and HTTPS URLs. Downloads large
176+ /// files in parallel chunks (32MB each) for improved performance. For S3 URIs,
177+ /// uses byte-range requests via the S3 client. For HTTPS URLs, uses HTTP Range
178+ /// headers if supported by the server, otherwise falls back to sequential download.
179+ ///
180+ /// # Arguments
181+ /// * `uri` - The URI to download from (s3:// or https://)
182+ /// * `s3_region` - The AWS region for S3 operations
183+ /// * `artifact_type` - The type of artifact determining the S3 prefix
184+ /// * `concurrency` - Optional concurrency limit (default: 8)
185+ #[ instrument( fields( label = self . label, id = self . id) , skip_all) ]
186+ pub async fn download_raw_from_uri_par (
187+ & self ,
188+ uri : & str ,
189+ s3_region : & str ,
190+ artifact_type : ArtifactType ,
191+ concurrency : Option < usize > ,
192+ ) -> Result < Bytes > {
193+ let parsed_url = Url :: parse ( uri) . context ( "Failed to parse URI" ) ?;
194+ match parsed_url. scheme ( ) {
195+ "s3" => {
196+ let bucket =
197+ parsed_url. host_str ( ) . ok_or_else ( || anyhow ! ( "S3 URI missing bucket: {uri}" ) ) ?;
198+ let s3_client = get_s3_client ( s3_region) . await ;
199+ download_s3_file_par ( & s3_client, bucket, & self . id , artifact_type, concurrency) . await
200+ }
201+ "https" => download_https_file_par ( uri, concurrency) . await ,
202+ scheme => {
203+ Err ( anyhow ! ( "Unsupported URI scheme for download_raw_from_uri_par: {scheme}" ) )
204+ }
205+ }
206+ }
207+
167208 /// Downloads and deserializes a program artifact from S3.
168209 ///
169210 /// Downloads the program artifact and deserializes it using bincode into the
@@ -431,6 +472,107 @@ async fn download_s3_file(
431472 Ok ( bytes)
432473}
433474
475+ async fn download_s3_file_par (
476+ client : & S3Client ,
477+ bucket : & str ,
478+ id : & str ,
479+ artifact_type : ArtifactType ,
480+ concurrency : Option < usize > ,
481+ ) -> Result < Bytes > {
482+ let key = get_s3_key ( artifact_type, id) ;
483+ let concurrency = concurrency. unwrap_or ( DEFAULT_CONCURRENCY ) ;
484+
485+ let head_res = client
486+ . head_object ( )
487+ . bucket ( bucket)
488+ . key ( & key)
489+ . send ( )
490+ . await
491+ . context ( "Failed to get object metadata from S3" ) ?;
492+
493+ let size = head_res. content_length ( ) . unwrap_or ( 0 ) ;
494+
495+ if size <= 0 {
496+ return Err ( anyhow ! ( "Invalid object size: {size}" ) ) ;
497+ }
498+
499+ if size as usize <= CHUNK_SIZE {
500+ return download_s3_file ( client, bucket, id, artifact_type) . await ;
501+ }
502+
503+ let starts: Vec < ( usize , i64 ) > = ( 0 ..size) . step_by ( CHUNK_SIZE ) . enumerate ( ) . collect ( ) ;
504+
505+ let num_chunks = starts. len ( ) ;
506+ let concurrency = std:: cmp:: min ( concurrency, num_chunks) ;
507+
508+ let mut set = JoinSet :: new ( ) ;
509+ let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( num_chunks) ;
510+
511+ for chunk_group in starts. chunks ( std:: cmp:: max ( num_chunks / concurrency, 1 ) ) {
512+ let client = client. clone ( ) ;
513+ let bucket = bucket. to_string ( ) ;
514+ let key = key. clone ( ) ;
515+ let tx = tx. clone ( ) ;
516+ let chunk_group = chunk_group. to_vec ( ) ;
517+
518+ set. spawn ( async move {
519+ for ( index, start) in chunk_group {
520+ let end = std:: cmp:: min ( start + CHUNK_SIZE as i64 , size) - 1 ;
521+ let range = format ! ( "bytes={start}-{end}" ) ;
522+
523+ let mut retry_count = 0 ;
524+ let max_retries = 5 ;
525+ let mut delay = Duration :: from_secs ( 1 ) ;
526+
527+ loop {
528+ match client. get_object ( ) . bucket ( & bucket) . key ( & key) . range ( & range) . send ( ) . await {
529+ Ok ( res) => {
530+ let data = res. body . collect ( ) . await ?;
531+ let bytes = data. into_bytes ( ) ;
532+ tx. send ( ( index, bytes) ) . await ?;
533+ break ;
534+ }
535+ Err ( e) => {
536+ retry_count += 1 ;
537+ if retry_count >= max_retries {
538+ return Err ( anyhow ! (
539+ "Failed to download S3 chunk after {max_retries} retries: {e}"
540+ ) ) ;
541+ }
542+
543+ tracing:: warn!(
544+ "retry attempt {} for downloading S3 chunk {}: {}" ,
545+ retry_count,
546+ index,
547+ e
548+ ) ;
549+
550+ tokio:: time:: sleep ( delay) . await ;
551+ delay *= 2 ;
552+ }
553+ }
554+ }
555+ }
556+ Ok :: < ( ) , anyhow:: Error > ( ( ) )
557+ } ) ;
558+ }
559+
560+ drop ( tx) ;
561+
562+ let mut result = vec ! [ 0_u8 ; size as usize ] ;
563+ while let Some ( ( index, chunk) ) = rx. recv ( ) . await {
564+ let start = index * CHUNK_SIZE ;
565+ let end = std:: cmp:: min ( start + chunk. len ( ) , size as usize ) ;
566+ result[ start..end] . copy_from_slice ( & chunk) ;
567+ }
568+
569+ while let Some ( res) = set. join_next ( ) . await {
570+ res. context ( "S3 download task panicked" ) ??;
571+ }
572+
573+ Ok ( Bytes :: from ( result) )
574+ }
575+
434576async fn download_https_file ( uri : & str ) -> Result < Bytes > {
435577 let client = reqwest:: Client :: new ( ) ;
436578 let res = client
@@ -446,6 +588,140 @@ async fn download_https_file(uri: &str) -> Result<Bytes> {
446588 Ok ( bytes)
447589}
448590
591+ #[ allow( clippy:: too_many_lines) ]
592+ async fn download_https_file_par ( uri : & str , concurrency : Option < usize > ) -> Result < Bytes > {
593+ let concurrency = concurrency. unwrap_or ( DEFAULT_CONCURRENCY ) ;
594+ let client = reqwest:: Client :: new ( ) ;
595+
596+ let head_res = client
597+ . head ( uri)
598+ . timeout ( Duration :: from_secs ( 30 ) )
599+ . send ( )
600+ . await
601+ . context ( "Failed to HEAD HTTPS URL" ) ?;
602+
603+ if !head_res. status ( ) . is_success ( ) {
604+ return Err ( anyhow ! (
605+ "Failed to get metadata from HTTPS URL {uri}: status {}" ,
606+ head_res. status( )
607+ ) ) ;
608+ }
609+
610+ let supports_range = head_res
611+ . headers ( )
612+ . get ( reqwest:: header:: ACCEPT_RANGES )
613+ . and_then ( |v| v. to_str ( ) . ok ( ) )
614+ . is_some_and ( |v| v == "bytes" ) ;
615+
616+ let size = head_res
617+ . headers ( )
618+ . get ( reqwest:: header:: CONTENT_LENGTH )
619+ . and_then ( |v| v. to_str ( ) . ok ( ) )
620+ . and_then ( |v| v. parse :: < usize > ( ) . ok ( ) ) ;
621+
622+ if !supports_range || size. is_none ( ) {
623+ return download_https_file ( uri) . await ;
624+ }
625+
626+ let size = size. unwrap ( ) ;
627+
628+ if size <= CHUNK_SIZE {
629+ return download_https_file ( uri) . await ;
630+ }
631+
632+ let starts: Vec < ( usize , usize ) > = ( 0 ..size) . step_by ( CHUNK_SIZE ) . enumerate ( ) . collect ( ) ;
633+
634+ let num_chunks = starts. len ( ) ;
635+ let concurrency = std:: cmp:: min ( concurrency, num_chunks) ;
636+
637+ let mut set = JoinSet :: new ( ) ;
638+ let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( num_chunks) ;
639+
640+ for chunk_group in starts. chunks ( std:: cmp:: max ( num_chunks / concurrency, 1 ) ) {
641+ let uri = uri. to_string ( ) ;
642+ let tx = tx. clone ( ) ;
643+ let chunk_group = chunk_group. to_vec ( ) ;
644+
645+ set. spawn ( async move {
646+ let client = reqwest:: Client :: new ( ) ;
647+ for ( index, start) in chunk_group {
648+ let end = std:: cmp:: min ( start + CHUNK_SIZE , size) - 1 ;
649+ let range = format ! ( "bytes={start}-{end}" ) ;
650+
651+ let mut retry_count = 0 ;
652+ let max_retries = 5 ;
653+ let mut delay = Duration :: from_secs ( 1 ) ;
654+
655+ loop {
656+ match client
657+ . get ( & uri)
658+ . header ( reqwest:: header:: RANGE , & range)
659+ . timeout ( Duration :: from_secs ( 90 ) )
660+ . send ( )
661+ . await
662+ {
663+ Ok ( res) => {
664+ if !res. status ( ) . is_success ( ) && res. status ( ) != reqwest:: StatusCode :: PARTIAL_CONTENT {
665+ retry_count += 1 ;
666+ if retry_count >= max_retries {
667+ return Err ( anyhow ! ( "Failed to download HTTPS chunk after {max_retries} retries: status {}" , res. status( ) ) ) ;
668+ }
669+
670+ tracing:: warn!(
671+ "retry attempt {} for downloading HTTPS chunk {}: status {}" ,
672+ retry_count,
673+ index,
674+ res. status( )
675+ ) ;
676+
677+ tokio:: time:: sleep ( delay) . await ;
678+ delay *= 2 ;
679+ continue ;
680+ }
681+
682+ let bytes = res. bytes ( ) . await ?;
683+ tx. send ( ( index, bytes) ) . await ?;
684+ break ;
685+ }
686+ Err ( e) => {
687+ retry_count += 1 ;
688+ if retry_count >= max_retries {
689+ return Err ( anyhow ! ( "Failed to download HTTPS chunk after {max_retries} retries: {e}" ) ) ;
690+ }
691+
692+ tracing:: warn!(
693+ "retry attempt {} for downloading HTTPS chunk {}: {}" ,
694+ retry_count,
695+ index,
696+ e
697+ ) ;
698+
699+ tokio:: time:: sleep ( delay) . await ;
700+ delay *= 2 ;
701+ }
702+ }
703+ }
704+ }
705+ Ok :: < ( ) , anyhow:: Error > ( ( ) )
706+ } ) ;
707+ }
708+
709+ drop ( tx) ;
710+
711+ let mut result = vec ! [ 0_u8 ; size] ;
712+ while let Some ( ( index, chunk) ) = rx. recv ( ) . await {
713+ let start = index * CHUNK_SIZE ;
714+ let end = std:: cmp:: min ( start + chunk. len ( ) , size) ;
715+ result[ start..end] . copy_from_slice ( & chunk) ;
716+ }
717+
718+ while let Some ( res) = set. join_next ( ) . await {
719+ res. context ( "HTTPS download task panicked" ) ??;
720+ }
721+
722+ Ok ( Bytes :: from ( result) )
723+ }
724+
449725async fn upload_file (
450726 client : & S3Client ,
451727 bucket : & str ,
0 commit comments