Skip to content

Commit d7b0bcf

Browse files
Add shuffle locality metrics to ExecutorMetricsCollector, SchedulerMetricsCollector, and ShuffleReaderExec
- Add record_shuffle_read_local/remote methods to ExecutorMetricsCollector trait - Add record_task_shuffle_affinity_hit/miss methods to SchedulerMetricsCollector trait - Add ShuffleReadMetricsCallback trait in ballista-core for tracking local vs remote reads - Instrument shuffle_reader.rs to call metrics callback during partition fetches - Add SessionConfigExt methods to pass metrics callback via session config
1 parent 379cf36 commit d7b0bcf

5 files changed

Lines changed: 374 additions & 10 deletions

File tree

ballista/core/src/execution_plans/distributed_query.rs

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
use crate::client::BallistaClient;
1919
use crate::config::BallistaConfig;
2020
use crate::extension::{
21-
BallistaConfigGrpcEndpoint, BallistaGrpcMetadataInterceptor, SessionConfigExt,
21+
BallistaConfigGrpcEndpoint, BallistaGrpcMetadataInterceptor,
22+
ResultFetchMetricsCallback, SessionConfigExt,
2223
};
2324
use crate::serde::protobuf::SuccessfulJob;
2425
use crate::serde::protobuf::{
@@ -248,8 +249,6 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
248249
let metric_total_bytes =
249250
MetricBuilder::new(&self.metrics).counter("transferred_bytes", partition);
250251

251-
252-
253252
let interceptor = context.session_config().ballista_grpc_interceptor();
254253

255254
let customize_endpoint = context
@@ -258,23 +257,24 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
258257

259258
let use_tls = context.session_config().ballista_use_tls();
260259

260+
let result_fetch_callback = context
261+
.session_config()
262+
.ballista_result_fetch_metrics_callback();
261263

262264
let stream = futures::stream::once(
263265
execute_query(
264266
self.scheduler_url.clone(),
265267
self.session_id.clone(),
266268
query,
267-
268269
self.config.default_grpc_client_max_message_size(),
269270
GrpcClientConfig::from(&self.config),
270271
Arc::new(self.metrics.clone()),
271272
partition,
272-
273273
self.config.clone(),
274274
interceptor,
275275
customize_endpoint,
276276
use_tls,
277-
277+
result_fetch_callback,
278278
)
279279
.map_err(|e| ArrowError::ExternalError(Box::new(e))),
280280
)
@@ -320,7 +320,7 @@ async fn execute_query(
320320
grpc_interceptor: Arc<BallistaGrpcMetadataInterceptor>,
321321
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
322322
use_tls: bool,
323-
323+
result_fetch_callback: Option<Arc<dyn ResultFetchMetricsCallback>>,
324324
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send> {
325325
// Capture query submission time for total_query_time_ms
326326
let query_start_time = std::time::Instant::now();
@@ -450,12 +450,14 @@ async fn execute_query(
450450
// This could be added in a future enhancement by wrapping the stream.
451451

452452
let streams = partition_location.into_iter().map(move |partition| {
453+
let callback = result_fetch_callback.clone();
453454
let f = fetch_partition(
454455
partition,
455456
max_message_size,
456457
true,
457458
customize_endpoint.clone(),
458459
use_tls,
460+
callback,
459461
)
460462
.map_err(|e| ArrowError::ExternalError(Box::new(e)));
461463

@@ -474,13 +476,29 @@ async fn fetch_partition(
474476
flight_transport: bool,
475477
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
476478
use_tls: bool,
479+
metrics_callback: Option<Arc<dyn ResultFetchMetricsCallback>>,
477480
) -> Result<SendableRecordBatchStream> {
481+
let start_time = std::time::Instant::now();
482+
478483
let metadata = location.executor_meta.ok_or_else(|| {
479484
DataFusionError::Internal("Received empty executor metadata".to_owned())
480485
})?;
481486
let partition_id = location.partition_id.ok_or_else(|| {
482487
DataFusionError::Internal("Received empty partition id".to_owned())
483488
})?;
489+
490+
// Extract stats before consuming location
491+
let stats = location.partition_stats.as_ref();
492+
#[expect(clippy::cast_sign_loss)]
493+
let expected_bytes = stats.map(|s| s.num_bytes as u64).unwrap_or(0);
494+
#[expect(clippy::cast_sign_loss)]
495+
let expected_rows = stats.map(|s| s.num_rows as u64).unwrap_or(0);
496+
497+
let job_id = partition_id.job_id.clone();
498+
let stage_id = partition_id.stage_id as usize;
499+
let partition = partition_id.partition_id as usize;
500+
let executor_id = metadata.id.clone();
501+
484502
let host = metadata.host.as_str();
485503
let port = metadata.port as u16;
486504
let mut ballista_client = BallistaClient::try_new(
@@ -492,7 +510,8 @@ async fn fetch_partition(
492510
)
493511
.await
494512
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
495-
ballista_client
513+
514+
let stream = ballista_client
496515
.fetch_partition(
497516
&metadata.id,
498517
&partition_id.into(),
@@ -502,5 +521,21 @@ async fn fetch_partition(
502521
flight_transport,
503522
)
504523
.await
505-
.map_err(|e| DataFusionError::External(Box::new(e)))
524+
.map_err(|e| DataFusionError::External(Box::new(e)))?;
525+
526+
// Record metrics after successful fetch
527+
if let Some(callback) = metrics_callback {
528+
let duration_ms = start_time.elapsed().as_millis() as u64;
529+
callback.record_result_fetch(
530+
&job_id,
531+
stage_id,
532+
partition,
533+
&executor_id,
534+
expected_bytes,
535+
expected_rows,
536+
duration_ms,
537+
);
538+
}
539+
540+
Ok(stream)
506541
}

ballista/core/src/execution_plans/shuffle_reader.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ use std::sync::Arc;
2929
use std::task::{Context, Poll};
3030

3131
use crate::client::BallistaClient;
32-
use crate::extension::{BallistaConfigGrpcEndpoint, SessionConfigExt};
32+
use crate::extension::{
33+
BallistaConfigGrpcEndpoint, SessionConfigExt, ShuffleReadMetricsCallback,
34+
};
3335
use crate::serde::scheduler::{PartitionLocation, PartitionStats};
3436

3537
use datafusion::arrow::datatypes::SchemaRef;
@@ -164,6 +166,7 @@ impl ExecutionPlan for ShuffleReaderExec {
164166
let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight();
165167
let customize_endpoint = config.ballista_override_create_grpc_client_endpoint();
166168
let use_tls = config.ballista_use_tls();
169+
let metrics_callback = config.ballista_shuffle_read_metrics_callback();
167170

168171
if force_remote_read {
169172
debug!(
@@ -199,6 +202,7 @@ impl ExecutionPlan for ShuffleReaderExec {
199202
prefer_flight,
200203
customize_endpoint,
201204
use_tls,
205+
metrics_callback,
202206
);
203207

204208
let result = RecordBatchStreamAdapter::new(
@@ -396,6 +400,7 @@ fn send_fetch_partitions(
396400
flight_transport: bool,
397401
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
398402
use_tls: bool,
403+
metrics_callback: Option<Arc<dyn ShuffleReadMetricsCallback>>,
399404
) -> AbortableReceiverStream {
400405
let (response_sender, response_receiver) = mpsc::channel(max_request_num);
401406
let semaphore = Arc::new(Semaphore::new(max_request_num));
@@ -413,8 +418,10 @@ fn send_fetch_partitions(
413418
// keep local shuffle files reading in serial order for memory control.
414419
let response_sender_c = response_sender.clone();
415420
let customize_endpoint_c = customize_endpoint.clone();
421+
let metrics_callback_c = metrics_callback.clone();
416422
spawned_tasks.push(SpawnedTask::spawn(async move {
417423
for p in local_locations {
424+
let start_time = std::time::Instant::now();
418425
let r = PartitionReaderEnum::Local
419426
.fetch_partition(
420427
&p,
@@ -424,6 +431,25 @@ fn send_fetch_partitions(
424431
use_tls,
425432
)
426433
.await;
434+
435+
// Record local read metrics if callback is set and read succeeded
436+
if r.is_ok() {
437+
if let Some(ref callback) = metrics_callback_c {
438+
let duration_ms = start_time.elapsed().as_millis() as u64;
439+
let bytes = p.partition_stats.num_bytes().unwrap_or(0);
440+
let rows = p.partition_stats.num_rows().unwrap_or(0);
441+
callback.record_local_read(
442+
&p.partition_id.job_id,
443+
p.partition_id.stage_id,
444+
p.partition_id.partition_id,
445+
&p.executor_meta.id,
446+
bytes,
447+
rows,
448+
duration_ms,
449+
);
450+
}
451+
}
452+
427453
if let Err(e) = response_sender_c.send(r).await {
428454
error!("Fail to send response event to the channel due to {e}");
429455
}
@@ -434,9 +460,11 @@ fn send_fetch_partitions(
434460
let semaphore = semaphore.clone();
435461
let response_sender = response_sender.clone();
436462
let customize_endpoint_c = customize_endpoint.clone();
463+
let metrics_callback_c = metrics_callback.clone();
437464
spawned_tasks.push(SpawnedTask::spawn(async move {
438465
// Block if exceeds max request number.
439466
let permit = semaphore.acquire_owned().await.unwrap();
467+
let start_time = std::time::Instant::now();
440468
let r = PartitionReaderEnum::FlightRemote
441469
.fetch_partition(
442470
&p,
@@ -446,6 +474,25 @@ fn send_fetch_partitions(
446474
use_tls,
447475
)
448476
.await;
477+
478+
// Record remote read metrics if callback is set and read succeeded
479+
if r.is_ok() {
480+
if let Some(ref callback) = metrics_callback_c {
481+
let duration_ms = start_time.elapsed().as_millis() as u64;
482+
let bytes = p.partition_stats.num_bytes().unwrap_or(0);
483+
let rows = p.partition_stats.num_rows().unwrap_or(0);
484+
callback.record_remote_read(
485+
&p.partition_id.job_id,
486+
p.partition_id.stage_id,
487+
p.partition_id.partition_id,
488+
&p.executor_meta.id,
489+
bytes,
490+
rows,
491+
duration_ms,
492+
);
493+
}
494+
}
495+
449496
// Block if the channel buffer is full.
450497
if let Err(e) = response_sender.send(r).await {
451498
error!("Fail to send response event to the channel due to {e}");
@@ -992,6 +1039,7 @@ mod tests {
9921039
true,
9931040
None,
9941041
false,
1042+
None, // No metrics callback in tests
9951043
);
9961044

9971045
let stream = RecordBatchStreamAdapter::new(

0 commit comments

Comments
 (0)