Skip to content

Commit d41d2a2

Browse files
committed
resolve gemini comments
1 parent d0453fd commit d41d2a2

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@
4141

4242

4343
# Type for the plan cache: hash -> (SavePlan, Metadata)
44-
# The SavePlan stored here is a "template" plan with empty tensor_data to avoid memory leaks.
44+
# The SavePlan stored here is a "template" plan with tensor_data set to None to avoid memory leaks.
4545
PlanCache = dict[int, tuple[SavePlan, torchdistsaver.Metadata]]
4646

47+
_MAX_PLAN_CACHE_SIZE = 16
48+
4749

4850
def _compute_plan_structure_hash(plan: SavePlan) -> int:
4951
"""Computes a hash of the plan structure (FQN, type, shape, dtype).
@@ -161,6 +163,12 @@ def generate_plan(
161163
cached_template_plan, global_metadata = cached_entry
162164
updated_local_plan = _rehydrate_plan(cached_template_plan, local_plan)
163165
_LOGGER.info("Plan cache HIT for hash %d. Skipping global planning (reduce_scatter).", plan_hash)
166+
167+
# Move to end to mark as recently used
168+
if plan_cache is not None:
169+
plan_cache.pop(plan_hash)
170+
plan_cache[plan_hash] = cached_entry
171+
164172
except ValueError:
165173
_LOGGER.warning(
166174
"Plan cache HIT for hash %d but rehydration failed. Falling back to global planning.",
@@ -200,6 +208,19 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]:
200208

201209
# Cache the result
202210
if plan_cache is not None:
211+
# TODO: Revisit this, ideally only one plan be cached in each training.
212+
# Check size and evict if needed (LRU policy)
213+
if len(plan_cache) >= _MAX_PLAN_CACHE_SIZE:
214+
# Remove the first item inserted (which will be the least recently used)
215+
# Note: Python 3.7+ dicts preserve insertion order.
216+
oldest_key = next(iter(plan_cache))
217+
plan_cache.pop(oldest_key)
218+
_LOGGER.debug(
219+
"Evicted oldest plan cache entry with hash %d. Cache size is now %d.",
220+
oldest_key,
221+
len(plan_cache),
222+
)
223+
203224
# Sanitize to avoid memory leaks
204225
sanitized_plan = _sanitize_plan_for_cache(updated_local_plan)
205226
plan_cache[plan_hash] = (sanitized_plan, global_metadata)

tests/adapter/pytorch/test_custom_state_dict_saver.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,101 @@ def test_generate_plan_fallback_on_rehydration_failure(
363363
dist_wrapper,
364364
plan_cache=plan_cache,
365365
)
366-
367366
# Then
368367
# Check that reduce_scatter WAS called (fallback occurred)
369368
assert mock_reduce_scatter.call_count == 1
370369

370+
def test_plan_cache_lru_behavior(self, mock_storage_writer, mock_save_planner, dist_wrapper, mocker):
371+
"""Tests that the plan cache respects the LRU policy."""
372+
# Given
373+
# Mock _MAX_PLAN_CACHE_SIZE to a small number for testing
374+
mocker.patch.object(custom_state_dict_saver, "_MAX_PLAN_CACHE_SIZE", 2)
375+
state_dict = {"model": "test"}
376+
global_metadata = Metadata(state_dict_metadata={})
377+
378+
mock_save_planner.create_global_plan.return_value = ([], global_metadata)
379+
mock_storage_writer.prepare_global_plan.return_value = []
380+
mock_save_planner.finish_plan.side_effect = lambda x: x
381+
mocker.patch.object(dist_wrapper, "broadcast_object", side_effect=lambda x: x)
382+
383+
# We need distinct ReduceScatter results for each call to distinguish them
384+
def reduce_scatter_side_effect(tag, local_fn, global_fn):
385+
return local_fn()
386+
387+
mocker.patch.object(dist_wrapper, "reduce_scatter", side_effect=reduce_scatter_side_effect)
388+
389+
plan_cache = {}
390+
391+
# 1. Insert Item A
392+
plan_a = SavePlan([WriteItem(index=MetadataIndex("A"), type=WriteItemType.TENSOR)])
393+
mock_save_planner.create_local_plan.return_value = plan_a
394+
mock_storage_writer.prepare_local_plan.return_value = plan_a
395+
396+
custom_state_dict_saver.generate_plan(
397+
CheckpointContainerId("/ckpt_a"),
398+
state_dict,
399+
mock_storage_writer,
400+
mock_save_planner,
401+
dist_wrapper,
402+
plan_cache,
403+
)
404+
assert len(plan_cache) == 1
405+
hash_a = custom_state_dict_saver._compute_plan_structure_hash(plan_a)
406+
407+
# 2. Insert Item B
408+
plan_b = SavePlan([WriteItem(index=MetadataIndex("B"), type=WriteItemType.TENSOR)])
409+
mock_save_planner.create_local_plan.return_value = plan_b
410+
mock_storage_writer.prepare_local_plan.return_value = plan_b
411+
412+
custom_state_dict_saver.generate_plan(
413+
CheckpointContainerId("/ckpt_b"),
414+
state_dict,
415+
mock_storage_writer,
416+
mock_save_planner,
417+
dist_wrapper,
418+
plan_cache,
419+
)
420+
assert len(plan_cache) == 2
421+
hash_b = custom_state_dict_saver._compute_plan_structure_hash(plan_b)
422+
423+
# 3. Access Item A (Mark as recently used)
424+
# We need to simulate a hit
425+
# The generate_plan logic computes hash based on local plan.
426+
# So we pass plan_a again.
427+
mock_save_planner.create_local_plan.return_value = plan_a
428+
mock_storage_writer.prepare_local_plan.return_value = plan_a
429+
430+
custom_state_dict_saver.generate_plan(
431+
CheckpointContainerId("/ckpt_a_2"),
432+
state_dict,
433+
mock_storage_writer,
434+
mock_save_planner,
435+
dist_wrapper,
436+
plan_cache,
437+
)
438+
# Verify A is now at the end (most recently used)
439+
keys = list(plan_cache.keys())
440+
assert keys[-1] == hash_a
441+
442+
# 4. Insert Item C (Should evict oldest, which is now B because A was accessed)
443+
plan_c = SavePlan([WriteItem(index=MetadataIndex("C"), type=WriteItemType.TENSOR)])
444+
mock_save_planner.create_local_plan.return_value = plan_c
445+
mock_storage_writer.prepare_local_plan.return_value = plan_c
446+
447+
custom_state_dict_saver.generate_plan(
448+
CheckpointContainerId("/ckpt_c"),
449+
state_dict,
450+
mock_storage_writer,
451+
mock_save_planner,
452+
dist_wrapper,
453+
plan_cache,
454+
)
455+
456+
assert len(plan_cache) == 2
457+
assert hash_a in plan_cache
458+
assert hash_b not in plan_cache # B should be evicted
459+
assert custom_state_dict_saver._compute_plan_structure_hash(plan_c) in plan_cache
460+
371461

372462
class TestWriteData:
373463
"""Tests for the write_data function."""

0 commit comments

Comments
 (0)