Skip to content

Commit 4dce778

Browse files
Merge pull request #185 from hexinw-nvidia/profiling
feat: Add Injob time profiling metrics
2 parents 73b54e1 + 119235b commit 4dce778

File tree

4 files changed

+173
-3
lines changed

4 files changed

+173
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ requires = ["poetry-core>=1.0.0", "pybind11", "setuptools", "wheel"]
3232
build-backend = "poetry.core.masonry.api"
3333

3434
[tool.poetry.dependencies]
35+
nv-one-logger-core = ">=2.1.0"
3536
torch = ">=2.3.0"
3637
packaging = "*"
3738
python = ">=3.10"

src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig
6161

6262
from ..shared_utils.health_check import GPUHealthCheck
63+
from ..shared_utils.profiling import ProfilingEvent, record_profiling_event
6364
from .data import WorkloadAction
6465
from .ipc_connector import IpcConnector
6566
from .launcher import FT_LAUNCHER_IPC_SOCKET, UnhealthyNodeException
@@ -1322,6 +1323,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
13221323
self._record(message=msg)
13231324
log.info(msg)
13241325

1326+
# Record rendezvous start event
1327+
rendezvous_start_event_id = record_profiling_event(
1328+
ProfilingEvent.RENDEZVOUS_STARTED,
1329+
node_id=self._this_node,
1330+
)
1331+
13251332
try:
13261333
self._stop_heartbeats()
13271334

@@ -1362,6 +1369,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
13621369
self._record(message=msg, rank=rank)
13631370
log.info(msg)
13641371

1372+
# Record rendezvous completion event
1373+
rendezvous_completion_event_id = record_profiling_event(
1374+
ProfilingEvent.RENDEZVOUS_COMPLETED,
1375+
node_id=self._this_node,
1376+
)
1377+
13651378
# Use RendezvousInfo if available (newer PyTorch versions >= 2.4.0)
13661379
# Fall back to tuple format if RendezvousInfo is not supported
13671380
if _RENDEZVOUS_INFO_AVAILABLE:

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
write_obj_to_ipc_stream,
7777
)
7878
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig, setup_logger
79+
from nvidia_resiliency_ext.shared_utils.profiling import ProfilingEvent, record_profiling_event
7980

8081
# Deprecation warning for FT_LAUNCHER_LOGLEVEL
8182
if os.getenv('FT_LAUNCHER_LOGLEVEL') is not None:
@@ -142,7 +143,7 @@ class LocalElasticAgent(SimpleElasticAgent):
142143
python multiprocessing compatible. To pass multiprocessing data structures
143144
to the workers you may create the data structure in the same multiprocessing
144145
context as the specified ``start_method`` and pass it as a function argument.
145-
146+
146147
Note: If your training script uses the nvrx logger, make sure to call
147148
``setup_logger()`` at the beginning of your training function to ensure
148149
the logger is properly set up in each subprocess.
@@ -183,12 +184,12 @@ def trainer(args) -> str:
183184
# Ensure nvrx logger is set up in this subprocess
184185
from nvidia_resiliency_ext.shared_utils.log_manager import setup_logger
185186
setup_logger()
186-
187+
187188
# Use the nvrx logger
188189
import logging
189190
logger = logging.getLogger(LogConfig.name)
190191
logger.info("Training started")
191-
192+
192193
return "do train"
193194
194195
def main():
@@ -255,6 +256,7 @@ def __init__(
255256
self._ft_cfg = fault_tol_cfg
256257
self._children_pgids: Set[int] = set()
257258
self._restart_policy = restart_policy
259+
self._node_id = self._get_fq_hostname()
258260

259261
DEFAULT_ROLE = "default" # FIXME
260262

@@ -326,6 +328,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
326328
self._exit_barrier()
327329
return run_result
328330
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
331+
# Record failure detection event
332+
record_profiling_event(
333+
ProfilingEvent.FAILURE_DETECTED,
334+
node_id=self._rdzv_handler._this_node,
335+
rank=self._worker_group.group_rank,
336+
)
337+
329338
if self._remaining_restarts > 0:
330339
logger.info(
331340
"[%s] Worker group %s. "
@@ -351,6 +360,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
351360
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
352361
group_rank = self._worker_group.group_rank
353362
if num_nodes_waiting > 0:
363+
# Record failure detection event
364+
record_profiling_event(
365+
ProfilingEvent.FAILURE_DETECTED,
366+
node_id=self._rdzv_handler._this_node,
367+
rank=self._worker_group.group_rank,
368+
)
369+
354370
logger.info(
355371
"[%s] Detected %s "
356372
"new nodes from group_rank=%s; "
@@ -591,6 +607,13 @@ async def send_close_msg():
591607

592608
self._shutdown(timeout=self._workers_stop_timeout)
593609

610+
# Record worker termination event after shutdown is complete
611+
record_profiling_event(
612+
ProfilingEvent.WORKER_TERMINATED,
613+
node_id=self._rdzv_handler._this_node,
614+
rank=worker_group.group_rank,
615+
)
616+
594617
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
595618
# `torch.distributed.elastic.metrics.prof`.
596619
@prof
@@ -600,6 +623,13 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
600623
assert store is not None
601624
restart_count = spec.max_restarts - self._remaining_restarts
602625

626+
# Record worker start start event
627+
record_profiling_event(
628+
ProfilingEvent.WORKER_START_STARTED,
629+
node_id=self._rdzv_handler._this_node,
630+
rank=worker_group.group_rank,
631+
)
632+
603633
use_agent_store = spec.rdzv_handler.use_agent_store
604634

605635
args: Dict[int, Tuple] = {}
@@ -671,8 +701,16 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
671701

672702
self._children_pgids = {os.getpgid(p) for p in self._pcontext.pids().values()}
673703

704+
# Record worker start completion event
705+
record_profiling_event(
706+
ProfilingEvent.WORKER_START_COMPLETED,
707+
node_id=self._rdzv_handler._this_node,
708+
rank=worker_group.group_rank,
709+
)
710+
674711
return self._pcontext.pids()
675712

713+
676714
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 30) -> None:
677715
if self._worker_watchdog is not None:
678716
self._worker_watchdog.stop()
@@ -1058,6 +1096,7 @@ def launch_agent(
10581096
)
10591097

10601098
logger.info(f"Agent .run() is OK. No failures in the result. {result=}")
1099+
10611100
return result.return_values
10621101
except UnhealthyNodeException as e:
10631102
# do not shutdown rendezvous when an unhealthy node is leaving
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# This file adds time profiling capabilities using nv one logger
17+
18+
import logging
19+
import time
20+
from datetime import datetime, timezone
21+
from enum import Enum
22+
from typing import Optional
23+
24+
from nv_one_logger.api.one_logger_provider import OneLoggerProvider
25+
from nv_one_logger.core.attributes import Attributes
26+
from nv_one_logger.core.event import Event
27+
28+
from ..shared_utils.log_manager import LogConfig
29+
30+
31+
class ProfilingEvent(Enum):
32+
"""Enumeration of profiling events for fault tolerance metrics."""
33+
34+
FAILURE_DETECTED = "failure_detected"
35+
WORKER_TERMINATED = "worker_terminated"
36+
RENDEZVOUS_STARTED = "rendezvous_started"
37+
RENDEZVOUS_COMPLETED = "rendezvous_completed"
38+
WORKER_START_STARTED = "worker_start_started"
39+
WORKER_START_COMPLETED = "worker_start_completed"
40+
41+
42+
class FaultToleranceProfiler:
43+
"""Profiler for measuring fault tolerance timing metrics using nv one logger."""
44+
45+
def __init__(self):
46+
self._current_cycle = 0
47+
# Initialize logger as a member to avoid module-level logger issues
48+
self._logger = logging.getLogger(LogConfig.name)
49+
50+
def _timestamp_to_utc_datetime(self, timestamp: float) -> str:
51+
"""Convert timestamp to UTC datetime string."""
52+
utc_datetime = datetime.fromtimestamp(timestamp, tz=timezone.utc)
53+
return utc_datetime.strftime("%Y-%m-%d %H:%M:%S.%f")[
54+
:-3
55+
] # Remove last 3 digits for milliseconds
56+
57+
def _publish_metrics(
58+
self, event: ProfilingEvent, timestamp: float, node_id: Optional[str], rank: Optional[int]
59+
) -> None:
60+
"""Publish metrics using nv one logger."""
61+
try:
62+
# Check if nv one logger is available and enabled
63+
if OneLoggerProvider.instance().one_logger_enabled:
64+
# Create attributes for the event
65+
attributes = Attributes()
66+
attributes.add("event_type", event.value)
67+
attributes.add("timestamp_ms", int(timestamp * 1000))
68+
attributes.add("cycle", self._current_cycle)
69+
if node_id:
70+
attributes.add("node_id", node_id)
71+
if rank is not None:
72+
attributes.add("rank", rank)
73+
74+
# Create and record the event
75+
event_obj = Event.create(f"ft.{event.value}", attributes)
76+
OneLoggerProvider.instance().recorder.event(None, event_obj)
77+
except Exception as e:
78+
# If nv one logger fails, just log a warning and continue
79+
self._logger.warning(f"Failed to publish metrics to nv one logger: {e}")
80+
81+
def record_event(
82+
self,
83+
event: ProfilingEvent,
84+
node_id: Optional[str] = None,
85+
rank: Optional[int] = None,
86+
) -> str:
87+
"""Record a profiling event and return a unique event ID."""
88+
timestamp = time.time()
89+
event_id = f"{event.value}_{timestamp}_{node_id or 'unknown'}_{rank or 'unknown'}"
90+
91+
# Increment cycle count for failure detection events
92+
if event == ProfilingEvent.FAILURE_DETECTED:
93+
self._current_cycle += 1
94+
95+
# Publish metrics using nv one logger
96+
self._publish_metrics(event, timestamp, node_id, rank)
97+
98+
# Format log message with cycle count and UTC time
99+
utc_time = self._timestamp_to_utc_datetime(timestamp)
100+
self._logger.info(
101+
f" - Cycle: {self._current_cycle} Event: {event.value} Node: {node_id} Rank: {rank} "
102+
f"Time: {utc_time} UTC"
103+
)
104+
return event_id
105+
106+
107+
# Global profiler instance
108+
_global_profiler = FaultToleranceProfiler()
109+
110+
111+
def record_profiling_event(
112+
event: ProfilingEvent,
113+
node_id: Optional[str] = None,
114+
rank: Optional[int] = None,
115+
) -> str:
116+
"""Convenience function to record a profiling event."""
117+
return _global_profiler.record_event(event, node_id, rank)

0 commit comments

Comments
 (0)