@@ -33,7 +33,7 @@ use crate::client::BallistaClient;
3333use crate :: execution_plans:: sort_shuffle:: {
3434 get_index_path, is_sort_shuffle_output, stream_sort_shuffle_partition,
3535} ;
36- use crate :: extension:: SessionConfigExt ;
36+ use crate :: extension:: { BallistaConfigGrpcEndpoint , SessionConfigExt } ;
3737use crate :: serde:: scheduler:: { PartitionLocation , PartitionStats } ;
3838
3939use datafusion:: arrow:: datatypes:: SchemaRef ;
@@ -169,6 +169,8 @@ impl ExecutionPlan for ShuffleReaderExec {
169169 let force_remote_read = config. ballista_shuffle_reader_force_remote_read ( ) ;
170170 let prefer_flight = config. ballista_shuffle_reader_remote_prefer_flight ( ) ;
171171 let batch_size = config. batch_size ( ) ;
172+ let customize_endpoint = config. ballista_override_create_grpc_client_endpoint ( ) ;
173+ let use_tls = config. ballista_use_tls ( ) ;
172174
173175 if force_remote_read {
174176 debug ! (
@@ -202,6 +204,8 @@ impl ExecutionPlan for ShuffleReaderExec {
202204 max_message_size,
203205 force_remote_read,
204206 prefer_flight,
207+ customize_endpoint,
208+ use_tls,
205209 ) ;
206210
207211 let input_stream = Box :: pin ( RecordBatchStreamAdapter :: new (
@@ -404,6 +408,8 @@ fn send_fetch_partitions(
404408 max_message_size : usize ,
405409 force_remote_read : bool ,
406410 flight_transport : bool ,
411+ customize_endpoint : Option < Arc < BallistaConfigGrpcEndpoint > > ,
412+ use_tls : bool ,
407413) -> AbortableReceiverStream {
408414 let ( response_sender, response_receiver) = mpsc:: channel ( max_request_num) ;
409415 let semaphore = Arc :: new ( Semaphore :: new ( max_request_num) ) ;
@@ -420,10 +426,17 @@ fn send_fetch_partitions(
420426
421427 // keep local shuffle files reading in serial order for memory control.
422428 let response_sender_c = response_sender. clone ( ) ;
429+ let customize_endpoint_c = customize_endpoint. clone ( ) ;
423430 spawned_tasks. push ( SpawnedTask :: spawn ( async move {
424431 for p in local_locations {
425432 let r = PartitionReaderEnum :: Local
426- . fetch_partition ( & p, max_message_size, flight_transport)
433+ . fetch_partition (
434+ & p,
435+ max_message_size,
436+ flight_transport,
437+ customize_endpoint_c. clone ( ) ,
438+ use_tls,
439+ )
427440 . await ;
428441 if let Err ( e) = response_sender_c. send ( r) . await {
429442 error ! ( "Fail to send response event to the channel due to {e}" ) ;
@@ -434,11 +447,18 @@ fn send_fetch_partitions(
434447 for p in remote_locations. into_iter ( ) {
435448 let semaphore = semaphore. clone ( ) ;
436449 let response_sender = response_sender. clone ( ) ;
450+ let customize_endpoint_c = customize_endpoint. clone ( ) ;
437451 spawned_tasks. push ( SpawnedTask :: spawn ( async move {
438452 // Block if exceeds max request number.
439453 let permit = semaphore. acquire_owned ( ) . await . unwrap ( ) ;
440454 let r = PartitionReaderEnum :: FlightRemote
441- . fetch_partition ( & p, max_message_size, flight_transport)
455+ . fetch_partition (
456+ & p,
457+ max_message_size,
458+ flight_transport,
459+ customize_endpoint_c,
460+ use_tls,
461+ )
442462 . await ;
443463 // Block if the channel buffer is full.
444464 if let Err ( e) = response_sender. send ( r) . await {
@@ -465,6 +485,8 @@ trait PartitionReader: Send + Sync + Clone {
465485 location : & PartitionLocation ,
466486 max_message_size : usize ,
467487 flight_transport : bool ,
488+ customize_endpoint : Option < Arc < BallistaConfigGrpcEndpoint > > ,
489+ use_tls : bool ,
468490 ) -> result:: Result < SendableRecordBatchStream , BallistaError > ;
469491}
470492
@@ -484,10 +506,19 @@ impl PartitionReader for PartitionReaderEnum {
484506 location : & PartitionLocation ,
485507 max_message_size : usize ,
486508 flight_transport : bool ,
509+ customize_endpoint : Option < Arc < BallistaConfigGrpcEndpoint > > ,
510+ use_tls : bool ,
487511 ) -> result:: Result < SendableRecordBatchStream , BallistaError > {
488512 match self {
489513 PartitionReaderEnum :: FlightRemote => {
490- fetch_partition_remote ( location, max_message_size, flight_transport) . await
514+ fetch_partition_remote (
515+ location,
516+ max_message_size,
517+ flight_transport,
518+ customize_endpoint,
519+ use_tls,
520+ )
521+ . await
491522 }
492523 PartitionReaderEnum :: Local => fetch_partition_local ( location) . await ,
493524 PartitionReaderEnum :: ObjectStoreRemote => {
@@ -501,25 +532,33 @@ async fn fetch_partition_remote(
501532 location : & PartitionLocation ,
502533 max_message_size : usize ,
503534 flight_transport : bool ,
535+ customize_endpoint : Option < Arc < BallistaConfigGrpcEndpoint > > ,
536+ use_tls : bool ,
504537) -> result:: Result < SendableRecordBatchStream , BallistaError > {
505538 let metadata = & location. executor_meta ;
506539 let partition_id = & location. partition_id ;
507540 // TODO for shuffle client connections, we should avoid creating new connections again and again.
508541 // And we should also avoid to keep alive too many connections for long time.
509542 let host = metadata. host . as_str ( ) ;
510543 let port = metadata. port ;
511- let mut ballista_client = BallistaClient :: try_new ( host, port, max_message_size)
512- . await
513- . map_err ( |error| match error {
514- // map grpc connection error to partition fetch error.
515- BallistaError :: GrpcConnectionError ( msg) => BallistaError :: FetchFailed (
516- metadata. id . clone ( ) ,
517- partition_id. stage_id ,
518- partition_id. partition_id ,
519- msg,
520- ) ,
521- other => other,
522- } ) ?;
544+ let mut ballista_client = BallistaClient :: try_new (
545+ host,
546+ port,
547+ max_message_size,
548+ use_tls,
549+ customize_endpoint,
550+ )
551+ . await
552+ . map_err ( |error| match error {
553+ // map grpc connection error to partition fetch error.
554+ BallistaError :: GrpcConnectionError ( msg) => BallistaError :: FetchFailed (
555+ metadata. id . clone ( ) ,
556+ partition_id. stage_id ,
557+ partition_id. partition_id ,
558+ msg,
559+ ) ,
560+ other => other,
561+ } ) ?;
523562
524563 ballista_client
525564 . fetch_partition (
@@ -1087,6 +1126,8 @@ mod tests {
10871126 4 * 1024 * 1024 ,
10881127 false ,
10891128 true ,
1129+ None ,
1130+ false ,
10901131 ) ;
10911132
10921133 let stream = RecordBatchStreamAdapter :: new (
0 commit comments