@@ -212,6 +212,31 @@ impl InMemoryShuffleManager {
212212 log:: debug!( "Removed all shuffle partitions for job: {job_id}" ) ;
213213 }
214214
215+ /// Removes all partitions for a given stage within a job.
216+ ///
217+ /// This is called when a stage's output has been fully consumed by the next stage,
218+ /// allowing the memory to be reclaimed immediately rather than waiting for job completion.
219+ ///
220+ /// # Arguments
221+ /// * `job_id` - The job identifier
222+ /// * `stage_id` - The stage identifier
223+ ///
224+ /// # Returns
225+ /// The number of partitions that were removed
226+ pub fn remove_stage_partitions ( & self , job_id : & str , stage_id : usize ) -> usize {
227+ let prefix = format ! ( "{job_id}/{stage_id}/" ) ;
228+ let initial_count = self . partitions . len ( ) ;
229+ self . partitions . retain ( |k, _| !k. starts_with ( & prefix) ) ;
230+ let removed = initial_count - self . partitions . len ( ) ;
231+ log:: debug!(
232+ "Removed {} shuffle partitions for stage: {}/{}" ,
233+ removed,
234+ job_id,
235+ stage_id
236+ ) ;
237+ removed
238+ }
239+
215240 /// Returns the total number of partitions stored in memory.
216241 pub fn partition_count ( & self ) -> usize {
217242 self . partitions . len ( )
@@ -334,4 +359,133 @@ mod tests {
334359 let key = InMemoryShuffleManager :: hash_partition_key ( "job1" , 1 , 2 , 3 ) ;
335360 assert_eq ! ( key, "job1/1/2/data-3" ) ;
336361 }
362+
363+ #[ test]
364+ fn test_remove_stage_partitions ( ) {
365+ let manager = InMemoryShuffleManager :: new ( ) ;
366+ let batch = create_test_batch ( ) ;
367+ let schema = batch. schema ( ) ;
368+
369+ // Store partitions for multiple stages in the same job
370+ for stage in 0 ..3 {
371+ for partition in 0 ..4 {
372+ let key =
373+ InMemoryShuffleManager :: partition_key ( "job1" , stage, partition) ;
374+ let data =
375+ ShufflePartitionData :: new ( schema. clone ( ) , vec ! [ batch. clone( ) ] ) ;
376+ manager. store_partition ( key, data) ;
377+ }
378+ }
379+
380+ assert_eq ! ( manager. partition_count( ) , 12 ) ;
381+
382+ // Remove stage 1 partitions
383+ let removed = manager. remove_stage_partitions ( "job1" , 1 ) ;
384+ assert_eq ! ( removed, 4 ) ;
385+ assert_eq ! ( manager. partition_count( ) , 8 ) ;
386+
387+ // Verify stage 0 and 2 partitions still exist
388+ let key0 = InMemoryShuffleManager :: partition_key ( "job1" , 0 , 0 ) ;
389+ let key2 = InMemoryShuffleManager :: partition_key ( "job1" , 2 , 0 ) ;
390+ assert ! ( manager. contains_partition( & key0) ) ;
391+ assert ! ( manager. contains_partition( & key2) ) ;
392+
393+ // Verify stage 1 partitions are gone
394+ let key1 = InMemoryShuffleManager :: partition_key ( "job1" , 1 , 0 ) ;
395+ assert ! ( !manager. contains_partition( & key1) ) ;
396+ }
397+
398+ #[ test]
399+ fn test_remove_stage_partitions_different_jobs ( ) {
400+ let manager = InMemoryShuffleManager :: new ( ) ;
401+ let batch = create_test_batch ( ) ;
402+ let schema = batch. schema ( ) ;
403+
404+ // Store partitions for stage 1 in two different jobs
405+ for job in [ "job1" , "job2" ] {
406+ for partition in 0 ..3 {
407+ let key = InMemoryShuffleManager :: partition_key ( job, 1 , partition) ;
408+ let data =
409+ ShufflePartitionData :: new ( schema. clone ( ) , vec ! [ batch. clone( ) ] ) ;
410+ manager. store_partition ( key, data) ;
411+ }
412+ }
413+
414+ assert_eq ! ( manager. partition_count( ) , 6 ) ;
415+
416+ // Remove stage 1 from job1 only
417+ let removed = manager. remove_stage_partitions ( "job1" , 1 ) ;
418+ assert_eq ! ( removed, 3 ) ;
419+ assert_eq ! ( manager. partition_count( ) , 3 ) ;
420+
421+ // Verify job2 stage 1 partitions still exist
422+ let key = InMemoryShuffleManager :: partition_key ( "job2" , 1 , 0 ) ;
423+ assert ! ( manager. contains_partition( & key) ) ;
424+ }
425+
426+ #[ test]
427+ fn test_remove_partition_returns_data ( ) {
428+ let manager = InMemoryShuffleManager :: new ( ) ;
429+ let batch = create_test_batch ( ) ;
430+ let schema = batch. schema ( ) ;
431+ let data = ShufflePartitionData :: new ( schema. clone ( ) , vec ! [ batch] ) ;
432+
433+ let key = InMemoryShuffleManager :: partition_key ( "job1" , 1 , 0 ) ;
434+ manager. store_partition ( key. clone ( ) , data) ;
435+
436+ assert ! ( manager. contains_partition( & key) ) ;
437+
438+ // Remove should return the data
439+ let removed = manager. remove_partition ( & key) ;
440+ assert ! ( removed. is_some( ) ) ;
441+ let removed_data = removed. unwrap ( ) ;
442+ assert_eq ! ( removed_data. num_rows, 3 ) ;
443+ assert_eq ! ( removed_data. num_batches, 1 ) ;
444+
445+ // Partition should no longer exist
446+ assert ! ( !manager. contains_partition( & key) ) ;
447+
448+ // Second remove should return None
449+ let removed_again = manager. remove_partition ( & key) ;
450+ assert ! ( removed_again. is_none( ) ) ;
451+ }
452+
453+ #[ test]
454+ fn test_total_memory_usage ( ) {
455+ let manager = InMemoryShuffleManager :: new ( ) ;
456+ let batch = create_test_batch ( ) ;
457+ let schema = batch. schema ( ) ;
458+
459+ // Store multiple partitions
460+ for i in 0 ..3 {
461+ let key = InMemoryShuffleManager :: partition_key ( "job1" , 1 , i) ;
462+ let data = ShufflePartitionData :: new ( schema. clone ( ) , vec ! [ batch. clone( ) ] ) ;
463+ manager. store_partition ( key, data) ;
464+ }
465+
466+ // Memory usage should be > 0
467+ let usage = manager. total_memory_usage ( ) ;
468+ assert ! ( usage > 0 ) ;
469+
470+ // Remove partitions and verify usage decreases
471+ manager. remove_job_partitions ( "job1" ) ;
472+ assert_eq ! ( manager. total_memory_usage( ) , 0 ) ;
473+ }
474+
475+ #[ test]
476+ fn test_clear ( ) {
477+ let manager = InMemoryShuffleManager :: new ( ) ;
478+ let batch = create_test_batch ( ) ;
479+ let schema = batch. schema ( ) ;
480+
481+ for i in 0 ..5 {
482+ let key = InMemoryShuffleManager :: partition_key ( "job1" , 1 , i) ;
483+ let data = ShufflePartitionData :: new ( schema. clone ( ) , vec ! [ batch. clone( ) ] ) ;
484+ manager. store_partition ( key, data) ;
485+ }
486+
487+ assert_eq ! ( manager. partition_count( ) , 5 ) ;
488+ manager. clear ( ) ;
489+ assert_eq ! ( manager. partition_count( ) , 0 ) ;
490+ }
337491}
0 commit comments