Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions megatron/training/config/common_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,23 @@ class DistributedInitConfig:
distributed_timeout_seconds_after_init: int | None = None
"""Timeout in seconds for process groups after initialization. This timeout is applied to all process groups after initialization and the first iteration completes."""

flight_recorder_dump_path: str | None = None
"""Path for NCCL flight recorder trace dumps. Sets TORCH_FR_DUMP_TEMP_FILE and TORCH_NCCL_DEBUG_INFO_TEMP_FILE env variables before distributed init."""

flight_recorder_trace_buffer_size: int = 2000
"""Size of the NCCL flight recorder trace buffer (TORCH_NCCL_TRACE_BUFFER_SIZE)."""

flight_recorder_dump_on_timeout: bool = True
"""Dump flight recorder traces on NCCL timeout (TORCH_NCCL_DUMP_ON_TIMEOUT)."""

flight_recorder_include_stack_trace: bool = False
"""Include stack traces in flight recorder dumps (TORCH_INCLUDE_STACK_TRACE)."""

flight_recorder_include_only_active: bool = True
"""Include only active operations in flight recorder dumps (TORCH_INCLUDE_ONLY_ACTIVE)."""

flight_recorder_extra_dump_on_exec: bool = True
"""Enable extra flight recorder dump on execution (TORCH_NCCL_EXTRA_DUMP_ON_EXEC)."""

disable_jit_fuser: bool = False
"""Disable the JIT fuser."""
29 changes: 29 additions & 0 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
if args.cuda_graph_impl == "transformer_engine":
torch.cuda.set_stream(torch.cuda.Stream())

# Set flight recorder env vars if specified
if args.flight_recorder_dump_path is not None:
os.environ['TORCH_FR_DUMP_TEMP_FILE'] = args.flight_recorder_dump_path
os.environ['TORCH_NCCL_DEBUG_INFO_TEMP_FILE'] = args.flight_recorder_dump_path
_fr_env_defaults = {
'TORCH_NCCL_TRACE_BUFFER_SIZE': str(args.flight_recorder_trace_buffer_size),
'TORCH_NCCL_DUMP_ON_TIMEOUT': str(int(args.flight_recorder_dump_on_timeout)),
'TORCH_INCLUDE_STACK_TRACE': str(int(args.flight_recorder_include_stack_trace)),
'TORCH_INCLUDE_ONLY_ACTIVE': str(int(args.flight_recorder_include_only_active)),
'TORCH_NCCL_EXTRA_DUMP_ON_EXEC': str(int(args.flight_recorder_extra_dump_on_exec)),
}
for _var, _default in _fr_env_defaults.items():
if _var in os.environ:
warn_rank_0(
f"Flight recorder: environment variable {_var} is already set to "
f"'{os.environ[_var]}'; ignoring config value '{_default}'."
)
else:
os.environ[_var] = _default
_fr_all_vars = {
'TORCH_FR_DUMP_TEMP_FILE': os.environ['TORCH_FR_DUMP_TEMP_FILE'],
'TORCH_NCCL_DEBUG_INFO_TEMP_FILE': os.environ['TORCH_NCCL_DEBUG_INFO_TEMP_FILE'],
**{k: os.environ[k] for k in _fr_env_defaults},
}
print_rank_0(
"Flight recorder env vars:\n"
+ "\n".join(f" {k}={v}" for k, v in _fr_all_vars.items())
)

# Call the init process
init_process_group_kwargs = {
'backend': args.distributed_backend,
Expand Down