2222import struct
2323from collections import defaultdict
2424from pathlib import Path
25- from typing import IO , List , Optional , Set , Tuple , TypeVar , cast
25+ from typing import IO , Callable , List , Optional , Set , Tuple , TypeVar , cast
2626
2727import torch
2828import torch .distributed as dist
@@ -128,6 +128,12 @@ def __init__(
128128 self ,
129129 checkpoint_object_manager : CheckpointObjectManager ,
130130 replication_manager : ReplicationManager ,
131+ * ,
132+ global_rank_getter : Callable [[], int ],
133+ local_rank_getter : Callable [[], int ],
134+ broadcast_object_list_func : Callable [..., None ],
135+ all_gather_object_func : Callable [..., None ],
136+ world_size_getter : Callable [[], int ],
131137 ):
132138 """Initializes the DefaultMLFlashpointCheckpointLoader.
133139
@@ -136,9 +142,21 @@ def __init__(
136142 reading data.
137143 replication_manager: The replication manager to use for retrieving
138144 missing checkpoint objects from peer nodes.
145+ global_rank_getter: A callable that returns the global rank.
146+ local_rank_getter: A callable that returns the node-local rank.
147+ broadcast_object_list_func: A callable with the same signature as
148+ ``torch.distributed.broadcast_object_list``.
149+ all_gather_object_func: A callable with the same signature as
150+ ``torch.distributed.all_gather_object``.
151+ world_size_getter: A callable that returns the world size.
139152 """
140153 self ._checkpoint_object_manager = checkpoint_object_manager
141154 self ._replication_manager = replication_manager
155+ self ._global_rank_getter = global_rank_getter
156+ self ._local_rank_getter = local_rank_getter
157+ self ._broadcast_object_list_func = broadcast_object_list_func
158+ self ._all_gather_object_func = all_gather_object_func
159+ self ._world_size_getter = world_size_getter
142160 # Cache for available objects: CheckpointContainerId -> dict[object_path, list[rank]]
143161 self ._available_objects_cache : dict [CheckpointContainerId , dict [str , List [int ]]] = {}
144162
@@ -337,8 +355,7 @@ def get_latest_complete_checkpoint(
337355 else continue to the next candidate checkpoint
338356 - return the checkpoint container id of the latest complete checkpoint
339357 """
340- # TODO: use global_rank_getter and local_rank_getter.
341- rank = dist .get_rank ()
358+ rank = self ._global_rank_getter ()
342359 _LOGGER .debug (
343360 "Rank %s: Getting latest complete checkpoint for '%s'" ,
344361 rank ,
@@ -382,7 +399,7 @@ def get_latest_complete_checkpoint(
382399 retrieval_plan = self ._compute_retrieval_plan (checkpoint , available_objects_by_rank )
383400 # Broadcast the retrieval plan to all ranks.
384401 plan_container = [retrieval_plan ]
385- dist . broadcast_object_list (plan_container , src = planner_rank )
402+ self . _broadcast_object_list_func (plan_container , src = planner_rank )
386403 retrieval_plan = plan_container [0 ]
387404
388405 if retrieval_plan is None :
@@ -451,7 +468,7 @@ def _compute_retrieval_plan(
451468
452469 objects_needed_by_local_rank_0 .update (self ._get_extra_needed_objects (checkpoint , available_objects_by_rank ))
453470
454- world_size = dist . get_world_size ()
471+ world_size = self . _world_size_getter ()
455472 num_nodes = get_num_of_nodes ()
456473 ranks_per_node = world_size // num_nodes
457474
@@ -507,8 +524,8 @@ def get_candidate_checkpoints(
507524
508525 # Scan locally only on the first rank of each node
509526 base_path = Path (checkpoint_base_container .data )
510- rank = dist . get_rank ()
511- local_rank = dist . get_node_local_rank ()
527+ rank = self . _global_rank_getter ()
528+ local_rank = self . _local_rank_getter ()
512529
513530 local_candidate_ckpt_ids = []
514531
@@ -532,8 +549,8 @@ def get_candidate_checkpoints(
532549 else :
533550 _LOGGER .debug ("Rank %s: Base path '%s' is not a directory or does not exist." , rank , base_path )
534551
535- all_checkpoint_container_path_lists = [None for _ in range (dist . get_world_size ())]
536- dist . all_gather_object (all_checkpoint_container_path_lists , local_candidate_ckpt_ids )
552+ all_checkpoint_container_path_lists = [None for _ in range (self . _world_size_getter ())]
553+ self . _all_gather_object_func (all_checkpoint_container_path_lists , local_candidate_ckpt_ids )
537554 _LOGGER .debug (
538555 "Rank %s: Gathered checkpoint container paths from all ranks: '%s'" ,
539556 rank ,
@@ -589,8 +606,8 @@ def get_checkpoint_objects_by_rank(
589606
590607 local_objects .extend (self ._get_extra_local_objects (container_path ))
591608
592- all_objects_by_rank_paths = [None for _ in range (dist . get_world_size ())]
593- dist . all_gather_object (all_objects_by_rank_paths , local_objects )
609+ all_objects_by_rank_paths = [None for _ in range (self . _world_size_getter ())]
610+ self . _all_gather_object_func (all_objects_by_rank_paths , local_objects )
594611
595612 result = {}
596613 object_locations = defaultdict (list )
@@ -620,7 +637,7 @@ def retrieve_checkpoint(
620637 If empty for this rank, no retrieval is needed.
621638 """
622639
623- rank = dist . get_rank ()
640+ rank = self . _global_rank_getter ()
624641 all_success = True
625642
626643 # Only proceed with retrieval if we have items to retrieve
@@ -656,8 +673,8 @@ def retrieve_checkpoint(
656673
657674 # Gather success status from all ranks
658675 _LOGGER .debug ("Gathering success status from all ranks" )
659- all_success_list = [None for _ in range (dist . get_world_size ())]
660- dist . all_gather_object (all_success_list , all_success )
676+ all_success_list = [None for _ in range (self . _world_size_getter ())]
677+ self . _all_gather_object_func (all_success_list , all_success )
661678 _LOGGER .debug ("All success list: '%s'" , all_success_list )
662679 return all (all_success_list )
663680
0 commit comments