2626import warnings
2727from typing import Callable , Optional , Union
2828
29+ from nvidia_resiliency_ext .shared_utils .log_manager import LogConfig
30+
2931from . import exception , utils
3032from .state import Mode , State
3133from .store import StoreMixin
@@ -177,7 +179,8 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
177179 active_rank = None
178180 # Log deactivation if transitioning from ACTIVE to INACTIVE
179181 if state .mode == Mode .ACTIVE :
180- log = logging .getLogger (__name__ )
182+ log = logging .getLogger (LogConfig .name )
183+
181184 log .info (
182185 f"[In-process] Rank deactivated (rank={ state .rank } ) due to max active world size limit ({ active_world_size } )"
183186 )
@@ -224,7 +227,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
224227 active_rank = None
225228 # Log deactivation if transitioning from ACTIVE to INACTIVE
226229 if state .mode == Mode .ACTIVE :
227- log = logging .getLogger (__name__ )
230+ log = logging .getLogger (LogConfig . name )
228231 log .info (
229232 f"[In-process] Rank deactivated (rank={ state .rank } ) due to divisibility requirement (active_world_size={ active_world_size } , divisor={ divisor } )"
230233 )
@@ -574,7 +577,7 @@ def build_tree(self, state, store):
574577 def replace_with_inactive (self , terminated_active_ranks ):
575578 replaced_terminate_active_ranks = set ()
576579
577- log = logging .getLogger (__name__ )
580+ log = logging .getLogger (LogConfig . name )
578581
579582 for terminated_active_rank in sorted (terminated_active_ranks ):
580583 terminated_active_node = self .rank_map [terminated_active_rank ]
@@ -625,7 +628,7 @@ def replace_with_backfill(self, unhandled_terminations):
625628 key = lambda node : node .state .active_rank ,
626629 )
627630
628- log = logging .getLogger (__name__ )
631+ log = logging .getLogger (LogConfig . name )
629632 for backfill_node , terminated_node in itertools .zip_longest (
630633 reversed (largest_active_nodes ),
631634 terminated_nodes ,
@@ -647,7 +650,7 @@ def replace_with_backfill(self, unhandled_terminations):
647650
648651 def shift_ranks (self , replaced_active , unhandled_terminations ):
649652 sorted_replaced_active = sorted (replaced_active )
650- log = logging .getLogger (__name__ )
653+ log = logging .getLogger (LogConfig . name )
651654
652655 for n in self .rank_map .values ():
653656 n .state .active_world_size -= len (unhandled_terminations )
@@ -672,7 +675,7 @@ def filter_active_world_size(self):
672675 new_active_world_size = self .world_size_filter (active_world_size )
673676 assert new_active_world_size <= active_world_size
674677
675- log = logging .getLogger (__name__ )
678+ log = logging .getLogger (LogConfig . name )
676679 for leaf in self .tree .iter_leaves ():
677680 leaf .state .active_world_size = new_active_world_size
678681 if leaf .state .mode == Mode .ACTIVE and leaf .state .active_rank >= new_active_world_size :
@@ -738,7 +741,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
738741 rank for rank in terminated_ranks if self .rank_map [rank ].state .mode == Mode .ACTIVE
739742 )
740743
741- log = logging .getLogger (__name__ )
744+ log = logging .getLogger (LogConfig . name )
742745 for terminated_rank in terminated_ranks :
743746 # If this rank is being terminated, log it
744747 if self .current_state .initial_rank == self .rank_map [terminated_rank ].state .initial_rank :
@@ -808,7 +811,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
808811 terminated_ranks = utils .format_rank_set (terminated_ranks )
809812 raise RankDiscarded (f'{ rank = } { terminated_ranks = } ' )
810813 elif rank >= world_size :
811- log = logging .getLogger (__name__ )
814+ log = logging .getLogger (LogConfig . name )
812815 old_rank = rank
813816 rank = ordered_terminated_ranks [rank - world_size ]
814817 log .info (
@@ -869,7 +872,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
869872 old_rank = rank
870873 rank = rank - sum (rank > terminated_rank for terminated_rank in terminated_ranks )
871874 if old_rank != rank :
872- log = logging .getLogger (__name__ )
875+ log = logging .getLogger (LogConfig . name )
873876 log .info (f"[In-process] Rank shifted (rank changed from { old_rank } to { rank } )" )
874877
875878 state = dataclasses .replace (
@@ -982,7 +985,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
982985
983986 group_count = int (store .get (prefixed_key ))
984987 if not self .condition (group_count ):
985- log = logging .getLogger (__name__ )
988+ log = logging .getLogger (LogConfig . name )
986989 log .info (
987990 f"[In-process] Rank marked for termination (rank={ rank } , group_key={ key } , group_count={ group_count } ) due to failed group condition"
988991 )
0 commit comments