Skip to content

Commit 2a72065

Browse files
Add TLS support to scheduler flight proxy service
- Update BallistaFlightProxyService to accept use_tls and customize_endpoint parameters - Add use_tls field to SchedulerConfig with with_use_tls() builder method - Unify EndpointOverrideFn type across crates to use ballista_core::extension definition - Update flight proxy to use https/http scheme based on TLS configuration - Apply custom endpoint configuration for TLS certificate setup
1 parent 21c4684 commit 2a72065

6 files changed

Lines changed: 69 additions & 26 deletions

File tree

ballista/executor/src/executor_process.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use datafusion::execution::runtime_env::RuntimeEnvBuilder;
4242

4343
use ballista_core::config::{LogRotationPolicy, TaskSchedulingPolicy};
4444
use ballista_core::error::BallistaError;
45-
use ballista_core::extension::SessionConfigExt;
45+
use ballista_core::extension::{EndpointOverrideFn, SessionConfigExt};
4646
use ballista_core::serde::protobuf::executor_resource::Resource;
4747
use ballista_core::serde::protobuf::executor_status::Status;
4848
use ballista_core::serde::protobuf::{
@@ -57,11 +57,6 @@ use ballista_core::utils::{
5757
default_config_producer, get_time_before,
5858
};
5959
use ballista_core::{BALLISTA_VERSION, ConfigProducer, RuntimeProducer};
60-
use tonic::transport::{Endpoint, Error as TonicTransportError};
61-
62-
/// Type alias for the endpoint override function used in gRPC client configuration
63-
pub type EndpointOverrideFn =
64-
Arc<dyn Fn(Endpoint) -> Result<Endpoint, TonicTransportError> + Send + Sync>;
6560

6661
use crate::execution_engine::ExecutionEngine;
6762
use crate::executor::{Executor, TasksDrainedFuture};

ballista/executor/src/executor_server.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,11 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
3030
use tokio::sync::mpsc;
3131

3232
use log::{debug, error, info, warn};
33-
use tonic::transport::{Channel, Endpoint, Error as TonicTransportError};
33+
use tonic::transport::Channel;
3434
use tonic::{Request, Response, Status};
3535

36-
/// Type alias for the endpoint override function used in gRPC client configuration
37-
pub type EndpointOverrideFn =
38-
Arc<dyn Fn(Endpoint) -> Result<Endpoint, TonicTransportError> + Send + Sync>;
39-
4036
use ballista_core::error::BallistaError;
37+
use ballista_core::extension::EndpointOverrideFn;
4138
use ballista_core::serde::BallistaCodec;
4239
use ballista_core::serde::protobuf::{
4340
CancelTasksParams, CancelTasksResult, ExecutorMetric, ExecutorStatus,
@@ -276,7 +273,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
276273
let mut endpoint = create_grpc_client_endpoint(scheduler_url, None)?;
277274

278275
if let Some(ref override_fn) = self.override_create_grpc_client_endpoint {
279-
endpoint = override_fn(endpoint)?;
276+
endpoint = override_fn(endpoint).map_err(|e| {
277+
BallistaError::GrpcConnectionError(format!(
278+
"Failed to customize endpoint for scheduler {scheduler_id}: {e}"
279+
))
280+
})?;
280281
}
281282

282283
let connection = endpoint.connect().await?;

ballista/scheduler/src/config.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,12 @@
2727

2828
use crate::SessionBuilder;
2929
use crate::cluster::DistributionPolicy;
30+
use ballista_core::extension::EndpointOverrideFn;
3031
use ballista_core::{ConfigProducer, config::TaskSchedulingPolicy};
3132
use datafusion_proto::logical_plan::LogicalExtensionCodec;
3233
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3334
use std::fmt::Display;
3435
use std::sync::Arc;
35-
use tonic::transport::{Endpoint, Error as TonicTransportError};
36-
37-
/// Type alias for the endpoint override function used in gRPC client configuration
38-
pub type EndpointOverrideFn =
39-
Arc<dyn Fn(Endpoint) -> Result<Endpoint, TonicTransportError> + Send + Sync>;
4036

4137
/// Command-line configuration for the scheduler binary.
4238
#[cfg(feature = "build-binary")]
@@ -247,6 +243,8 @@ pub struct SchedulerConfig {
247243
pub override_physical_codec: Option<Arc<dyn PhysicalExtensionCodec>>,
248244
/// Override function for customizing gRPC client endpoints before they are used
249245
pub override_create_grpc_client_endpoint: Option<EndpointOverrideFn>,
246+
/// Whether to use TLS when connecting to executors (for flight proxy)
247+
pub use_tls: bool,
250248
}
251249

252250
impl Default for SchedulerConfig {
@@ -274,6 +272,7 @@ impl Default for SchedulerConfig {
274272
override_logical_codec: None,
275273
override_physical_codec: None,
276274
override_create_grpc_client_endpoint: None,
275+
use_tls: false,
277276
}
278277
}
279278
}
@@ -398,13 +397,17 @@ impl SchedulerConfig {
398397
/// This allows configuring TLS, timeouts, and other transport settings.
399398
pub fn with_override_create_grpc_client_endpoint(
400399
mut self,
401-
override_fn: Arc<
402-
dyn Fn(Endpoint) -> Result<Endpoint, TonicTransportError> + Send + Sync,
403-
>,
400+
override_fn: EndpointOverrideFn,
404401
) -> Self {
405402
self.override_create_grpc_client_endpoint = Some(override_fn);
406403
self
407404
}
405+
406+
/// Sets whether TLS should be used when connecting to executors (for flight proxy).
407+
pub fn with_use_tls(mut self, use_tls: bool) -> Self {
408+
self.use_tls = use_tls;
409+
self
410+
}
408411
}
409412

410413
/// Policy of distributing tasks to available executor slots
@@ -516,6 +519,7 @@ impl TryFrom<Config> for SchedulerConfig {
516519
override_physical_codec: None,
517520
override_session_builder: None,
518521
override_create_grpc_client_endpoint: None,
522+
use_tls: false,
519523
};
520524

521525
Ok(config)

ballista/scheduler/src/flight_proxy_service.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ use arrow_flight::{
2222
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
2323
};
2424
use ballista_core::error::BallistaError;
25+
use ballista_core::extension::BallistaConfigGrpcEndpoint;
2526
use ballista_core::serde::decode_protobuf;
2627
use ballista_core::serde::scheduler::Action as BallistaAction;
27-
use ballista_core::utils::{GrpcClientConfig, create_grpc_client_connection};
28+
use ballista_core::utils::{GrpcClientConfig, create_grpc_client_endpoint};
2829

2930
use futures::{Stream, TryFutureExt};
3031
use log::debug;
3132
use std::pin::Pin;
33+
use std::sync::Arc;
3234
use tonic::{Request, Response, Status, Streaming};
3335

3436
/// Service implementing a proxy from scheduler to executor Apache Arrow Flight Protocol
@@ -40,16 +42,24 @@ use tonic::{Request, Response, Status, Streaming};
4042
pub struct BallistaFlightProxyService {
4143
max_decoding_message_size: usize,
4244
max_encoding_message_size: usize,
45+
/// Whether to use TLS when connecting to executors
46+
use_tls: bool,
47+
/// Optional function to customize gRPC endpoint configuration (e.g., for TLS)
48+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
4349
}
4450

4551
impl BallistaFlightProxyService {
4652
pub fn new(
4753
max_decoding_message_size: usize,
4854
max_encoding_message_size: usize,
55+
use_tls: bool,
56+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
4957
) -> Self {
5058
Self {
5159
max_decoding_message_size,
5260
max_encoding_message_size,
61+
use_tls,
62+
customize_endpoint,
5363
}
5464
}
5565
}
@@ -120,6 +130,8 @@ impl FlightService for BallistaFlightProxyService {
120130
*port,
121131
self.max_decoding_message_size,
122132
self.max_encoding_message_size,
133+
self.use_tls,
134+
self.customize_endpoint.clone(),
123135
)
124136
.map_err(|e| from_ballista_err(&e))
125137
.await?;
@@ -169,16 +181,34 @@ async fn get_flight_client(
169181
port: u16,
170182
max_decoding_message_size: usize,
171183
max_encoding_message_size: usize,
184+
use_tls: bool,
185+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
172186
) -> Result<FlightServiceClient<tonic::transport::channel::Channel>, BallistaError> {
173-
let addr = format!("http://{host}:{port}");
187+
let scheme = if use_tls { "https" } else { "http" };
188+
let addr = format!("{scheme}://{host}:{port}");
174189
let grpc_config = GrpcClientConfig::default();
175-
let connection = create_grpc_client_connection(addr.clone(), &grpc_config)
176-
.await
190+
191+
let mut endpoint = create_grpc_client_endpoint(addr.clone(), Some(&grpc_config))
177192
.map_err(|e| {
178193
BallistaError::GrpcConnectionError(format!(
179-
"Error connecting to Ballista scheduler or executor at {addr}: {e:?}"
194+
"Error creating endpoint for Ballista executor at {addr}: {e:?}"
195+
))
196+
})?;
197+
198+
if let Some(ref customize) = customize_endpoint {
199+
endpoint = customize.configure_endpoint(endpoint).map_err(|e| {
200+
BallistaError::GrpcConnectionError(format!(
201+
"Error customizing endpoint for Ballista executor at {addr}: {e}"
180202
))
181203
})?;
204+
}
205+
206+
let connection = endpoint.connect().await.map_err(|e| {
207+
BallistaError::GrpcConnectionError(format!(
208+
"Error connecting to Ballista executor at {addr}: {e:?}"
209+
))
210+
})?;
211+
182212
let flight_client = FlightServiceClient::new(connection)
183213
.max_decoding_message_size(max_decoding_message_size)
184214
.max_encoding_message_size(max_encoding_message_size);

ballista/scheduler/src/scheduler_process.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::flight_proxy_service::BallistaFlightProxyService;
2020
use arrow_flight::flight_service_server::FlightServiceServer;
2121
use ballista_core::BALLISTA_VERSION;
2222
use ballista_core::error::BallistaError;
23+
use ballista_core::extension::BallistaConfigGrpcEndpoint;
2324
use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer;
2425
use ballista_core::serde::{
2526
BallistaCodec, BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec,
@@ -102,9 +103,17 @@ pub async fn start_grpc_service<
102103
match &config.advertise_flight_sql_endpoint {
103104
Some(proxy) if proxy.is_empty() => {
104105
info!("Adding embedded flight proxy service on scheduler");
106+
// Wrap the endpoint override function in BallistaConfigGrpcEndpoint
107+
let customize_endpoint = config
108+
.override_create_grpc_client_endpoint
109+
.clone()
110+
.map(|f| Arc::new(BallistaConfigGrpcEndpoint::new(f)));
111+
105112
let flight_proxy = FlightServiceServer::new(BallistaFlightProxyService::new(
106113
config.grpc_server_max_encoding_message_size as usize,
107114
config.grpc_server_max_decoding_message_size as usize,
115+
config.use_tls,
116+
customize_endpoint,
108117
))
109118
.max_decoding_message_size(
110119
config.grpc_server_max_decoding_message_size as usize,

ballista/scheduler/src/state/executor_manager.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,11 @@ impl ExecutorManager {
481481
if let Some(ref override_fn) =
482482
self.config.override_create_grpc_client_endpoint
483483
{
484-
endpoint = override_fn(endpoint)?;
484+
endpoint = override_fn(endpoint).map_err(|e| {
485+
BallistaError::GrpcConnectionError(format!(
486+
"Failed to customize endpoint for executor {executor_id}: {e}"
487+
))
488+
})?;
485489
}
486490

487491
let connection = endpoint.connect().await?;

0 commit comments

Comments
 (0)