Skip to content

Commit d988d43

Browse files
committed
Moved rank assignment logging to distributed logger
1 parent e50e391 commit d988d43

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

src/nvidia_resiliency_ext/inprocess/rank_assignment.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import warnings
2727
from typing import Callable, Optional, Union
2828

29+
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig
30+
2931
from . import exception, utils
3032
from .state import Mode, State
3133
from .store import StoreMixin
@@ -177,7 +179,8 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
177179
active_rank = None
178180
# Log deactivation if transitioning from ACTIVE to INACTIVE
179181
if state.mode == Mode.ACTIVE:
180-
log = logging.getLogger(__name__)
182+
log = logging.getLogger(LogConfig.name)
183+
181184
log.info(
182185
f"[In-process] Rank deactivated (rank={state.rank}) due to max active world size limit ({active_world_size})"
183186
)
@@ -224,7 +227,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
224227
active_rank = None
225228
# Log deactivation if transitioning from ACTIVE to INACTIVE
226229
if state.mode == Mode.ACTIVE:
227-
log = logging.getLogger(__name__)
230+
log = logging.getLogger(LogConfig.name)
228231
log.info(
229232
f"[In-process] Rank deactivated (rank={state.rank}) due to divisibility requirement (active_world_size={active_world_size}, divisor={divisor})"
230233
)
@@ -574,7 +577,7 @@ def build_tree(self, state, store):
574577
def replace_with_inactive(self, terminated_active_ranks):
575578
replaced_terminate_active_ranks = set()
576579

577-
log = logging.getLogger(__name__)
580+
log = logging.getLogger(LogConfig.name)
578581

579582
for terminated_active_rank in sorted(terminated_active_ranks):
580583
terminated_active_node = self.rank_map[terminated_active_rank]
@@ -625,7 +628,7 @@ def replace_with_backfill(self, unhandled_terminations):
625628
key=lambda node: node.state.active_rank,
626629
)
627630

628-
log = logging.getLogger(__name__)
631+
log = logging.getLogger(LogConfig.name)
629632
for backfill_node, terminated_node in itertools.zip_longest(
630633
reversed(largest_active_nodes),
631634
terminated_nodes,
@@ -647,7 +650,7 @@ def replace_with_backfill(self, unhandled_terminations):
647650

648651
def shift_ranks(self, replaced_active, unhandled_terminations):
649652
sorted_replaced_active = sorted(replaced_active)
650-
log = logging.getLogger(__name__)
653+
log = logging.getLogger(LogConfig.name)
651654

652655
for n in self.rank_map.values():
653656
n.state.active_world_size -= len(unhandled_terminations)
@@ -672,7 +675,7 @@ def filter_active_world_size(self):
672675
new_active_world_size = self.world_size_filter(active_world_size)
673676
assert new_active_world_size <= active_world_size
674677

675-
log = logging.getLogger(__name__)
678+
log = logging.getLogger(LogConfig.name)
676679
for leaf in self.tree.iter_leaves():
677680
leaf.state.active_world_size = new_active_world_size
678681
if leaf.state.mode == Mode.ACTIVE and leaf.state.active_rank >= new_active_world_size:
@@ -738,7 +741,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
738741
rank for rank in terminated_ranks if self.rank_map[rank].state.mode == Mode.ACTIVE
739742
)
740743

741-
log = logging.getLogger(__name__)
744+
log = logging.getLogger(LogConfig.name)
742745
for terminated_rank in terminated_ranks:
743746
# If this rank is being terminated, log it
744747
if self.current_state.initial_rank == self.rank_map[terminated_rank].state.initial_rank:
@@ -808,7 +811,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
808811
terminated_ranks = utils.format_rank_set(terminated_ranks)
809812
raise RankDiscarded(f'{rank=} {terminated_ranks=}')
810813
elif rank >= world_size:
811-
log = logging.getLogger(__name__)
814+
log = logging.getLogger(LogConfig.name)
812815
old_rank = rank
813816
rank = ordered_terminated_ranks[rank - world_size]
814817
log.info(
@@ -869,7 +872,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
869872
old_rank = rank
870873
rank = rank - sum(rank > terminated_rank for terminated_rank in terminated_ranks)
871874
if old_rank != rank:
872-
log = logging.getLogger(__name__)
875+
log = logging.getLogger(LogConfig.name)
873876
log.info(f"[In-process] Rank shifted (rank changed from {old_rank} to {rank})")
874877

875878
state = dataclasses.replace(
@@ -982,7 +985,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
982985

983986
group_count = int(store.get(prefixed_key))
984987
if not self.condition(group_count):
985-
log = logging.getLogger(__name__)
988+
log = logging.getLogger(LogConfig.name)
986989
log.info(
987990
f"[In-process] Rank marked for termination (rank={rank}, group_key={key}, group_count={group_count}) due to failed group condition"
988991
)

0 commit comments

Comments
 (0)