Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/nvidia_resiliency_ext/inprocess/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __call__(self, fn, args, kwargs):
wrapper = self.wrapper
progress_watchdog = self.progress_watchdog

exception_recvd_time = None
rank_assignment_ctx = RankAssignmentCtx(state, store, set())
reassigned_ctx = wrapper.rank_assignment(rank_assignment_ctx)
self.state = state = reassigned_ctx.state
Expand Down Expand Up @@ -463,6 +464,42 @@ def __call__(self, fn, args, kwargs):
raise HealthCheckError from health_ex

if state.mode == Mode.ACTIVE:
if exception_recvd_time is not None:
# Store local trigger time in the distributed store
store.set(
f"exception_recvd_time_{state.rank}",
str(exception_recvd_time),
)
# Barrier to ensure all ranks have set their trigger time
store.completion_barrier(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The barriers defined in Store can only be used once. To add a barrier here we'll need to add another to the StoreMixin.

ranks=[state.rank],
rendezvous_count=state.world_size,
timeout=wrapper.completion_timeout,
timeout_chunk=wrapper.progress_watchdog_interval,
)
if state.rank == 0:
excp_recvd_times = []
for r in range(state.world_size):
excp_recvd_times.append(
float(store.get(f"exception_recvd_time_{r}"))
)
restart_latency_min = int(
(time.monotonic() - min(excp_recvd_times)) * 1000
)
restart_latency_max = int(
(time.monotonic() - max(excp_recvd_times)) * 1000
)
log.info(
f"global-in-process ....: ({restart_latency_min}, {restart_latency_max})"
)
# Also log local latency for reference
local_restart_latency = int(
(time.monotonic() - exception_recvd_time) * 1000
)
log.debug(
f"local-in-process ....: {local_restart_latency}"
)
exception_recvd_time = None
ret = fn(*args, **kwargs)
store.record_completed()
elif state.mode == Mode.INACTIVE:
Expand All @@ -484,6 +521,8 @@ def __call__(self, fn, args, kwargs):
)
except Exception as fn_ex:
try:
if exception_recvd_time is None:
exception_recvd_time = time.monotonic()
log.error(log_exc(state, fn_ex, 'fn_ex'))
monitor_process.record_interrupted(
[InterruptionRecord(state.rank, Interruption.EXCEPTION)]
Expand Down
Loading