Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 66 additions & 43 deletions docs/examples/distributed/multi_gpu/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,56 +25,51 @@ Click here to see `the code for this example

# distributed/single_gpu/job.sh -> distributed/multi_gpu/job.sh
#!/bin/bash
#SBATCH --gpus-per-task=rtx8000:1
-#SBATCH --gres=gpu:1
+#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=4
-#SBATCH --ntasks-per-node=1
+#SBATCH --ntasks-per-node=4
+#SBATCH --nodes=1
#SBATCH --mem=16G
#SBATCH --time=00:15:00


# Echo time and hostname into log
set -e # exit on error.
echo "Date: $(date)"
echo "Hostname: $(hostname)"


# Ensure only anaconda/3 module loaded.
module --quiet purge
# This example uses Conda to manage package dependencies.
# See https://docs.mila.quebec/Userguide.html#conda for more information.
module load anaconda/3
module load cuda/11.7

# Creating the environment for the first time:
# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
# pytorch-cuda=11.7 -c pytorch -c nvidia
# Other conda packages:
# conda install -y -n pytorch -c conda-forge rich tqdm

# Activate pre-existing environment.
conda activate pytorch


# Stage dataset into $SLURM_TMPDIR
mkdir -p $SLURM_TMPDIR/data
-cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
-# General-purpose alternatives combining copy and unpack:
-# unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
-# tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/
+ln -s /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
# General-purpose alternatives combining copy and unpack:
# unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
# tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/

-# Execute Python script
+# Get a unique port for this job based on the job ID
+export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
+export MASTER_ADDR="127.0.0.1"

# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
unset CUDA_VISIBLE_DEVICES

-# Execute Python script
-python main.py
+
+# Execute Python script in each task (one per GPU)
+srun python main.py
# Use `uv run --offline` on clusters without internet access on compute nodes.
-uv run python main.py
+srun uv run python main.py

**pyproject.toml**

.. code:: toml

[project]
name = "multi-gpu-example"
version = "0.1.0"
description = "Add your description here"
readme = "README.rst"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.3.1",
"rich>=14.0.0",
"torch>=2.7.1",
"torchvision>=0.22.1",
"tqdm>=4.67.1",
]

**main.py**

Expand All @@ -83,10 +78,12 @@ Click here to see `the code for this example
# distributed/single_gpu/main.py -> distributed/multi_gpu/main.py
-"""Single-GPU training example."""
+"""Multi-GPU Training example."""

import argparse
import logging
import os
from pathlib import Path
import sys

import rich.logging
import torch
Expand Down Expand Up @@ -125,10 +122,20 @@ Click here to see `the code for this example
+ device = torch.device("cuda", rank % torch.cuda.device_count())

# Setup logging (optional, but much better than using print statements)
# Uses the `rich` package to make logs pretty.
logging.basicConfig(
level=logging.INFO,
- format="%(message)s",
+ format=f"[{rank}/{world_size}] %(name)s - %(message)s ",
handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package.
handlers=[
rich.logging.RichHandler(
markup=True,
console=rich.console.Console(
# Allower wider log lines in sbatch output files than on the terminal.
width=120 if not sys.stdout.isatty() else None
),
)
],
)

logger = logging.getLogger(__name__)
Expand All @@ -140,9 +147,13 @@ Click here to see `the code for this example

+ # Wrap the model with DistributedDataParallel
+ # (See https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel)
+ model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
+ model = nn.parallel.DistributedDataParallel(
+ model, device_ids=[rank], output_device=rank
+ )
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)

# Setup CIFAR10
num_workers = get_num_workers()
Expand Down Expand Up @@ -201,7 +212,9 @@ Click here to see `the code for this example
progress_bar = tqdm(
total=len(train_dataloader),
desc=f"Train epoch {epoch}",
+ disable=not is_master,
- disable=not sys.stdout.isatty(), # Disable progress bar in non-interactive environments.
+ # Disable progress bar in non-interactive environments.
+ disable=not (sys.stdout.isatty() and is_master),
)

# Training loop
Expand Down Expand Up @@ -260,10 +273,14 @@ Click here to see `the code for this example
progress_bar.close()

val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
- logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
- logger.info(
- f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
- )
+ # NOTE: This would log the same values in all workers. Only logging on master:
+ if is_master:
+ logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
+ logger.info(
+ f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
+ )

print("Done!")

Expand Down Expand Up @@ -351,11 +368,17 @@ Click here to see `the code for this example
+ torch.distributed.barrier()
train_dataset = CIFAR10(
- root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
+ root=dataset_path, transform=transforms.ToTensor(), download=is_master, train=True
+ root=dataset_path,
+ transform=transforms.ToTensor(),
+ download=is_master,
+ train=True,
)
test_dataset = CIFAR10(
- root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
+ root=dataset_path, transform=transforms.ToTensor(), download=is_master, train=False
+ root=dataset_path,
+ transform=transforms.ToTensor(),
+ download=is_master,
+ train=False,
)
+ if is_master:
+ # Join the workers waiting in the barrier above. They can now load the datasets from disk.
Expand Down
5 changes: 5 additions & 0 deletions docs/examples/distributed/multi_gpu/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ Click here to see `the code for this example
.. literalinclude:: job.sh.diff
:language: diff

**pyproject.toml**

.. literalinclude:: pyproject.toml
:language: toml

**main.py**

.. literalinclude:: main.py.diff
Expand Down
35 changes: 8 additions & 27 deletions docs/examples/distributed/multi_gpu/job.sh
Original file line number Diff line number Diff line change
@@ -1,44 +1,25 @@
#!/bin/bash
#SBATCH --gpus-per-task=rtx8000:1
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=4
#SBATCH --ntasks-per-node=4
#SBATCH --nodes=1
#SBATCH --mem=16G
#SBATCH --time=00:15:00


# Echo time and hostname into log
set -e # exit on error.
echo "Date: $(date)"
echo "Hostname: $(hostname)"


# Ensure only anaconda/3 module loaded.
module --quiet purge
# This example uses Conda to manage package dependencies.
# See https://docs.mila.quebec/Userguide.html#conda for more information.
module load anaconda/3
module load cuda/11.7

# Creating the environment for the first time:
# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
# pytorch-cuda=11.7 -c pytorch -c nvidia
# Other conda packages:
# conda install -y -n pytorch -c conda-forge rich tqdm

# Activate pre-existing environment.
conda activate pytorch


# Stage dataset into $SLURM_TMPDIR
mkdir -p $SLURM_TMPDIR/data
ln -s /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
# General-purpose alternatives combining copy and unpack:
# unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
# tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/

# Get a unique port for this job based on the job ID
export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export MASTER_ADDR="127.0.0.1"

# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
unset CUDA_VISIBLE_DEVICES

# Execute Python script in each task (one per GPU)
srun python main.py
# Use `uv run --offline` on clusters without internet access on compute nodes.
srun uv run python main.py
38 changes: 31 additions & 7 deletions docs/examples/distributed/multi_gpu/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Multi-GPU Training example."""

import argparse
import logging
import os
from pathlib import Path
import sys

import rich.logging
import torch
Expand Down Expand Up @@ -40,10 +42,19 @@ def main():
device = torch.device("cuda", rank % torch.cuda.device_count())

# Setup logging (optional, but much better than using print statements)
# Uses the `rich` package to make logs pretty.
logging.basicConfig(
level=logging.INFO,
format=f"[{rank}/{world_size}] %(name)s - %(message)s ",
handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package.
handlers=[
rich.logging.RichHandler(
markup=True,
console=rich.console.Console(
# Allower wider log lines in sbatch output files than on the terminal.
width=120 if not sys.stdout.isatty() else None
),
)
],
)

logger = logging.getLogger(__name__)
Expand All @@ -55,9 +66,13 @@ def main():

# Wrap the model with DistributedDataParallel
# (See https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel)
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank
)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)

# Setup CIFAR10
num_workers = get_num_workers()
Expand Down Expand Up @@ -114,7 +129,8 @@ def main():
progress_bar = tqdm(
total=len(train_dataloader),
desc=f"Train epoch {epoch}",
disable=not is_master,
# Disable progress bar in non-interactive environments.
disable=not (sys.stdout.isatty() and is_master),
)

# Training loop
Expand Down Expand Up @@ -169,7 +185,9 @@ def main():
val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
# NOTE: This would log the same values in all workers. Only logging on master:
if is_master:
logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
logger.info(
f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
)

print("Done!")

Expand Down Expand Up @@ -252,10 +270,16 @@ def make_datasets(
# Wait for the master process to finish downloading (reach the barrier below)
torch.distributed.barrier()
train_dataset = CIFAR10(
root=dataset_path, transform=transforms.ToTensor(), download=is_master, train=True
root=dataset_path,
transform=transforms.ToTensor(),
download=is_master,
train=True,
)
test_dataset = CIFAR10(
root=dataset_path, transform=transforms.ToTensor(), download=is_master, train=False
root=dataset_path,
transform=transforms.ToTensor(),
download=is_master,
train=False,
)
if is_master:
# Join the workers waiting in the barrier above. They can now load the datasets from disk.
Expand Down
13 changes: 13 additions & 0 deletions docs/examples/distributed/multi_gpu/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[project]
name = "multi-gpu-example"
version = "0.1.0"
description = "Add your description here"
readme = "README.rst"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.3.1",
"rich>=14.0.0",
"torch>=2.7.1",
"torchvision>=0.22.1",
"tqdm>=4.67.1",
]
Loading