diff --git a/src/nvidia_resiliency_ext/inprocess/wrap.py b/src/nvidia_resiliency_ext/inprocess/wrap.py index 27d206a2..569cd7d1 100644 --- a/src/nvidia_resiliency_ext/inprocess/wrap.py +++ b/src/nvidia_resiliency_ext/inprocess/wrap.py @@ -401,6 +401,7 @@ def __call__(self, fn, args, kwargs): wrapper = self.wrapper progress_watchdog = self.progress_watchdog + exception_recvd_time = None rank_assignment_ctx = RankAssignmentCtx(state, store, set()) reassigned_ctx = wrapper.rank_assignment(rank_assignment_ctx) self.state = state = reassigned_ctx.state @@ -463,6 +464,42 @@ def __call__(self, fn, args, kwargs): raise HealthCheckError from health_ex if state.mode == Mode.ACTIVE: + if exception_recvd_time is not None: + # Store local trigger time in the distributed store + store.set( + f"exception_recvd_time_{state.rank}", + str(exception_recvd_time), + ) + # Barrier to ensure all ranks have set their trigger time + store.completion_barrier( + ranks=[state.rank], + rendezvous_count=state.world_size, + timeout=wrapper.completion_timeout, + timeout_chunk=wrapper.progress_watchdog_interval, + ) + if state.rank == 0: + excp_recvd_times = [] + for r in range(state.world_size): + excp_recvd_times.append( + float(store.get(f"exception_recvd_time_{r}")) + ) + restart_latency_min = int( + (time.monotonic() - min(excp_recvd_times)) * 1000 + ) + restart_latency_max = int( + (time.monotonic() - max(excp_recvd_times)) * 1000 + ) + log.info( + f"global-in-process ....: ({restart_latency_min}, {restart_latency_max})" + ) + # Also log local latency for reference + local_restart_latency = int( + (time.monotonic() - exception_recvd_time) * 1000 + ) + log.debug( + f"local-in-process ....: {local_restart_latency}" + ) + exception_recvd_time = None ret = fn(*args, **kwargs) store.record_completed() elif state.mode == Mode.INACTIVE: @@ -484,6 +521,8 @@ def __call__(self, fn, args, kwargs): ) except Exception as fn_ex: try: + if exception_recvd_time is None: + exception_recvd_time = time.monotonic() log.error(log_exc(state, fn_ex, 'fn_ex')) monitor_process.record_interrupted( [InterruptionRecord(state.rank, Interruption.EXCEPTION)]