@@ -22,12 +22,21 @@ use std::any::Any;
2222use std:: collections:: HashMap ;
2323use std:: fmt:: Debug ;
2424use std:: fs:: File ;
25- use std:: io:: BufReader ;
25+ use std:: io:: { BufReader , Cursor } ;
2626use std:: pin:: Pin ;
2727use std:: result;
2828use std:: sync:: Arc ;
2929use std:: task:: { Context , Poll } ;
3030
31+ #[ cfg( feature = "build-binary" ) ]
32+ use object_store:: aws:: AmazonS3Builder ;
33+ #[ cfg( feature = "build-binary" ) ]
34+ use object_store:: azure:: MicrosoftAzureBuilder ;
35+ #[ cfg( feature = "build-binary" ) ]
36+ use object_store:: ObjectStore ;
37+ #[ cfg( feature = "build-binary" ) ]
38+ use url:: Url ;
39+
3140use crate :: client:: BallistaClient ;
3241use crate :: extension:: { BallistaConfigGrpcEndpoint , SessionConfigExt } ;
3342use crate :: serde:: scheduler:: { PartitionLocation , PartitionStats } ;
@@ -371,23 +380,34 @@ impl Stream for AbortableReceiverStream {
371380 . map_err ( |e| ArrowError :: ExternalError ( Box :: new ( e) ) )
372381 }
373382}
374- /// Splits the provided partition locations into local and remote partitions.
383+ /// Splits the provided partition locations into local, object store, and remote partitions.
375384/// Local partitions are read directly from local Arrow IPC files,
385+ /// object store partitions are read via the object store client,
376386/// while remote partitions are fetched using the Arrow Flight client.
377387/// If `force_remote_read` is true, all partitions are treated as remote.
378388fn local_remote_read_split (
379389 partition_locations : Vec < PartitionLocation > ,
380390 force_remote_read : bool ,
381- ) -> ( Vec < PartitionLocation > , Vec < PartitionLocation > ) {
391+ ) -> ( Vec < PartitionLocation > , Vec < PartitionLocation > , Vec < PartitionLocation > ) {
382392 if !force_remote_read {
383- partition_locations
393+ let ( local, non_local) : ( Vec < _ > , Vec < _ > ) = partition_locations
394+ . into_iter ( )
395+ . partition ( check_is_local_location) ;
396+ let ( object_store, remote) : ( Vec < _ > , Vec < _ > ) = non_local
384397 . into_iter ( )
385- . partition ( check_is_local_location)
398+ . partition ( check_is_object_store_location) ;
399+ ( local, object_store, remote)
386400 } else {
387- ( vec ! [ ] , partition_locations)
401+ ( vec ! [ ] , vec ! [ ] , partition_locations)
388402 }
389403}
390404
405+ /// Check if the location is an object store path (S3 or Azure).
406+ fn check_is_object_store_location ( location : & PartitionLocation ) -> bool {
407+ let path = location. path . as_str ( ) ;
408+ path. starts_with ( "s3://" ) || path. starts_with ( "abfs://" ) || path. starts_with ( "az://" )
409+ }
410+
391411fn send_fetch_partitions (
392412 partition_locations : Vec < PartitionLocation > ,
393413 max_request_num : usize ,
@@ -401,12 +421,13 @@ fn send_fetch_partitions(
401421 let semaphore = Arc :: new ( Semaphore :: new ( max_request_num) ) ;
402422 let mut spawned_tasks: Vec < SpawnedTask < ( ) > > = vec ! [ ] ;
403423
404- let ( local_locations, remote_locations) : ( Vec < _ > , Vec < _ > ) =
424+ let ( local_locations, object_store_locations , remote_locations) : ( Vec < _ > , Vec < _ > , Vec < _ > ) =
405425 local_remote_read_split ( partition_locations, force_remote_read) ;
406426
407427 debug ! (
408- "local shuffle file counts:{}, remote shuffle file count:{}." ,
428+ "local shuffle file counts:{}, object store shuffle file count:{}, remote shuffle file count:{}." ,
409429 local_locations. len( ) ,
430+ object_store_locations. len( ) ,
410431 remote_locations. len( )
411432 ) ;
412433
@@ -430,6 +451,31 @@ fn send_fetch_partitions(
430451 }
431452 } ) ) ;
432453
454+ // Handle object store partitions with concurrency control
455+ for p in object_store_locations. into_iter ( ) {
456+ let semaphore = semaphore. clone ( ) ;
457+ let response_sender = response_sender. clone ( ) ;
458+ spawned_tasks. push ( SpawnedTask :: spawn ( async move {
459+ // Block if exceeds max request number.
460+ let permit = semaphore. acquire_owned ( ) . await . unwrap ( ) ;
461+ let r = PartitionReaderEnum :: ObjectStoreRemote
462+ . fetch_partition (
463+ & p,
464+ max_message_size,
465+ false , // flight_transport not used for object store
466+ None , // customize_endpoint not used for object store
467+ false , // use_tls not used for object store
468+ )
469+ . await ;
470+ // Block if the channel buffer is full.
471+ if let Err ( e) = response_sender. send ( r) . await {
472+ error ! ( "Fail to send response event to the channel due to {e}" ) ;
473+ }
474+ // Increase semaphore by dropping existing permits.
475+ drop ( permit) ;
476+ } ) ) ;
477+ }
478+
433479 for p in remote_locations. into_iter ( ) {
434480 let semaphore = semaphore. clone ( ) ;
435481 let response_sender = response_sender. clone ( ) ;
@@ -590,14 +636,143 @@ fn fetch_partition_local_inner(
590636 Ok ( reader)
591637}
592638
639+ #[ cfg( feature = "build-binary" ) ]
640+ async fn fetch_partition_object_store (
641+ location : & PartitionLocation ,
642+ ) -> result:: Result < SendableRecordBatchStream , BallistaError > {
643+ let path = & location. path ;
644+ let metadata = & location. executor_meta ;
645+ let partition_id = & location. partition_id ;
646+
647+ debug ! ( "Fetching shuffle partition from object store: {}" , path) ;
648+
649+ let batches = fetch_partition_object_store_inner ( path) . await . map_err ( |e| {
650+ // return BallistaError::FetchFailed may let scheduler retry this task.
651+ BallistaError :: FetchFailed (
652+ metadata. id . clone ( ) ,
653+ partition_id. stage_id ,
654+ partition_id. partition_id ,
655+ e. to_string ( ) ,
656+ )
657+ } ) ?;
658+
659+ if batches. is_empty ( ) {
660+ return Err ( BallistaError :: General ( format ! (
661+ "No batches found in shuffle partition at {}" ,
662+ path
663+ ) ) ) ;
664+ }
665+
666+ let schema = batches[ 0 ] . schema ( ) ;
667+ let stream = futures:: stream:: iter ( batches. into_iter ( ) . map ( Ok ) ) ;
668+ Ok ( Box :: pin ( RecordBatchStreamAdapter :: new ( schema, stream) ) )
669+ }
670+
671+ #[ cfg( not( feature = "build-binary" ) ) ]
593672async fn fetch_partition_object_store (
594673 _location : & PartitionLocation ,
595674) -> result:: Result < SendableRecordBatchStream , BallistaError > {
596675 Err ( BallistaError :: NotImplemented (
597- "Should not use ObjectStorePartitionReader " . to_string ( ) ,
676+ "Object store support requires 'build-binary' feature " . to_string ( ) ,
598677 ) )
599678}
600679
680+ #[ cfg( feature = "build-binary" ) ]
681+ async fn fetch_partition_object_store_inner (
682+ path : & str ,
683+ ) -> result:: Result < Vec < RecordBatch > , BallistaError > {
684+ use object_store:: path:: Path as ObjectPath ;
685+
686+ let url = Url :: parse ( path) . map_err ( |e| {
687+ BallistaError :: General ( format ! ( "Failed to parse object store URL '{}': {:?}" , path, e) )
688+ } ) ?;
689+
690+ let scheme = url. scheme ( ) ;
691+ let store: Arc < dyn ObjectStore > = match scheme {
692+ "s3" => {
693+ let bucket = url. host_str ( ) . ok_or_else ( || {
694+ BallistaError :: General ( format ! ( "No bucket in S3 URL: {}" , path) )
695+ } ) ?;
696+ let builder = AmazonS3Builder :: from_env ( ) . with_bucket_name ( bucket) ;
697+ Arc :: new ( builder. build ( ) . map_err ( |e| {
698+ BallistaError :: General ( format ! ( "Failed to create S3 client: {:?}" , e) )
699+ } ) ?)
700+ }
701+ "abfs" | "az" => {
702+ // Parse Azure URL: abfs://container@account.dfs.core.windows.net/path
703+ let host = url. host_str ( ) . ok_or_else ( || {
704+ BallistaError :: General ( format ! ( "No host in Azure URL: {}" , path) )
705+ } ) ?;
706+
707+ // Extract container from username portion
708+ let container = url. username ( ) ;
709+ if container. is_empty ( ) {
710+ return Err ( BallistaError :: General ( format ! (
711+ "No container in Azure URL. Expected format: abfs://container@account.dfs.core.windows.net/path. Got: {}" ,
712+ path
713+ ) ) ) ;
714+ }
715+
716+ // Extract account from host (account.dfs.core.windows.net)
717+ let account = host. split ( '.' ) . next ( ) . ok_or_else ( || {
718+ BallistaError :: General ( format ! ( "No account in Azure URL: {}" , path) )
719+ } ) ?;
720+
721+ let builder = MicrosoftAzureBuilder :: from_env ( )
722+ . with_account ( account)
723+ . with_container_name ( container) ;
724+ Arc :: new ( builder. build ( ) . map_err ( |e| {
725+ BallistaError :: General ( format ! ( "Failed to create Azure client: {:?}" , e) )
726+ } ) ?)
727+ }
728+ _ => {
729+ return Err ( BallistaError :: General ( format ! (
730+ "Unsupported object store scheme: {}. Supported: s3, abfs, az" ,
731+ scheme
732+ ) ) ) ;
733+ }
734+ } ;
735+
736+ // Extract the object path from the URL
737+ let object_path = ObjectPath :: from ( url. path ( ) . trim_start_matches ( '/' ) ) ;
738+
739+ debug ! ( "Reading object from path: {:?}" , object_path) ;
740+
741+ let get_result = store. get ( & object_path) . await . map_err ( |e| {
742+ BallistaError :: General ( format ! (
743+ "Failed to read object from {}: {:?}" ,
744+ path, e
745+ ) )
746+ } ) ?;
747+
748+ let bytes = get_result. bytes ( ) . await . map_err ( |e| {
749+ BallistaError :: General ( format ! (
750+ "Failed to read bytes from {}: {:?}" ,
751+ path, e
752+ ) )
753+ } ) ?;
754+
755+ let cursor = Cursor :: new ( bytes. to_vec ( ) ) ;
756+ let stream_reader = StreamReader :: try_new ( cursor, None ) . map_err ( |e| {
757+ BallistaError :: General ( format ! (
758+ "Failed to create Arrow stream reader for {}: {:?}" ,
759+ path, e
760+ ) )
761+ } ) ?;
762+
763+ let mut batches = Vec :: new ( ) ;
764+ for batch_result in stream_reader {
765+ batches. push ( batch_result. map_err ( |e| {
766+ BallistaError :: General ( format ! (
767+ "Failed to read batch from {}: {:?}" ,
768+ path, e
769+ ) )
770+ } ) ?) ;
771+ }
772+
773+ Ok ( batches)
774+ }
775+
601776#[ cfg( test) ]
602777mod tests {
603778 use super :: * ;
@@ -955,14 +1130,16 @@ mod tests {
9551130 let partition_locations =
9561131 get_test_partition_locations ( 1 , file_path. to_str ( ) . unwrap ( ) . to_string ( ) ) ;
9571132
958- let ( local, remote) = local_remote_read_split ( partition_locations. clone ( ) , false ) ;
1133+ let ( local, object_store , remote) = local_remote_read_split ( partition_locations. clone ( ) , false ) ;
9591134
9601135 assert ! ( !local. is_empty( ) ) ;
1136+ assert ! ( object_store. is_empty( ) ) ;
9611137 assert ! ( remote. is_empty( ) ) ;
9621138
963- let ( local, remote) = local_remote_read_split ( partition_locations, true ) ;
1139+ let ( local, object_store , remote) = local_remote_read_split ( partition_locations, true ) ;
9641140
9651141 assert ! ( local. is_empty( ) ) ;
1142+ assert ! ( object_store. is_empty( ) ) ;
9661143 assert ! ( !remote. is_empty( ) ) ;
9671144 }
9681145
0 commit comments