Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ Using Slurm can be slightly more involved. Like with MPI, you must add the follo
```
If you do not have ssh access to the compute nodes in your Slurm cluster you need to add `{"no_ssh_check": true}`

##### Torchrun Based Slurm Launching

If you prefers `torchrun` over DeepSpeed's multinode launcher, use `deepy_torchrun.py`. It keeps the same command-line interface as `deepy.py`, parses the same YAML config files, and forwards the args directly to `train.py`.

This launcher currently assumes a Slurm allocation. It expects `MASTER_ADDR` and `MASTER_PORT` to be exported ahead of time, and it derives `--nnodes`, `--nproc-per-node`, and `--node-rank` from Slurm environment variables (`SLURM_JOB_NUM_NODES`, `SLURM_GPUS_ON_NODE`, and `RANK`).

Launch one `deepy_torchrun.py` process per node, for example with `srun --ntasks-per-node=1`:

```bash
srun --ntasks-per-node=1 python3 deepy_torchrun.py train.py /path/to/configs/my_model.yml
```

See `examples/slurm_torchrun/slurm_torchrun_usage.sh` for a complete Slurm-based example.

#### (Advanced) Custom Launching

There are many cases where the above default launching options are not sufficient
Expand Down
90 changes: 90 additions & 0 deletions deepy_torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python
# Copyright (c) 2025, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Launch GPT-NeoX with torchrun while keeping the deepy.py CLI."""

import logging
import os
import subprocess
import sys


def main(input_args=None):
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

from megatron.neox_arguments import NeoXArgs
from megatron.utils import get_wandb_api_key

neox_args = NeoXArgs.consume_deepy_args(input_args)
deepspeed_main_args = neox_args.get_deepspeed_main_args()

# Extract wandb API key and inject into worker environments
wandb_token = get_wandb_api_key(neox_args=neox_args)
if wandb_token is not None:
os.environ["WANDB_API_KEY"] = wandb_token

slurm_env = {
"nnodes": os.environ.get("SLURM_JOB_NUM_NODES"),
"nproc_per_node": os.environ.get("SLURM_GPUS_ON_NODE"),
"master_addr": os.environ.get("MASTER_ADDR"),
"master_port": os.environ.get("MASTER_PORT"),
"node_rank": os.environ.get("RANK")
}
missing = [name for name, value in slurm_env.items() if not value]
if missing:
raise RuntimeError(
"deepy_torchrun.py expects a Slurm-style launch environment and is "
f"missing: {', '.join(missing)}"
)

# DeepSpeed launcher args come first; torchrun only needs the target script and
# the arguments that would normally be forwarded to it.
user_script_idx = deepspeed_main_args.index(neox_args.user_script)
cmd = [
"torchrun",
"--nnodes",
slurm_env["nnodes"],
"--nproc-per-node",
slurm_env["nproc_per_node"],
"--master-addr",
slurm_env["master_addr"],
"--master-port",
slurm_env["master_port"],
"--node-rank",
slurm_env["node_rank"],
"--log-dir",
os.path.join(os.getcwd(), "logs"),
*deepspeed_main_args[user_script_idx:],
]

env = os.environ.copy()
curr_path = os.path.abspath(".")
if "PYTHONPATH" in env:
env["PYTHONPATH"] = curr_path + os.pathsep + env["PYTHONPATH"]
else:
env["PYTHONPATH"] = curr_path

logging.info("Running command: %s", " ".join(cmd))
result = subprocess.run(cmd, env=env, check=False)

# In case of failure must propagate the error-condition back to the caller (usually shell). The
# actual error and traceback should have been printed in the subprocess, so in order to avoid
# unnecessary noise we just quietly exit here with the same code as the subprocess
if result.returncode != 0:
sys.exit(result.returncode)


if __name__ == "__main__":
main()
67 changes: 67 additions & 0 deletions examples/slurm_torchrun/slurm_torchrun_usage.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#! /bin/bash
#SBATCH --job-name=slurm_torchrun_job
#SBATCH --output=slurm_torchrun_job-%j.out
#SBATCH --time=1-00:00:00
#SBATCH --gres=gpu:8
#SBATCH --nodes=2
#SBATCH --cpus-per-task=48
#SBATCH --ntasks-per-node=1
#SBATCH --mem=1000000
#SBATCH --no-requeue

# --- Load necessary modules (might not be required depending on your HPC environment) ---
module load python-waterboa apptainer gcc openmpi ...

# --- Set up directories and container image ---
export SINGULARITY_TMPDIR="Artifacts/TEMP"
export CONTAINER_IMAGE=Artifacts/image.sif

# --- Set up distributed environment variables ---
export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=30001
export COUNT_NODE=$SLURM_NNODES

echo "JOB ID: $SLURM_JOBID"
echo "NODES: $COUNT_NODE"
echo "HOSTNAMES: $HOSTNAMES"
echo "MASTER_ADDR: $MASTER_ADDR"
echo "MASTER_PORT: $MASTER_PORT"

# --- Create DeepSpeed hostfile ---
# This script should generate the hostfile using the SLURM_JOBID
bash write_hostfile.sh
# Tell DeepSpeed where to find our generated hostfile
export DLTS_HOSTFILE=hostfiles/hosts_$SLURM_JOBID

# --- Set WANDB API Key ---
export WANDB_API_KEY=""

# --- Execute the distributed training job using srun ---
srun -l apptainer exec --nv --bind /:/ $CONTAINER_IMAGE \
bash -c '
set -ex # Exit on error and print commands
# --- Environment setup inside the container ---
# Optional: Create a unique cache directory for each job run
export TRITON_CACHE_DIR="/tmp/TRITON_TEMP_$SLURM_JOBID"
mkdir -p $TRITON_CACHE_DIR
export OMP_NUM_THREADS=10 # Might not be needed/might need adjustment based on your CPU setup
# Map Slurm variables to standard distributed training variables
export RANK=$SLURM_PROCID
export WORLD_SIZE=$SLURM_NTASKS
# --- Log environment for debugging ---
echo "--------------------------------------------------"
echo "Node ID: $SLURM_NODEID | Rank: $RANK | World Size: $WORLD_SIZE"
echo "MASTER_ADDR: $MASTER_ADDR | MASTER_PORT: $MASTER_PORT"
echo "DLTS_HOSTFILE: $DLTS_HOSTFILE"
echo "WANDB API Key is set." # Avoid printing the key to logs
echo "--------------------------------------------------"

# Optional: Display installed packages
# pip freeze --all

# --- Run the training script ---
cd /path_to_gpt_neox_codebase

python deepy_torchrun.py train.py config.yml
'