@@ -22,13 +22,15 @@ use arrow_flight::{
2222 HandshakeRequest , HandshakeResponse , PollInfo , PutResult , SchemaResult , Ticket ,
2323} ;
2424use ballista_core:: error:: BallistaError ;
25+ use ballista_core:: extension:: BallistaConfigGrpcEndpoint ;
2526use ballista_core:: serde:: decode_protobuf;
2627use ballista_core:: serde:: scheduler:: Action as BallistaAction ;
27- use ballista_core:: utils:: { GrpcClientConfig , create_grpc_client_connection } ;
28+ use ballista_core:: utils:: { GrpcClientConfig , create_grpc_client_endpoint } ;
2829
2930use futures:: { Stream , TryFutureExt } ;
3031use log:: debug;
3132use std:: pin:: Pin ;
33+ use std:: sync:: Arc ;
3234use tonic:: { Request , Response , Status , Streaming } ;
3335
3436/// Service implementing a proxy from scheduler to executor Apache Arrow Flight Protocol
@@ -40,16 +42,24 @@ use tonic::{Request, Response, Status, Streaming};
4042pub struct BallistaFlightProxyService {
4143 max_decoding_message_size : usize ,
4244 max_encoding_message_size : usize ,
45+ /// Whether to use TLS when connecting to executors
46+ use_tls : bool ,
47+ /// Optional function to customize gRPC endpoint configuration (e.g., for TLS)
48+ customize_endpoint : Option < Arc < BallistaConfigGrpcEndpoint > > ,
4349}
4450
4551impl BallistaFlightProxyService {
4652 pub fn new (
4753 max_decoding_message_size : usize ,
4854 max_encoding_message_size : usize ,
55+ use_tls : bool ,
56+ customize_endpoint : Option < Arc < BallistaConfigGrpcEndpoint > > ,
4957 ) -> Self {
5058 Self {
5159 max_decoding_message_size,
5260 max_encoding_message_size,
61+ use_tls,
62+ customize_endpoint,
5363 }
5464 }
5565}
@@ -120,6 +130,8 @@ impl FlightService for BallistaFlightProxyService {
120130 * port,
121131 self . max_decoding_message_size ,
122132 self . max_encoding_message_size ,
133+ self . use_tls ,
134+ self . customize_endpoint . clone ( ) ,
123135 )
124136 . map_err ( |e| from_ballista_err ( & e) )
125137 . await ?;
@@ -169,16 +181,34 @@ async fn get_flight_client(
169181 port : u16 ,
170182 max_decoding_message_size : usize ,
171183 max_encoding_message_size : usize ,
184+ use_tls : bool ,
185+ customize_endpoint : Option < Arc < BallistaConfigGrpcEndpoint > > ,
172186) -> Result < FlightServiceClient < tonic:: transport:: channel:: Channel > , BallistaError > {
173- let addr = format ! ( "http://{host}:{port}" ) ;
187+ let scheme = if use_tls { "https" } else { "http" } ;
188+ let addr = format ! ( "{scheme}://{host}:{port}" ) ;
174189 let grpc_config = GrpcClientConfig :: default ( ) ;
175- let connection = create_grpc_client_connection ( addr . clone ( ) , & grpc_config )
176- . await
190+
191+ let mut endpoint = create_grpc_client_endpoint ( addr . clone ( ) , Some ( & grpc_config ) )
177192 . map_err ( |e| {
178193 BallistaError :: GrpcConnectionError ( format ! (
179- "Error connecting to Ballista scheduler or executor at {addr}: {e:?}"
194+ "Error creating endpoint for Ballista executor at {addr}: {e:?}"
195+ ) )
196+ } ) ?;
197+
198+ if let Some ( ref customize) = customize_endpoint {
199+ endpoint = customize. configure_endpoint ( endpoint) . map_err ( |e| {
200+ BallistaError :: GrpcConnectionError ( format ! (
201+ "Error customizing endpoint for Ballista executor at {addr}: {e}"
180202 ) )
181203 } ) ?;
204+ }
205+
206+ let connection = endpoint. connect ( ) . await . map_err ( |e| {
207+ BallistaError :: GrpcConnectionError ( format ! (
208+ "Error connecting to Ballista executor at {addr}: {e:?}"
209+ ) )
210+ } ) ?;
211+
182212 let flight_client = FlightServiceClient :: new ( connection)
183213 . max_decoding_message_size ( max_decoding_message_size)
184214 . max_encoding_message_size ( max_encoding_message_size) ;
0 commit comments