Skip to content

Commit cfac2d7

Browse files
authored
Merge pull request #177 from NVIDIA/rankmonitorserver_inherit_env
Configure RankMonitorServer to inherit env vars from launcher
2 parents 2b02733 + d263df3 commit cfac2d7

File tree

5 files changed

+14
-6
lines changed

5 files changed

+14
-6
lines changed

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def setup_rank_monitors(self, envs: Dict[int, Dict[str, str]]) -> None:
478478
ipc_socket_path=rmon_ipc_socket,
479479
is_restarter_logger=is_restarter_logger,
480480
mp_ctx=fork_mp_ctx,
481+
env=worker_env,
481482
)
482483

483484
def shutdown_rank_monitors(self):

src/nvidia_resiliency_ext/fault_tolerance/rank_monitor_server.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def __init__(
126126
cfg: FaultToleranceConfig,
127127
ipc_socket_path: str,
128128
rank_monitor_ready_event,
129-
logger: RankMonitorLogger,
129+
logger: logging.Logger,
130+
is_restarter_logger: bool,
130131
):
131132
"""
132133
Initializes the RankMonitorServer object.
@@ -151,7 +152,10 @@ def __init__(
151152
self.connection_lock = asyncio.Lock()
152153
self.rank_monitor_ready_event = rank_monitor_ready_event
153154
self.logger = logger
154-
self.state_machine = RankMonitorStateMachine(logger)
155+
self.rmlogger = RankMonitorLogger(
156+
level=logger.level, is_restarter_logger=is_restarter_logger
157+
)
158+
self.state_machine = RankMonitorStateMachine(self.rmlogger)
155159
self._periodic_restart_task = None
156160
self.health_checker = GPUHealthCheck(
157161
interval=self.cfg.node_health_check_interval, on_failure=self._handle_unhealthy_node
@@ -264,7 +268,7 @@ async def _handle_init_msg(self, msg, writer):
264268
# Update NIC health checker on the rank to monitor.
265269
if self.nic_health_checker is not None:
266270
self.nic_health_checker.set_nic_device(local_rank=self.rank_info.local_rank)
267-
self.logger.set_connected_rank(msg.rank_info.global_rank)
271+
self.rmlogger.set_connected_rank(msg.rank_info.global_rank)
268272
await write_obj_to_ipc_stream(OkMsg(cfg=self.cfg), writer)
269273

270274
async def _handle_heartbeat_msg(self, msg, writer):
@@ -318,7 +322,7 @@ def _handle_ipc_connection_lost(self):
318322
f"Section(s) {open_section_names} were still open. you can use`.end_all_sections` to avoid this warning"
319323
)
320324
self.open_sections.clear()
321-
self.logger.set_connected_rank(None)
325+
self.rmlogger.set_connected_rank(None)
322326
if self.connection_lock.locked():
323327
self.connection_lock.release()
324328

@@ -546,7 +550,8 @@ def run(
546550
cfg,
547551
ipc_socket_path,
548552
rank_monitor_ready_event,
549-
rmlogger,
553+
logger,
554+
is_restarter_logger,
550555
)
551556
asyncio.run(inst._rank_monitor_loop())
552557
logger.debug("Leaving RankMonitorServer process")

src/nvidia_resiliency_ext/inprocess/wrap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def __init__(
219219
self.finalize = finalize
220220
self.health_check = health_check
221221

222+
setup_logger(node_local_tmp_prefix="wrapper")
222223
# Construct internal restart_health_check by chaining user's health_check with GPU and NVL checks
223224
self._construct_restart_health_check()
224225

src/nvidia_resiliency_ext/shared_utils/health_check.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def __init__(
215215
on_failure (Optional[Callable]): Callback function to handle health check failures.
216216
"""
217217
super().__init__()
218+
logger = logging.getLogger(LogConfig.name)
218219
self.device_index = device_index
219220
self.interval = interval
220221
self.on_failure = on_failure

src/nvidia_resiliency_ext/shared_utils/log_node_local_tmp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _get_backup_files(self):
7272
def _log_file_namer(self):
7373
backup_files = self._get_backup_files()
7474
if self.fname is None and backup_files:
75-
return backup_files[0]
75+
return backup_files[-1]
7676
rank_str = str(self.rank_id) if self.rank_id is not None else "unknown"
7777
file_prefix = f"rank_{rank_str}_{self.proc_name}.msg."
7878
return f"{file_prefix}{int(time.time()*1000)}"

0 commit comments

Comments
 (0)