Skip to content

Commit 53b6d88

Browse files
scheduler: route task cancellation through callback hook
1 parent 68afffb commit 53b6d88

2 files changed

Lines changed: 104 additions & 8 deletions

File tree

ballista/scheduler/src/config.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ pub type EndpointOverrideFn =
4949
/// rather than waiting for their next poll interval.
5050
pub type OnWorkAvailableFn = Arc<dyn Fn(&str) + Send + Sync>;
5151

52+
/// Callback invoked when running tasks should be cancelled on an executor.
53+
///
54+
/// Arguments are:
55+
/// - executor_id
56+
/// - running tasks to cancel on that executor
57+
pub type OnCancelTasksFn =
58+
Arc<dyn Fn(&str, Vec<crate::state::execution_graph::RunningTaskInfo>) + Send + Sync>;
59+
5260
/// Command-line configuration for the scheduler binary.
5361
#[cfg(feature = "build-binary")]
5462
#[derive(clap::Parser, Debug)]
@@ -261,6 +269,8 @@ pub struct SchedulerConfig {
261269
/// Callback invoked when new work becomes available for executors.
262270
/// The string argument is a reason/description for debugging purposes.
263271
pub on_work_available: Option<OnWorkAvailableFn>,
272+
/// Callback invoked when running tasks should be cancelled on an executor.
273+
pub on_cancel_tasks: Option<OnCancelTasksFn>,
264274
}
265275

266276
impl Default for SchedulerConfig {
@@ -290,6 +300,7 @@ impl Default for SchedulerConfig {
290300
override_create_grpc_client_endpoint: None,
291301
override_metrics_collector: None,
292302
on_work_available: None,
303+
on_cancel_tasks: None,
293304
}
294305
}
295306
}
@@ -542,6 +553,7 @@ impl TryFrom<Config> for SchedulerConfig {
542553
override_create_grpc_client_endpoint: None,
543554
override_metrics_collector: None,
544555
on_work_available: None,
556+
on_cancel_tasks: None,
545557
};
546558

547559
Ok(config)

ballista/scheduler/src/state/executor_manager.rs

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,39 @@ impl ExecutorManager {
119119

120120
/// Sends RPC requests to executors to cancel the specified running tasks.
121121
pub async fn cancel_running_tasks(&self, tasks: Vec<RunningTaskInfo>) -> Result<()> {
122-
let mut tasks_to_cancel: HashMap<String, Vec<protobuf::RunningTaskInfo>> =
122+
let mut tasks_by_executor: HashMap<String, Vec<RunningTaskInfo>> =
123123
Default::default();
124124

125125
for task_info in tasks {
126-
let infos = tasks_to_cancel.entry(task_info.executor_id).or_default();
127-
infos.push(protobuf::RunningTaskInfo {
128-
task_id: task_info.task_id as u32,
129-
job_id: task_info.job_id,
130-
stage_id: task_info.stage_id as u32,
131-
partition_id: task_info.partition_id as u32,
132-
});
126+
tasks_by_executor
127+
.entry(task_info.executor_id.clone())
128+
.or_default()
129+
.push(task_info);
130+
}
131+
132+
if let Some(cancel_callback) = &self.config.on_cancel_tasks {
133+
for (executor_id, infos) in tasks_by_executor {
134+
cancel_callback(&executor_id, infos);
135+
}
136+
return Ok(());
137+
}
138+
139+
let mut tasks_to_cancel: HashMap<String, Vec<protobuf::RunningTaskInfo>> =
140+
Default::default();
141+
142+
for (executor_id, infos) in tasks_by_executor {
143+
tasks_to_cancel.insert(
144+
executor_id,
145+
infos
146+
.into_iter()
147+
.map(|task_info| protobuf::RunningTaskInfo {
148+
task_id: task_info.task_id as u32,
149+
job_id: task_info.job_id,
150+
stage_id: task_info.stage_id as u32,
151+
partition_id: task_info.partition_id as u32,
152+
})
153+
.collect(),
154+
);
133155
}
134156

135157
let executor_manager = self.clone();
@@ -485,3 +507,65 @@ impl ExecutorManager {
485507
Ok(())
486508
}
487509
}
510+
511+
#[cfg(test)]
512+
mod tests {
513+
use super::*;
514+
use crate::cluster::memory::InMemoryClusterState;
515+
516+
#[tokio::test]
517+
async fn cancel_running_tasks_uses_callback() {
518+
let captured: Arc<std::sync::Mutex<HashMap<String, Vec<RunningTaskInfo>>>> =
519+
Arc::new(std::sync::Mutex::new(HashMap::new()));
520+
let callback_capture = Arc::clone(&captured);
521+
522+
let config = SchedulerConfig {
523+
on_cancel_tasks: Some(Arc::new(move |executor_id, tasks| {
524+
callback_capture
525+
.lock()
526+
.expect("callback capture lock")
527+
.insert(executor_id.to_string(), tasks);
528+
})),
529+
..SchedulerConfig::default()
530+
};
531+
532+
let manager = ExecutorManager::new(
533+
Arc::new(InMemoryClusterState::default()),
534+
Arc::new(config),
535+
);
536+
537+
let tasks = vec![
538+
RunningTaskInfo {
539+
task_id: 1,
540+
job_id: "job-1".to_string(),
541+
stage_id: 1,
542+
partition_id: 0,
543+
executor_id: "executor-a".to_string(),
544+
},
545+
RunningTaskInfo {
546+
task_id: 2,
547+
job_id: "job-1".to_string(),
548+
stage_id: 1,
549+
partition_id: 1,
550+
executor_id: "executor-a".to_string(),
551+
},
552+
RunningTaskInfo {
553+
task_id: 3,
554+
job_id: "job-2".to_string(),
555+
stage_id: 2,
556+
partition_id: 0,
557+
executor_id: "executor-b".to_string(),
558+
},
559+
];
560+
561+
manager
562+
.cancel_running_tasks(tasks)
563+
.await
564+
.expect("cancel should succeed");
565+
566+
let captured = captured.lock().expect("capture lock");
567+
assert_eq!(captured.len(), 2);
568+
assert_eq!(captured.get("executor-a").map(std::vec::Vec::len), Some(2));
569+
assert_eq!(captured.get("executor-b").map(std::vec::Vec::len), Some(1));
570+
}
571+
}

0 commit comments

Comments
 (0)