@@ -233,6 +233,10 @@ class RendezvousSettings:
233233 If set to True (default), nodes from the redundancy list and new arrivals are migrated
234234 to the wait list. If set to False, new arrivals will be moved to the redundancy list
235235 and will wait there until the next rendezvous round.
236+ use_infra_group_rank:
237+ Whether to use infrastructure group rank for rank assignment instead of sorted
238+ participant-based assignment. If True, ranks are read from SLURM_PROCID (in SLURM
239+ environments) or GROUP_RANK (set by launcher) environment variables.
236240 """
237241
238242 run_id : str
@@ -242,6 +246,7 @@ class RendezvousSettings:
242246 keep_alive_interval : timedelta
243247 keep_alive_max_attempt : int
244248 upscaling_enabled : bool = True
249+ use_infra_group_rank : bool = True
245250
246251
247252@dataclass (eq = True , order = True , frozen = True )
@@ -789,8 +794,22 @@ def _add_to_participants(self) -> None:
789794 log .debug (f"Node { self ._node } was not in the wait list." )
790795
791796 # The ranks of the participants will be set once the rendezvous is
792- # complete.
793- state .participants [self ._node ] = 0
797+ # complete. If use_infra_group_rank is enabled, store the infrastructure
798+ # rank (SLURM_PROCID or GROUP_RANK) here; otherwise, use placeholder -1.
799+ if self ._settings .use_infra_group_rank :
800+ # Try SLURM_PROCID first (set by SLURM), then fall back to GROUP_RANK (set by launcher)
801+ infra_rank_str = os .getenv ('SLURM_PROCID' , os .getenv ('GROUP_RANK' , '-1' ))
802+ infra_rank = int (infra_rank_str )
803+ if infra_rank < 0 :
804+ raise ValueError (
805+ "use_infra_group_rank is enabled but neither SLURM_PROCID nor GROUP_RANK "
806+ "environment variable is set. Please set one of these environment variables "
807+ "or disable use_infra_group_rank."
808+ )
809+ state .participants [self ._node ] = infra_rank
810+ log .debug (f"Node { self ._node } stored infrastructure rank { infra_rank } from environment" )
811+ else :
812+ state .participants [self ._node ] = 0
794813
795814 self ._keep_alive ()
796815
@@ -874,16 +893,61 @@ def _remove_from_redundancy_list(self) -> None:
874893
875894 @staticmethod
876895 def _assign_ranks (
877- participants : Dict [_NodeDesc , int ], prev : Dict [_NodeDesc , int ]
896+ participants : Dict [_NodeDesc , int ],
897+ prev : Dict [_NodeDesc , int ],
898+ use_infra_group_rank : bool = False ,
878899 ) -> Dict [_NodeDesc , int ]:
879- # Assign ranks. Re-use assigment from the previous round as much as possible
900+ """
901+ Assign ranks to participants in the rendezvous.
902+
903+ Behavior depends on use_infra_group_rank:
904+
905+ 1. If use_infra_group_rank=True:
906+ - ALWAYS use infrastructure ranks directly from SLURM_PROCID or GROUP_RANK
907+ - Previous assignments are ignored
908+ - Validates that all ranks are in range [0, world_size) and unique
909+ - Ensures consistency with infrastructure's rank assignment
910+ - Note: Hot spare/redundancy is NOT supported in this mode as dynamic
911+ rendezvous cannot guarantee lower ranks join as participants first
912+
913+ 2. If use_infra_group_rank=False:
914+ - Use deterministic assignment, preserving previous ranks when possible
915+ - Fill gaps left by failed nodes with new participants
916+
917+ Args:
918+ participants: Dict mapping node descriptors to infrastructure ranks
919+ prev: Dict of previous rank assignments (empty on first rendezvous)
920+ use_infra_group_rank: If True, always use infrastructure ranks
921+
922+ Returns:
923+ Dict mapping node descriptors to assigned ranks
924+ """
925+ # If use_infra_group_rank is enabled, use the infrastructure ranks directly
926+ if use_infra_group_rank :
927+ # Validate that all participants have valid infrastructure ranks
928+ for node , rank in participants .items ():
929+ if rank < 0 or rank >= len (participants ):
930+ raise ValueError (
931+ f"Invalid infrastructure rank { rank } for node { node } . "
932+ f"Expected rank in range [0, { len (participants )} )"
933+ )
934+ # Check for duplicate ranks
935+ ranks_set = set (participants .values ())
936+ if len (ranks_set ) != len (participants ):
937+ raise ValueError (
938+ f"Duplicate infrastructure ranks detected in participants: { participants } "
939+ )
940+ log .debug (f"Using infrastructure ranks directly: { participants } " )
941+ return dict (participants )
942+
943+ # Default behavior: Assign ranks. Re-use assignment from the previous round as much as possible
880944 world_size = len (participants )
881945 sorted_keys = sorted (participants .keys ())
882946 free_ranks = set (range (world_size ))
883947 res = {}
884948 for p in sorted_keys :
885949 prev_rank = prev .get (p , - 1 )
886- if prev_rank >= 0 and prev_rank < world_size :
950+ if prev_rank >= 0 and prev_rank < world_size and prev_rank in free_ranks :
887951 # if this node can have the same rank, use it
888952 res [p ] = prev_rank
889953 free_ranks .remove (prev_rank )
@@ -920,7 +984,9 @@ def _mark_rendezvous_complete(self) -> None:
920984 state .wait_list .clear ()
921985
922986 # Will try to preserve node<->rank mapping
923- state .participants = self ._assign_ranks (state .participants , self ._prev_participants )
987+ state .participants = self ._assign_ranks (
988+ state .participants , self ._prev_participants , self ._settings .use_infra_group_rank
989+ )
924990
925991 # Set initial worker states, assume all workers are healthy at the beginning
926992 state .worker_states = {n : WorkerState .HEALTHY for n in state .participants }
@@ -1156,6 +1222,7 @@ def from_backend(
11561222 local_addr : Optional [str ] = None ,
11571223 timeout : Optional [RendezvousTimeout ] = None ,
11581224 upscaling_enabled : bool = True ,
1225+ use_infra_group_rank : bool = False ,
11591226 ):
11601227 """Create a new :py:class:`FtRendezvousHandler`.
11611228
@@ -1176,6 +1243,8 @@ def from_backend(
11761243 The timeout configuration of the rendezvous.
11771244 upscaling_enabled:
11781245 Whether to enable upscaling of a completed rendezvous with redundant or new nodes.
1246+ use_infra_group_rank:
1247+ Whether to use infrastructure group rank for rank assignment.
11791248 """
11801249 # We associate each handler instance with a unique node descriptor.
11811250 node = cls ._node_desc_generator .generate (local_addr )
@@ -1188,6 +1257,7 @@ def from_backend(
11881257 keep_alive_interval = timedelta (seconds = 5 ),
11891258 keep_alive_max_attempt = 3 ,
11901259 upscaling_enabled = upscaling_enabled ,
1260+ use_infra_group_rank = use_infra_group_rank ,
11911261 )
11921262
11931263 state_holder = _BackendRendezvousStateHolder (backend , settings )
@@ -1657,6 +1727,10 @@ def create_handler(
16571727 | | :py:meth:`RendezvousHandler.shutdown`. Defaults to |
16581728 | | 30 seconds. |
16591729 +-------------------+------------------------------------------------------+
1730+ | use_infra_group_ | Whether to always use infrastructure group rank for |
1731+ | rank | rank assignment. Previous assignments are ignored. |
1732+ | | Hot spare/redundancy NOT supported. Defaults to True.|
1733+ +-------------------+------------------------------------------------------+
16601734 """
16611735 try :
16621736 timeout = RendezvousTimeout (
@@ -1667,6 +1741,7 @@ def create_handler(
16671741
16681742 # torchrun default behaviour if not specified otherwise
16691743 upscale_completed = params .config .get ('upscaling_enabled' , True )
1744+ use_infra_group_rank = params .config .get ('use_infra_group_rank' , True )
16701745
16711746 return FtRendezvousHandler .from_backend (
16721747 params .run_id ,
@@ -1677,6 +1752,7 @@ def create_handler(
16771752 params .local_addr ,
16781753 timeout ,
16791754 upscaling_enabled = upscale_completed ,
1755+ use_infra_group_rank = use_infra_group_rank ,
16801756 )
16811757 except Exception as e :
16821758 construct_and_record_rdzv_event (
0 commit comments