@@ -119,17 +119,39 @@ impl ExecutorManager {
119119
120120 /// Sends RPC requests to executors to cancel the specified running tasks.
121121 pub async fn cancel_running_tasks ( & self , tasks : Vec < RunningTaskInfo > ) -> Result < ( ) > {
122- let mut tasks_to_cancel : HashMap < String , Vec < protobuf :: RunningTaskInfo > > =
122+ let mut tasks_by_executor : HashMap < String , Vec < RunningTaskInfo > > =
123123 Default :: default ( ) ;
124124
125125 for task_info in tasks {
126- let infos = tasks_to_cancel. entry ( task_info. executor_id ) . or_default ( ) ;
127- infos. push ( protobuf:: RunningTaskInfo {
128- task_id : task_info. task_id as u32 ,
129- job_id : task_info. job_id ,
130- stage_id : task_info. stage_id as u32 ,
131- partition_id : task_info. partition_id as u32 ,
132- } ) ;
126+ tasks_by_executor
127+ . entry ( task_info. executor_id . clone ( ) )
128+ . or_default ( )
129+ . push ( task_info) ;
130+ }
131+
132+ if let Some ( cancel_callback) = & self . config . on_cancel_tasks {
133+ for ( executor_id, infos) in tasks_by_executor {
134+ cancel_callback ( & executor_id, infos) ;
135+ }
136+ return Ok ( ( ) ) ;
137+ }
138+
139+ let mut tasks_to_cancel: HashMap < String , Vec < protobuf:: RunningTaskInfo > > =
140+ Default :: default ( ) ;
141+
142+ for ( executor_id, infos) in tasks_by_executor {
143+ tasks_to_cancel. insert (
144+ executor_id,
145+ infos
146+ . into_iter ( )
147+ . map ( |task_info| protobuf:: RunningTaskInfo {
148+ task_id : task_info. task_id as u32 ,
149+ job_id : task_info. job_id ,
150+ stage_id : task_info. stage_id as u32 ,
151+ partition_id : task_info. partition_id as u32 ,
152+ } )
153+ . collect ( ) ,
154+ ) ;
133155 }
134156
135157 let executor_manager = self . clone ( ) ;
@@ -485,3 +507,65 @@ impl ExecutorManager {
485507 Ok ( ( ) )
486508 }
487509}
510+
511+ #[ cfg( test) ]
512+ mod tests {
513+ use super :: * ;
514+ use crate :: cluster:: memory:: InMemoryClusterState ;
515+
516+ #[ tokio:: test]
517+ async fn cancel_running_tasks_uses_callback ( ) {
518+ let captured: Arc < std:: sync:: Mutex < HashMap < String , Vec < RunningTaskInfo > > > > =
519+ Arc :: new ( std:: sync:: Mutex :: new ( HashMap :: new ( ) ) ) ;
520+ let callback_capture = Arc :: clone ( & captured) ;
521+
522+ let config = SchedulerConfig {
523+ on_cancel_tasks : Some ( Arc :: new ( move |executor_id, tasks| {
524+ callback_capture
525+ . lock ( )
526+ . expect ( "callback capture lock" )
527+ . insert ( executor_id. to_string ( ) , tasks) ;
528+ } ) ) ,
529+ ..SchedulerConfig :: default ( )
530+ } ;
531+
532+ let manager = ExecutorManager :: new (
533+ Arc :: new ( InMemoryClusterState :: default ( ) ) ,
534+ Arc :: new ( config) ,
535+ ) ;
536+
537+ let tasks = vec ! [
538+ RunningTaskInfo {
539+ task_id: 1 ,
540+ job_id: "job-1" . to_string( ) ,
541+ stage_id: 1 ,
542+ partition_id: 0 ,
543+ executor_id: "executor-a" . to_string( ) ,
544+ } ,
545+ RunningTaskInfo {
546+ task_id: 2 ,
547+ job_id: "job-1" . to_string( ) ,
548+ stage_id: 1 ,
549+ partition_id: 1 ,
550+ executor_id: "executor-a" . to_string( ) ,
551+ } ,
552+ RunningTaskInfo {
553+ task_id: 3 ,
554+ job_id: "job-2" . to_string( ) ,
555+ stage_id: 2 ,
556+ partition_id: 0 ,
557+ executor_id: "executor-b" . to_string( ) ,
558+ } ,
559+ ] ;
560+
561+ manager
562+ . cancel_running_tasks ( tasks)
563+ . await
564+ . expect ( "cancel should succeed" ) ;
565+
566+ let captured = captured. lock ( ) . expect ( "capture lock" ) ;
567+ assert_eq ! ( captured. len( ) , 2 ) ;
568+ assert_eq ! ( captured. get( "executor-a" ) . map( std:: vec:: Vec :: len) , Some ( 2 ) ) ;
569+ assert_eq ! ( captured. get( "executor-b" ) . map( std:: vec:: Vec :: len) , Some ( 1 ) ) ;
570+ }
571+ }
0 commit comments