Skip to content

Commit f064477

Browse files
Cluster RPC customizations to support TLS and custom headers
1 parent 76ca657 commit f064477

16 files changed

Lines changed: 1118 additions & 70 deletions

File tree

Cargo.lock

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ballista/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ prost = { workspace = true }
6262
prost-types = { workspace = true }
6363
rand = { workspace = true }
6464
serde = { workspace = true, features = ["derive"] }
65-
tokio = { workspace = true }
65+
tokio = { workspace = true, features = ["rt-multi-thread"] }
6666
tokio-stream = { workspace = true, features = ["net"] }
6767
tonic = { workspace = true }
6868
tonic-prost = { workspace = true }

ballista/core/src/client.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ use datafusion::arrow::{
4444
use datafusion::error::DataFusionError;
4545
use datafusion::error::Result;
4646

47+
use crate::extension::BallistaConfigGrpcEndpoint;
4748
use crate::serde::protobuf;
48-
use crate::utils::{GrpcClientConfig, create_grpc_client_connection};
49+
50+
use crate::utils::create_grpc_client_endpoint;
51+
4952
use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
5053
use futures::{Stream, StreamExt};
5154
use log::{debug, warn};
@@ -69,17 +72,37 @@ impl BallistaClient {
6972
host: &str,
7073
port: u16,
7174
max_message_size: usize,
75+
use_tls: bool,
76+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
7277
) -> BResult<Self> {
73-
let addr = format!("http://{host}:{port}");
74-
let grpc_config = GrpcClientConfig::default();
78+
let scheme = if use_tls { "https" } else { "http" };
79+
80+
let addr = format!("{scheme}://{host}:{port}");
7581
debug!("BallistaClient connecting to {addr}");
76-
let connection = create_grpc_client_connection(addr.clone(), &grpc_config)
77-
.await
82+
83+
let mut endpoint = create_grpc_client_endpoint(addr.clone(), None)
7884
.map_err(|e| {
7985
BallistaError::GrpcConnectionError(format!(
80-
"Error connecting to Ballista scheduler or executor at {addr}: {e:?}"
86+
"Error creating endpoint to Ballista scheduler or executor at {addr}: {e:?}"
8187
))
8288
})?;
89+
90+
if let Some(customize) = customize_endpoint {
91+
endpoint = customize
92+
.configure_endpoint(endpoint)
93+
.map_err(|e| {
94+
BallistaError::GrpcConnectionError(format!(
95+
"Error creating endpoint to Ballista scheduler or executor at {addr}: {e:?}"
96+
))
97+
})?;
98+
}
99+
100+
let connection = endpoint.connect().await.map_err(|e| {
101+
BallistaError::GrpcConnectionError(format!(
102+
"Error connecting to Ballista scheduler or executor at {addr}: {e:?}"
103+
))
104+
})?;
105+
83106
let flight_client = FlightServiceClient::new(connection)
84107
.max_decoding_message_size(max_message_size)
85108
.max_encoding_message_size(max_message_size);

ballista/core/src/execution_plans/distributed_query.rs

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
use crate::client::BallistaClient;
1919
use crate::config::BallistaConfig;
20+
use crate::extension::{BallistaConfigGrpcEndpoint, SessionConfigExt};
2021
use crate::serde::protobuf::get_job_status_result::FlightProxy;
2122
use crate::serde::protobuf::{
2223
ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, KeyValuePair,
2324
PartitionLocation, execute_query_params::Query, execute_query_result, job_status,
2425
scheduler_grpc_client::SchedulerGrpcClient,
2526
};
2627
use crate::serde::protobuf::{ExecutorMetadata, SuccessfulJob};
27-
use crate::utils::{GrpcClientConfig, create_grpc_client_connection};
28+
use crate::utils::{GrpcClientConfig, create_grpc_client_endpoint};
2829
use datafusion::arrow::datatypes::SchemaRef;
2930
use datafusion::arrow::error::ArrowError;
3031
use datafusion::arrow::record_batch::RecordBatch;
@@ -40,6 +41,7 @@ use datafusion::physical_plan::{
4041
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
4142
SendableRecordBatchStream, Statistics,
4243
};
44+
use datafusion::prelude::SessionConfig;
4345
use datafusion_proto::logical_plan::{
4446
AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec,
4547
};
@@ -243,6 +245,8 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
243245
let metric_total_bytes =
244246
MetricBuilder::new(&self.metrics).counter("transferred_bytes", partition);
245247

248+
let session_config = context.session_config().clone();
249+
246250
let stream = futures::stream::once(
247251
execute_query(
248252
self.scheduler_url.clone(),
@@ -252,6 +256,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
252256
GrpcClientConfig::from(&self.config),
253257
Arc::new(self.metrics.clone()),
254258
partition,
259+
session_config,
255260
)
256261
.map_err(|e| ArrowError::ExternalError(Box::new(e))),
257262
)
@@ -283,6 +288,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
283288
}
284289
}
285290

291+
#[allow(clippy::too_many_arguments)]
286292
async fn execute_query(
287293
scheduler_url: String,
288294
session_id: String,
@@ -291,19 +297,39 @@ async fn execute_query(
291297
grpc_config: GrpcClientConfig,
292298
metrics: Arc<ExecutionPlanMetricsSet>,
293299
partition: usize,
300+
session_config: SessionConfig,
294301
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send> {
302+
let grpc_interceptor = session_config.ballista_grpc_interceptor();
303+
let customize_endpoint =
304+
session_config.ballista_override_create_grpc_client_endpoint();
305+
let use_tls = session_config.ballista_use_tls();
306+
295307
// Capture query submission time for total_query_time_ms
296308
let query_start_time = std::time::Instant::now();
297309

298310
info!("Connecting to Ballista scheduler at {scheduler_url}");
299311
// TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again
300-
let connection = create_grpc_client_connection(scheduler_url.clone(), &grpc_config)
312+
let mut endpoint =
313+
create_grpc_client_endpoint(scheduler_url.clone(), Some(&grpc_config))
314+
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
315+
316+
if let Some(ref customize) = customize_endpoint {
317+
endpoint = customize
318+
.configure_endpoint(endpoint)
319+
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
320+
}
321+
322+
let connection = endpoint
323+
.connect()
301324
.await
302325
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
303326

304-
let mut scheduler = SchedulerGrpcClient::new(connection)
305-
.max_encoding_message_size(max_message_size)
306-
.max_decoding_message_size(max_message_size);
327+
let mut scheduler = SchedulerGrpcClient::with_interceptor(
328+
connection,
329+
grpc_interceptor.as_ref().clone(),
330+
)
331+
.max_encoding_message_size(max_message_size)
332+
.max_decoding_message_size(max_message_size);
307333

308334
let query_result = scheduler
309335
.execute_query(query)
@@ -414,6 +440,8 @@ async fn execute_query(
414440
true,
415441
scheduler_url.clone(),
416442
flight_proxy.clone(),
443+
customize_endpoint.clone(),
444+
use_tls,
417445
)
418446
.map_err(|e| ArrowError::ExternalError(Box::new(e)));
419447

@@ -477,6 +505,8 @@ async fn fetch_partition(
477505
flight_transport: bool,
478506
scheduler_url: String,
479507
flight_proxy: Option<FlightProxy>,
508+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
509+
use_tls: bool,
480510
) -> Result<SendableRecordBatchStream> {
481511
let metadata = location.executor_meta.ok_or_else(|| {
482512
DataFusionError::Internal("Received empty executor metadata".to_owned())
@@ -491,10 +521,15 @@ async fn fetch_partition(
491521
let (client_host, client_port) =
492522
get_client_host_port(&metadata, &scheduler_url, &flight_proxy)?;
493523

494-
let mut ballista_client =
495-
BallistaClient::try_new(client_host.as_str(), client_port, max_message_size)
496-
.await
497-
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
524+
let mut ballista_client = BallistaClient::try_new(
525+
client_host.as_str(),
526+
client_port,
527+
max_message_size,
528+
use_tls,
529+
customize_endpoint,
530+
)
531+
.await
532+
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
498533
ballista_client
499534
.fetch_partition(
500535
&metadata.id,

ballista/core/src/execution_plans/shuffle_reader.rs

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use crate::client::BallistaClient;
3333
use crate::execution_plans::sort_shuffle::{
3434
get_index_path, is_sort_shuffle_output, stream_sort_shuffle_partition,
3535
};
36-
use crate::extension::SessionConfigExt;
36+
use crate::extension::{BallistaConfigGrpcEndpoint, SessionConfigExt};
3737
use crate::serde::scheduler::{PartitionLocation, PartitionStats};
3838

3939
use datafusion::arrow::datatypes::SchemaRef;
@@ -169,6 +169,8 @@ impl ExecutionPlan for ShuffleReaderExec {
169169
let force_remote_read = config.ballista_shuffle_reader_force_remote_read();
170170
let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight();
171171
let batch_size = config.batch_size();
172+
let customize_endpoint = config.ballista_override_create_grpc_client_endpoint();
173+
let use_tls = config.ballista_use_tls();
172174

173175
if force_remote_read {
174176
debug!(
@@ -202,6 +204,8 @@ impl ExecutionPlan for ShuffleReaderExec {
202204
max_message_size,
203205
force_remote_read,
204206
prefer_flight,
207+
customize_endpoint,
208+
use_tls,
205209
);
206210

207211
let input_stream = Box::pin(RecordBatchStreamAdapter::new(
@@ -404,6 +408,8 @@ fn send_fetch_partitions(
404408
max_message_size: usize,
405409
force_remote_read: bool,
406410
flight_transport: bool,
411+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
412+
use_tls: bool,
407413
) -> AbortableReceiverStream {
408414
let (response_sender, response_receiver) = mpsc::channel(max_request_num);
409415
let semaphore = Arc::new(Semaphore::new(max_request_num));
@@ -420,10 +426,17 @@ fn send_fetch_partitions(
420426

421427
// keep local shuffle files reading in serial order for memory control.
422428
let response_sender_c = response_sender.clone();
429+
let customize_endpoint_c = customize_endpoint.clone();
423430
spawned_tasks.push(SpawnedTask::spawn(async move {
424431
for p in local_locations {
425432
let r = PartitionReaderEnum::Local
426-
.fetch_partition(&p, max_message_size, flight_transport)
433+
.fetch_partition(
434+
&p,
435+
max_message_size,
436+
flight_transport,
437+
customize_endpoint_c.clone(),
438+
use_tls,
439+
)
427440
.await;
428441
if let Err(e) = response_sender_c.send(r).await {
429442
error!("Fail to send response event to the channel due to {e}");
@@ -434,11 +447,18 @@ fn send_fetch_partitions(
434447
for p in remote_locations.into_iter() {
435448
let semaphore = semaphore.clone();
436449
let response_sender = response_sender.clone();
450+
let customize_endpoint_c = customize_endpoint.clone();
437451
spawned_tasks.push(SpawnedTask::spawn(async move {
438452
// Block if exceeds max request number.
439453
let permit = semaphore.acquire_owned().await.unwrap();
440454
let r = PartitionReaderEnum::FlightRemote
441-
.fetch_partition(&p, max_message_size, flight_transport)
455+
.fetch_partition(
456+
&p,
457+
max_message_size,
458+
flight_transport,
459+
customize_endpoint_c,
460+
use_tls,
461+
)
442462
.await;
443463
// Block if the channel buffer is full.
444464
if let Err(e) = response_sender.send(r).await {
@@ -465,6 +485,8 @@ trait PartitionReader: Send + Sync + Clone {
465485
location: &PartitionLocation,
466486
max_message_size: usize,
467487
flight_transport: bool,
488+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
489+
use_tls: bool,
468490
) -> result::Result<SendableRecordBatchStream, BallistaError>;
469491
}
470492

@@ -484,10 +506,19 @@ impl PartitionReader for PartitionReaderEnum {
484506
location: &PartitionLocation,
485507
max_message_size: usize,
486508
flight_transport: bool,
509+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
510+
use_tls: bool,
487511
) -> result::Result<SendableRecordBatchStream, BallistaError> {
488512
match self {
489513
PartitionReaderEnum::FlightRemote => {
490-
fetch_partition_remote(location, max_message_size, flight_transport).await
514+
fetch_partition_remote(
515+
location,
516+
max_message_size,
517+
flight_transport,
518+
customize_endpoint,
519+
use_tls,
520+
)
521+
.await
491522
}
492523
PartitionReaderEnum::Local => fetch_partition_local(location).await,
493524
PartitionReaderEnum::ObjectStoreRemote => {
@@ -501,25 +532,33 @@ async fn fetch_partition_remote(
501532
location: &PartitionLocation,
502533
max_message_size: usize,
503534
flight_transport: bool,
535+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
536+
use_tls: bool,
504537
) -> result::Result<SendableRecordBatchStream, BallistaError> {
505538
let metadata = &location.executor_meta;
506539
let partition_id = &location.partition_id;
507540
// TODO for shuffle client connections, we should avoid creating new connections again and again.
508541
// And we should also avoid to keep alive too many connections for long time.
509542
let host = metadata.host.as_str();
510543
let port = metadata.port;
511-
let mut ballista_client = BallistaClient::try_new(host, port, max_message_size)
512-
.await
513-
.map_err(|error| match error {
514-
// map grpc connection error to partition fetch error.
515-
BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed(
516-
metadata.id.clone(),
517-
partition_id.stage_id,
518-
partition_id.partition_id,
519-
msg,
520-
),
521-
other => other,
522-
})?;
544+
let mut ballista_client = BallistaClient::try_new(
545+
host,
546+
port,
547+
max_message_size,
548+
use_tls,
549+
customize_endpoint,
550+
)
551+
.await
552+
.map_err(|error| match error {
553+
// map grpc connection error to partition fetch error.
554+
BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed(
555+
metadata.id.clone(),
556+
partition_id.stage_id,
557+
partition_id.partition_id,
558+
msg,
559+
),
560+
other => other,
561+
})?;
523562

524563
ballista_client
525564
.fetch_partition(
@@ -1087,6 +1126,8 @@ mod tests {
10871126
4 * 1024 * 1024,
10881127
false,
10891128
true,
1129+
None,
1130+
false,
10901131
);
10911132

10921133
let stream = RecordBatchStreamAdapter::new(

0 commit comments

Comments
 (0)