2626import warnings
2727from typing import Callable , Optional , Union
2828
29+ from nvidia_resiliency_ext .shared_utils .log_manager import LogConfig
30+
2931from . import exception , utils
3032from .state import Mode , State
3133from .store import StoreMixin
@@ -177,7 +179,8 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
177179 active_rank = None
178180 # Log deactivation if transitioning from ACTIVE to INACTIVE
179181 if state .mode == Mode .ACTIVE :
180- log = logging .getLogger (__name__ )
182+ log = logging .getLogger (LogConfig .name )
183+
181184 log .info (
182185 f"[In-process] Rank deactivated (rank={ state .rank } ) due to max active world size limit ({ active_world_size } )"
183186 )
@@ -224,7 +227,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
224227 active_rank = None
225228 # Log deactivation if transitioning from ACTIVE to INACTIVE
226229 if state .mode == Mode .ACTIVE :
227- log = logging .getLogger (__name__ )
230+ log = logging .getLogger (LogConfig . name )
228231 log .info (
229232 f"[In-process] Rank deactivated (rank={ state .rank } ) due to divisibility requirement (active_world_size={ active_world_size } , divisor={ divisor } )"
230233 )
@@ -349,7 +352,7 @@ def __repr__(self):
349352 return f'{ type (self ).__name__ } ({ self .name = } )'
350353
351354
352- def bounded_activate (node , counter , path = None ):
355+ def bounded_activate (node , counter , path = None , current_state = None ):
353356 if path is None :
354357 path = []
355358
@@ -361,17 +364,29 @@ def bounded_activate(node, counter, path=None):
361364 for ascendant in path
362365 )
363366 ):
367+ # Log activation if this is the current rank
368+ if current_state and current_state .initial_rank == node .state .initial_rank :
369+ log = logging .getLogger (LogConfig .name )
370+ log .info (
371+ f"[In-process] Rank activated (initial_rank={ node .state .initial_rank } , active_rank={ counter } ) in topology tree"
372+ )
364373 node .activate (counter )
365374 counter += 1
366375 for ascendant in path :
367376 ascendant .active_count += 1
368377 else :
378+ # Log deactivation if this is the current rank
379+ if current_state and current_state .initial_rank == node .state .initial_rank :
380+ log = logging .getLogger (LogConfig .name )
381+ log .info (
382+ f"[In-process] Rank deactivated (initial_rank={ node .state .initial_rank } ) due to max_ranks constraint in topology layer"
383+ )
369384 node .deactivate ()
370385
371386 path .append (node )
372387
373388 for child in node .children .values ():
374- counter = bounded_activate (child , counter , path )
389+ counter = bounded_activate (child , counter , path , current_state )
375390 path .pop ()
376391 return counter
377392
@@ -574,7 +589,7 @@ def build_tree(self, state, store):
574589 def replace_with_inactive (self , terminated_active_ranks ):
575590 replaced_terminate_active_ranks = set ()
576591
577- log = logging .getLogger (__name__ )
592+ log = logging .getLogger (LogConfig . name )
578593
579594 for terminated_active_rank in sorted (terminated_active_ranks ):
580595 terminated_active_node = self .rank_map [terminated_active_rank ]
@@ -625,7 +640,7 @@ def replace_with_backfill(self, unhandled_terminations):
625640 key = lambda node : node .state .active_rank ,
626641 )
627642
628- log = logging .getLogger (__name__ )
643+ log = logging .getLogger (LogConfig . name )
629644 for backfill_node , terminated_node in itertools .zip_longest (
630645 reversed (largest_active_nodes ),
631646 terminated_nodes ,
@@ -647,7 +662,7 @@ def replace_with_backfill(self, unhandled_terminations):
647662
648663 def shift_ranks (self , replaced_active , unhandled_terminations ):
649664 sorted_replaced_active = sorted (replaced_active )
650- log = logging .getLogger (__name__ )
665+ log = logging .getLogger (LogConfig . name )
651666
652667 for n in self .rank_map .values ():
653668 n .state .active_world_size -= len (unhandled_terminations )
@@ -672,7 +687,7 @@ def filter_active_world_size(self):
672687 new_active_world_size = self .world_size_filter (active_world_size )
673688 assert new_active_world_size <= active_world_size
674689
675- log = logging .getLogger (__name__ )
690+ log = logging .getLogger (LogConfig . name )
676691 for leaf in self .tree .iter_leaves ():
677692 leaf .state .active_world_size = new_active_world_size
678693 if leaf .state .mode == Mode .ACTIVE and leaf .state .active_rank >= new_active_world_size :
@@ -722,7 +737,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
722737 if self .tree is None :
723738 self .build_tree (state , store )
724739
725- active_world_size = bounded_activate (self .tree , 0 )
740+ active_world_size = bounded_activate (self .tree , 0 , None , self . current_state )
726741 for node in self .rank_map .values ():
727742 node .state .active_world_size = active_world_size
728743
@@ -738,7 +753,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
738753 rank for rank in terminated_ranks if self .rank_map [rank ].state .mode == Mode .ACTIVE
739754 )
740755
741- log = logging .getLogger (__name__ )
756+ log = logging .getLogger (LogConfig . name )
742757 for terminated_rank in terminated_ranks :
743758 # If this rank is being terminated, log it
744759 if self .current_state .initial_rank == self .rank_map [terminated_rank ].state .initial_rank :
@@ -808,7 +823,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
808823 terminated_ranks = utils .format_rank_set (terminated_ranks )
809824 raise RankDiscarded (f'{ rank = } { terminated_ranks = } ' )
810825 elif rank >= world_size :
811- log = logging .getLogger (__name__ )
826+ log = logging .getLogger (LogConfig . name )
812827 old_rank = rank
813828 rank = ordered_terminated_ranks [rank - world_size ]
814829 log .info (
@@ -869,7 +884,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
869884 old_rank = rank
870885 rank = rank - sum (rank > terminated_rank for terminated_rank in terminated_ranks )
871886 if old_rank != rank :
872- log = logging .getLogger (__name__ )
887+ log = logging .getLogger (LogConfig . name )
873888 log .info (f"[In-process] Rank shifted (rank changed from { old_rank } to { rank } )" )
874889
875890 state = dataclasses .replace (
@@ -982,7 +997,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
982997
983998 group_count = int (store .get (prefixed_key ))
984999 if not self .condition (group_count ):
985- log = logging .getLogger (__name__ )
1000+ log = logging .getLogger (LogConfig . name )
9861001 log .info (
9871002 f"[In-process] Rank marked for termination (rank={ rank } , group_key={ key } , group_count={ group_count } ) due to failed group condition"
9881003 )
0 commit comments