Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions megatron/training/config/common_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,23 @@ class DistributedInitConfig:
distributed_timeout_seconds_after_init: int | None = None
"""Timeout in seconds for process groups after initialization. This timeout is applied to all process groups after initialization and the first iteration completes."""

flight_recorder_dump_path: str | None = None
"""Path for NCCL flight recorder trace dumps. Sets TORCH_FR_DUMP_TEMP_FILE and TORCH_NCCL_DEBUG_INFO_TEMP_FILE env variables before distributed init."""

flight_recorder_trace_buffer_size: int = 2000
"""Size of the NCCL flight recorder trace buffer (TORCH_NCCL_TRACE_BUFFER_SIZE)."""

flight_recorder_dump_on_timeout: bool = True
"""Dump flight recorder traces on NCCL timeout (TORCH_NCCL_DUMP_ON_TIMEOUT)."""

flight_recorder_include_stack_trace: bool = False
"""Include stack traces in flight recorder dumps (TORCH_INCLUDE_STACK_TRACE)."""

flight_recorder_include_only_active: bool = True
"""Include only active operations in flight recorder dumps (TORCH_INCLUDE_ONLY_ACTIVE)."""

flight_recorder_extra_dump_on_exec: bool = True
"""Enable extra flight recorder dump on execution (TORCH_NCCL_EXTRA_DUMP_ON_EXEC)."""

disable_jit_fuser: bool = False
"""Disable the JIT fuser."""
35 changes: 35 additions & 0 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,41 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
if args.cuda_graph_impl == "transformer_engine":
torch.cuda.set_stream(torch.cuda.Stream())

# Set flight recorder env vars if specified.
# Priority: pre-existing environment variable > MLM argument.
# All vars follow the same setdefault semantics: if already set in the
# environment we warn and keep the user's value; otherwise we apply the
# value derived from the MLM argument / flag.
# The block is also triggered when either path env var is already set
# so that the remaining defaults are applied consistently.
_fr_path = (
args.flight_recorder_dump_path
or os.environ.get('TORCH_FR_DUMP_TEMP_FILE')
or os.environ.get('TORCH_NCCL_DEBUG_INFO_TEMP_FILE')
)
if _fr_path is not None:
_fr_env_defaults = {
'TORCH_FR_DUMP_TEMP_FILE': _fr_path,
'TORCH_NCCL_DEBUG_INFO_TEMP_FILE': _fr_path,
'TORCH_NCCL_TRACE_BUFFER_SIZE': str(args.flight_recorder_trace_buffer_size),
'TORCH_NCCL_DUMP_ON_TIMEOUT': str(int(args.flight_recorder_dump_on_timeout)),
'TORCH_INCLUDE_STACK_TRACE': str(int(args.flight_recorder_include_stack_trace)),
'TORCH_INCLUDE_ONLY_ACTIVE': str(int(args.flight_recorder_include_only_active)),
'TORCH_NCCL_EXTRA_DUMP_ON_EXEC': str(int(args.flight_recorder_extra_dump_on_exec)),
}
for _var, _default in _fr_env_defaults.items():
if _var in os.environ:
warn_rank_0(
f"Flight recorder: environment variable {_var} is already set to "
f"'{os.environ[_var]}'; ignoring config value '{_default}'."
)
else:
os.environ[_var] = _default
print_rank_0(
"Flight recorder env vars:\n"
+ "\n".join(f" {k}={os.environ[k]}" for k in _fr_env_defaults)
)

# Call the init process
init_process_group_kwargs = {
'backend': args.distributed_backend,
Expand Down
Loading