@@ -379,6 +379,8 @@ def _compute_retrieval_plan(
379379 ) -> Optional [dict [int , List [Tuple [int , str ]]]]:
380380 """Computes the retrieval plan.
381381
382+ The plan assumes an even number of ranks (accelerator processes) on each node in the training cluster.
383+
382384 Args:
383385 checkpoint: The checkpoint container ID.
384386 available_objects_by_rank: Map of rank to available objects on that rank.
@@ -540,15 +542,15 @@ def get_checkpoint_objects_by_rank(
540542 `CheckpointObjectId`s available on that node.
541543 """
542544 container_path = Path (checkpoint_container_id .data )
543- local_objects = []
545+ local_objects : List [ CheckpointObjectId ] = []
544546 if not container_path .is_dir ():
545547 _LOGGER .debug (
546548 "Checkpoint container path '%s' is not a directory. Returning empty list." ,
547549 container_path ,
548550 )
549551 else :
550552 for entry in os .listdir (container_path ):
551- local_objects .append (entry )
553+ local_objects .append (CheckpointObjectId . from_container ( checkpoint_container_id , entry ) )
552554
553555 local_objects .extend (self ._get_extra_local_objects (container_path ))
554556
@@ -560,11 +562,8 @@ def get_checkpoint_objects_by_rank(
560562 if all_objects_by_rank_paths :
561563 for rank , objects in enumerate (all_objects_by_rank_paths ):
562564 if objects :
563- # Convert filenames to full paths and then to CheckpointObjectId
564- full_paths = [str (container_path / obj ) for obj in objects ]
565- checkpoint_objects = [CheckpointObjectId (p ) for p in full_paths ]
566- result [rank ] = checkpoint_objects
567- for obj in checkpoint_objects :
565+ result [rank ] = objects
566+ for obj in objects :
568567 object_locations [obj .data ].append (rank )
569568 else :
570569 result [rank ] = []
@@ -628,13 +627,13 @@ def retrieve_checkpoint(
628627 return all (all_success_list )
629628
630629 def _get_extra_local_objects (self , container_path : Path ) -> List [str ]:
631- """Hook for subclasses to provide extra local objects that are available and relevant,
630+ """Hook for subclasses to provide extra local objects that are available and relevant,
632631 which may be needed by other hosts.
633- This can be used when additional objects beyond the standard checkpoint data are needed,
632+ This can be used when additional objects beyond the standard checkpoint data are needed,
634633 such as framework-specific context data.
635-
634+
636635 This should always be implemented alongside `_get_extra_needed_objects`.
637-
636+
638637 Returns:
639638 List of additional locally available objects.
640639 """
@@ -646,11 +645,11 @@ def _get_extra_needed_objects(
646645 available_objects_by_rank : dict [int , List [CheckpointObjectId ]],
647646 ) -> Set [str ]:
648647 """Hook for subclasses to provide extra needed objects on any given host (each local rank 0).
649- This can leverage `available_objects_by_rank` to determine the set of additional objects
648+ This can leverage `available_objects_by_rank` to determine the set of additional objects
650649 each host needs.
651-
650+
652651 This should always be implemented alongside `_get_extra_local_objects`.
653-
652+
654653 Returns:
655654 Set of extra needed objects on any given node.
656655 """
0 commit comments