Skip to content

Commit ade1ba8

Browse files
committed
add sessionconfigext customization for endpoint create
1 parent 44e768e commit ade1ba8

2 files changed

Lines changed: 88 additions & 6 deletions

File tree

ballista/core/src/execution_plans/distributed_query.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
use crate::client::BallistaClient;
1919
use crate::config::BallistaConfig;
20-
use crate::extension::BallistaGrpcMetadataInterceptor;
20+
use crate::extension::{
21+
BallistaConfigGrpcEndpoint, BallistaGrpcMetadataInterceptor, SessionConfigExt,
22+
};
2123
use crate::serde::protobuf::SuccessfulJob;
2224
use crate::serde::protobuf::{
2325
execute_query_params::Query, execute_query_result, job_status,
@@ -50,6 +52,7 @@ use std::fmt::Debug;
5052
use std::marker::PhantomData;
5153
use std::sync::Arc;
5254
use std::time::Duration;
55+
use tonic::transport::Endpoint;
5356

5457
/// This operator sends a logical plan to a Ballista scheduler for execution and
5558
/// polls the scheduler until the query is complete and then fetches the resulting
@@ -234,10 +237,11 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
234237
let metric_total_bytes =
235238
MetricBuilder::new(&self.metrics).counter("transferred_bytes", partition);
236239

237-
let interceptor = context
240+
let interceptor = context.session_config().ballista_grpc_interceptor();
241+
242+
let customize_endpoint = context
238243
.session_config()
239-
.get_extension::<BallistaGrpcMetadataInterceptor>()
240-
.unwrap_or_default();
244+
.ballista_override_create_grpc_client_endpoint();
241245

242246
let stream = futures::stream::once(
243247
execute_query(
@@ -246,6 +250,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
246250
query,
247251
self.config.clone(),
248252
interceptor,
253+
customize_endpoint,
249254
)
250255
.map_err(|e| ArrowError::ExternalError(Box::new(e))),
251256
)
@@ -283,11 +288,20 @@ async fn execute_query(
283288
query: ExecuteQueryParams,
284289
config: BallistaConfig,
285290
grpc_interceptor: Arc<BallistaGrpcMetadataInterceptor>,
291+
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
286292
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send> {
287293
info!("Connecting to Ballista scheduler at {scheduler_url}");
288294
// TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again
289-
let connection = create_grpc_client_endpoint(scheduler_url)
290-
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
295+
let mut endpoint = create_grpc_client_endpoint(scheduler_url)
296+
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
297+
298+
if let Some(customize) = customize_endpoint {
299+
endpoint = customize
300+
.configure_endpoint(endpoint)
301+
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
302+
}
303+
304+
let connection = endpoint
291305
.connect()
292306
.await
293307
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

ballista/core/src/extension.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ use datafusion_proto::logical_plan::LogicalExtensionCodec;
3030
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3131
use datafusion_proto::protobuf::LogicalPlanNode;
3232
use std::collections::HashMap;
33+
use std::error::Error;
3334
use std::sync::Arc;
3435
use tonic::codegen::http::HeaderName;
3536
use tonic::metadata::MetadataMap;
3637
use tonic::service::Interceptor;
38+
use tonic::transport::Endpoint;
3739
use tonic::{Request, Status};
3840

3941
/// Provides methods which adapt [SessionState]
@@ -146,7 +148,24 @@ pub trait SessionConfigExt {
146148
prefer_flight: bool,
147149
) -> Self;
148150

151+
/// Set user defined metadata keys in Ballista gRPC requests
149152
fn with_ballista_grpc_metadata(self, metadata: HashMap<String, String>) -> Self;
153+
154+
/// Get a `tonic` interceptor configured to decorate the provided metadata keys
155+
fn ballista_grpc_interceptor(&self) -> Arc<BallistaGrpcMetadataInterceptor>;
156+
157+
fn with_ballista_override_create_grpc_client_endpoint(
158+
self,
159+
override_f: Arc<
160+
dyn Fn(Endpoint) -> Result<Endpoint, Box<dyn Error + Send + Sync>>
161+
+ Send
162+
+ Sync,
163+
>,
164+
) -> Self;
165+
166+
fn ballista_override_create_grpc_client_endpoint(
167+
&self,
168+
) -> Option<Arc<BallistaConfigGrpcEndpoint>>;
150169
}
151170

152171
/// [SessionConfigHelperExt] is set of [SessionConfig] extension methods
@@ -398,6 +417,29 @@ impl SessionConfigExt for SessionConfig {
398417
let extension = BallistaGrpcMetadataInterceptor::new(metadata);
399418
self.with_extension(Arc::new(extension))
400419
}
420+
421+
fn ballista_grpc_interceptor(&self) -> Arc<BallistaGrpcMetadataInterceptor> {
422+
self.get_extension::<BallistaGrpcMetadataInterceptor>()
423+
.unwrap_or_default()
424+
}
425+
426+
fn with_ballista_override_create_grpc_client_endpoint(
427+
self,
428+
override_f: Arc<
429+
dyn Fn(Endpoint) -> Result<Endpoint, Box<dyn Error + Send + Sync>>
430+
+ Send
431+
+ Sync,
432+
>,
433+
) -> Self {
434+
let extension = BallistaConfigGrpcEndpoint::new(override_f);
435+
self.with_extension(Arc::new(extension))
436+
}
437+
438+
fn ballista_override_create_grpc_client_endpoint(
439+
&self,
440+
) -> Option<Arc<BallistaConfigGrpcEndpoint>> {
441+
self.get_extension::<BallistaConfigGrpcEndpoint>()
442+
}
401443
}
402444

403445
impl SessionConfigHelperExt for SessionConfig {
@@ -570,6 +612,32 @@ impl Interceptor for BallistaGrpcMetadataInterceptor {
570612
}
571613
}
572614

615+
#[derive(Clone)]
616+
pub struct BallistaConfigGrpcEndpoint {
617+
override_f: Arc<
618+
dyn Fn(Endpoint) -> Result<Endpoint, Box<dyn Error + Send + Sync>> + Send + Sync,
619+
>,
620+
}
621+
622+
impl BallistaConfigGrpcEndpoint {
623+
pub fn new(
624+
override_f: Arc<
625+
dyn Fn(Endpoint) -> Result<Endpoint, Box<dyn Error + Send + Sync>>
626+
+ Send
627+
+ Sync,
628+
>,
629+
) -> Self {
630+
Self { override_f }
631+
}
632+
633+
pub fn configure_endpoint(
634+
&self,
635+
endpoint: Endpoint,
636+
) -> Result<Endpoint, Box<dyn Error + Send + Sync>> {
637+
(self.override_f)(endpoint)
638+
}
639+
}
640+
573641
#[cfg(test)]
574642
mod test {
575643
use datafusion::{

0 commit comments

Comments
 (0)