@@ -363,11 +363,101 @@ def test_generate_plan_fallback_on_rehydration_failure(
363363 dist_wrapper ,
364364 plan_cache = plan_cache ,
365365 )
366-
367366 # Then
368367 # Check that reduce_scatter WAS called (fallback occurred)
369368 assert mock_reduce_scatter .call_count == 1
370369
370+ def test_plan_cache_lru_behavior (self , mock_storage_writer , mock_save_planner , dist_wrapper , mocker ):
371+ """Tests that the plan cache respects the LRU policy."""
372+ # Given
373+ # Mock _MAX_PLAN_CACHE_SIZE to a small number for testing
374+ mocker .patch .object (custom_state_dict_saver , "_MAX_PLAN_CACHE_SIZE" , 2 )
375+ state_dict = {"model" : "test" }
376+ global_metadata = Metadata (state_dict_metadata = {})
377+
378+ mock_save_planner .create_global_plan .return_value = ([], global_metadata )
379+ mock_storage_writer .prepare_global_plan .return_value = []
380+ mock_save_planner .finish_plan .side_effect = lambda x : x
381+ mocker .patch .object (dist_wrapper , "broadcast_object" , side_effect = lambda x : x )
382+
383+ # We need distinct ReduceScatter results for each call to distinguish them
384+ def reduce_scatter_side_effect (tag , local_fn , global_fn ):
385+ return local_fn ()
386+
387+ mocker .patch .object (dist_wrapper , "reduce_scatter" , side_effect = reduce_scatter_side_effect )
388+
389+ plan_cache = {}
390+
391+ # 1. Insert Item A
392+ plan_a = SavePlan ([WriteItem (index = MetadataIndex ("A" ), type = WriteItemType .TENSOR )])
393+ mock_save_planner .create_local_plan .return_value = plan_a
394+ mock_storage_writer .prepare_local_plan .return_value = plan_a
395+
396+ custom_state_dict_saver .generate_plan (
397+ CheckpointContainerId ("/ckpt_a" ),
398+ state_dict ,
399+ mock_storage_writer ,
400+ mock_save_planner ,
401+ dist_wrapper ,
402+ plan_cache ,
403+ )
404+ assert len (plan_cache ) == 1
405+ hash_a = custom_state_dict_saver ._compute_plan_structure_hash (plan_a )
406+
407+ # 2. Insert Item B
408+ plan_b = SavePlan ([WriteItem (index = MetadataIndex ("B" ), type = WriteItemType .TENSOR )])
409+ mock_save_planner .create_local_plan .return_value = plan_b
410+ mock_storage_writer .prepare_local_plan .return_value = plan_b
411+
412+ custom_state_dict_saver .generate_plan (
413+ CheckpointContainerId ("/ckpt_b" ),
414+ state_dict ,
415+ mock_storage_writer ,
416+ mock_save_planner ,
417+ dist_wrapper ,
418+ plan_cache ,
419+ )
420+ assert len (plan_cache ) == 2
421+ hash_b = custom_state_dict_saver ._compute_plan_structure_hash (plan_b )
422+
423+ # 3. Access Item A (Mark as recently used)
424+ # We need to simulate a hit
425+ # The generate_plan logic computes hash based on local plan.
426+ # So we pass plan_a again.
427+ mock_save_planner .create_local_plan .return_value = plan_a
428+ mock_storage_writer .prepare_local_plan .return_value = plan_a
429+
430+ custom_state_dict_saver .generate_plan (
431+ CheckpointContainerId ("/ckpt_a_2" ),
432+ state_dict ,
433+ mock_storage_writer ,
434+ mock_save_planner ,
435+ dist_wrapper ,
436+ plan_cache ,
437+ )
438+ # Verify A is now at the end (most recently used)
439+ keys = list (plan_cache .keys ())
440+ assert keys [- 1 ] == hash_a
441+
442+ # 4. Insert Item C (Should evict oldest, which is now B because A was accessed)
443+ plan_c = SavePlan ([WriteItem (index = MetadataIndex ("C" ), type = WriteItemType .TENSOR )])
444+ mock_save_planner .create_local_plan .return_value = plan_c
445+ mock_storage_writer .prepare_local_plan .return_value = plan_c
446+
447+ custom_state_dict_saver .generate_plan (
448+ CheckpointContainerId ("/ckpt_c" ),
449+ state_dict ,
450+ mock_storage_writer ,
451+ mock_save_planner ,
452+ dist_wrapper ,
453+ plan_cache ,
454+ )
455+
456+ assert len (plan_cache ) == 2
457+ assert hash_a in plan_cache
458+ assert hash_b not in plan_cache # B should be evicted
459+ assert custom_state_dict_saver ._compute_plan_structure_hash (plan_c ) in plan_cache
460+
371461
372462class TestWriteData :
373463 """Tests for the write_data function."""
0 commit comments