|
| 1 | +<!-- NOTE: This file is auto-generated from examples/distributed/single_gpu/index.md |
| 2 | + This is done so this file can be easily viewed from the GitHub UI. |
| 3 | + DO NOT EDIT --> |
| 4 | + |
1 | 5 | ### Single GPU Job |
2 | 6 |
|
3 | 7 | **Prerequisites** |
4 | 8 | Make sure to read the following sections of the documentation before using this |
5 | 9 | example: |
6 | 10 |
|
7 | | -* [examples/frameworks/pytorch_setup ](https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/frameworks/pytorch_setup) |
| 11 | +* :doc:`/examples/frameworks/pytorch_setup/index` |
8 | 12 |
|
9 | | -The full source code for this example is available on `the mila-docs GitHub |
| 13 | +The full source code for this example is available on [the mila-docs GitHub |
10 | 14 | repository. |
11 | | -<https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/distributed/single_gpu>`_ |
| 15 | +](https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/distributed/single_gpu) |
12 | 16 |
|
13 | 17 | **job.sh** |
14 | 18 |
|
15 | | -.. code:: bash |
16 | | - |
17 | | - #!/bin/bash |
18 | | - #SBATCH --ntasks=1 |
19 | | - #SBATCH --ntasks-per-node=1 |
20 | | - #SBATCH --cpus-per-task=4 |
21 | | - #SBATCH --gpus-per-task=l40s:1 |
22 | | - #SBATCH --mem-per-gpu=16G |
23 | | - #SBATCH --time=00:15:00 |
24 | | - |
25 | | - # Exit on error |
26 | | - set -e |
27 | | - |
28 | | - # Echo time and hostname into log |
29 | | - echo "Date: $(date)" |
30 | | - echo "Hostname: $(hostname)" |
31 | | - |
32 | | - # To make your code as much reproducible as possible with |
33 | | - # `torch.use_deterministic_algorithms(True)`, uncomment the following block: |
34 | | - ## === Reproducibility === |
35 | | - ## Be warned that this can make your code slower. See |
36 | | - ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations |
37 | | - ## for more details. |
38 | | - # export CUBLAS_WORKSPACE_CONFIG=:4096:8 |
39 | | - ## === Reproducibility (END) === |
40 | | - |
41 | | - # Stage dataset into $SLURM_TMPDIR |
42 | | - mkdir -p $SLURM_TMPDIR/data |
43 | | - cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/ |
44 | | - # General-purpose alternatives combining copy and unpack: |
45 | | - # unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/ |
46 | | - # tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/ |
47 | | - |
48 | | - # Execute Python script |
49 | | - # Use the `--offline` option of `uv run` on clusters without internet access on compute nodes. |
50 | | - # Using the `--locked` option can help make your experiments easier to reproduce (it forces |
51 | | - # your uv.lock file to be up to date with the dependencies declared in pyproject.toml). |
52 | | - srun uv run python main.py |
53 | | - |
54 | 19 | **pyproject.toml** |
55 | 20 |
|
56 | | -.. code:: toml |
57 | | - |
58 | | - [project] |
59 | | - name = "single-gpu-example" |
60 | | - version = "0.1.0" |
61 | | - description = "Add your description here" |
62 | | - readme = "README.rst" |
63 | | - requires-python = ">=3.11,<3.14" |
64 | | - dependencies = [ |
65 | | - "rich>=14.0.0", |
66 | | - "torch>=2.7.1", |
67 | | - "torchvision>=0.22.1", |
68 | | - "tqdm>=4.67.1", |
69 | | - ] |
70 | | - |
71 | 21 | **main.py** |
72 | 22 |
|
73 | | -.. code:: python |
74 | | - |
75 | | - """Single-GPU training example.""" |
76 | | - |
77 | | - import argparse |
78 | | - import logging |
79 | | - import os |
80 | | - import random |
81 | | - import sys |
82 | | - from pathlib import Path |
83 | | - |
84 | | - import numpy as np |
85 | | - import rich.logging |
86 | | - import torch |
87 | | - from torch import Tensor, nn |
88 | | - from torch.nn import functional as F |
89 | | - from torch.utils.data import DataLoader, random_split |
90 | | - from torchvision import transforms |
91 | | - from torchvision.datasets import CIFAR10 |
92 | | - from torchvision.models import resnet18 |
93 | | - from tqdm import tqdm |
94 | | - |
95 | | - # To make your code as much reproducible as possible, uncomment the following |
96 | | - # block: |
97 | | - ## === Reproducibility === |
98 | | - ## Be warned that this can make your code slower. See |
99 | | - ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations |
100 | | - ## for more details. |
101 | | - # torch.use_deterministic_algorithms(True) |
102 | | - ## === Reproducibility (END) === |
103 | | - |
104 | | - def main(): |
105 | | - # Use an argument parser so we can pass hyperparameters from the command line. |
106 | | - parser = argparse.ArgumentParser(description=__doc__) |
107 | | - parser.add_argument("--epochs", type=int, default=10) |
108 | | - parser.add_argument("--learning-rate", type=float, default=5e-4) |
109 | | - parser.add_argument("--weight-decay", type=float, default=1e-4) |
110 | | - parser.add_argument("--batch-size", type=int, default=128) |
111 | | - parser.add_argument("--seed", type=int, default=42) |
112 | | - args = parser.parse_args() |
113 | | - |
114 | | - epochs: int = args.epochs |
115 | | - learning_rate: float = args.learning_rate |
116 | | - weight_decay: float = args.weight_decay |
117 | | - batch_size: int = args.batch_size |
118 | | - seed: int = args.seed |
119 | | - |
120 | | - # Seed the random number generators as early as possible for reproducibility |
121 | | - random.seed(seed) |
122 | | - np.random.seed(seed) |
123 | | - torch.random.manual_seed(seed) |
124 | | - torch.cuda.manual_seed_all(seed) |
125 | | - |
126 | | - # Check that the GPU is available |
127 | | - assert torch.cuda.is_available() and torch.cuda.device_count() > 0 |
128 | | - device = torch.device("cuda", 0) |
129 | | - |
130 | | - # Setup logging (optional, but much better than using print statements) |
131 | | - # Uses the `rich` package to make logs pretty. |
132 | | - logging.basicConfig( |
133 | | - level=logging.INFO, |
134 | | - format="%(message)s", |
135 | | - handlers=[ |
136 | | - rich.logging.RichHandler( |
137 | | - markup=True, |
138 | | - console=rich.console.Console( |
139 | | - # Allower wider log lines in sbatch output files than on the terminal. |
140 | | - width=120 if not sys.stdout.isatty() else None |
141 | | - ), |
142 | | - ) |
143 | | - ], |
144 | | - ) |
145 | | - |
146 | | - logger = logging.getLogger(__name__) |
147 | | - |
148 | | - # Create a model and move it to the GPU. |
149 | | - model = resnet18(num_classes=10) |
150 | | - model.to(device=device) |
151 | | - |
152 | | - optimizer = torch.optim.AdamW( |
153 | | - model.parameters(), lr=learning_rate, weight_decay=weight_decay |
154 | | - ) |
155 | | - |
156 | | - # Setup CIFAR10 |
157 | | - num_workers = get_num_workers() |
158 | | - dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data" |
159 | | - train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path)) |
160 | | - train_dataloader = DataLoader( |
161 | | - train_dataset, |
162 | | - batch_size=batch_size, |
163 | | - num_workers=num_workers, |
164 | | - shuffle=True, |
165 | | - ) |
166 | | - valid_dataloader = DataLoader( |
167 | | - valid_dataset, |
168 | | - batch_size=batch_size, |
169 | | - num_workers=num_workers, |
170 | | - shuffle=False, |
171 | | - ) |
172 | | - _test_dataloader = DataLoader( # NOTE: Not used in this example. |
173 | | - test_dataset, |
174 | | - batch_size=batch_size, |
175 | | - num_workers=num_workers, |
176 | | - shuffle=False, |
177 | | - ) |
178 | | - |
179 | | - # Checkout the "checkpointing and preemption" example for more info! |
180 | | - logger.debug("Starting training from scratch.") |
181 | | - |
182 | | - for epoch in range(epochs): |
183 | | - logger.debug(f"Starting epoch {epoch}/{epochs}") |
184 | | - |
185 | | - # Set the model in training mode (important for e.g. BatchNorm and Dropout layers) |
186 | | - model.train() |
187 | | - |
188 | | - # NOTE: using a progress bar from tqdm because it's nicer than using `print`. |
189 | | - progress_bar = tqdm( |
190 | | - total=len(train_dataloader), |
191 | | - desc=f"Train epoch {epoch}", |
192 | | - disable=not sys.stdout.isatty(), # Disable progress bar in non-interactive environments. |
193 | | - ) |
194 | | - |
195 | | - # Training loop |
196 | | - for batch in train_dataloader: |
197 | | - # Move the batch to the GPU before we pass it to the model |
198 | | - batch = tuple(item.to(device) for item in batch) |
199 | | - x, y = batch |
200 | | - |
201 | | - # Forward pass |
202 | | - logits: Tensor = model(x) |
203 | | - |
204 | | - loss = F.cross_entropy(logits, y) |
205 | | - |
206 | | - optimizer.zero_grad() |
207 | | - loss.backward() |
208 | | - optimizer.step() |
209 | | - |
210 | | - # Calculate some metrics: |
211 | | - n_correct_predictions = logits.detach().argmax(-1).eq(y).sum() |
212 | | - n_samples = y.shape[0] |
213 | | - accuracy = n_correct_predictions / n_samples |
214 | | - |
215 | | - logger.debug(f"Accuracy: {accuracy.item():.2%}") |
216 | | - logger.debug(f"Average Loss: {loss.item()}") |
217 | | - |
218 | | - # Advance the progress bar one step and update the progress bar text. |
219 | | - progress_bar.update(1) |
220 | | - progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item()) |
221 | | - progress_bar.close() |
222 | | - |
223 | | - val_loss, val_accuracy = validation_loop(model, valid_dataloader, device) |
224 | | - logger.info( |
225 | | - f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}" |
226 | | - ) |
227 | | - |
228 | | - print("Done!") |
229 | | - |
230 | | - @torch.no_grad() |
231 | | - def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device): |
232 | | - model.eval() |
233 | | - |
234 | | - total_loss = 0.0 |
235 | | - n_samples = 0 |
236 | | - correct_predictions = 0 |
237 | | - |
238 | | - for batch in dataloader: |
239 | | - batch = tuple(item.to(device) for item in batch) |
240 | | - x, y = batch |
241 | | - |
242 | | - logits: Tensor = model(x) |
243 | | - loss = F.cross_entropy(logits, y) |
244 | | - |
245 | | - batch_n_samples = x.shape[0] |
246 | | - batch_correct_predictions = logits.argmax(-1).eq(y).sum() |
247 | | - |
248 | | - total_loss += loss.item() |
249 | | - n_samples += batch_n_samples |
250 | | - correct_predictions += batch_correct_predictions |
251 | | - |
252 | | - accuracy = correct_predictions / n_samples |
253 | | - return total_loss, accuracy |
254 | | - |
255 | | - def make_datasets( |
256 | | - dataset_path: str, |
257 | | - val_split: float = 0.1, |
258 | | - val_split_seed: int = 42, |
259 | | - ): |
260 | | - """Returns the training, validation, and test splits for CIFAR10. |
261 | | - |
262 | | - NOTE: We don't use image transforms here for simplicity. |
263 | | - Having different transformations for train and validation would complicate things a bit. |
264 | | - Later examples will show how to do the train/val/test split properly when using transforms. |
265 | | - """ |
266 | | - train_dataset = CIFAR10( |
267 | | - root=dataset_path, transform=transforms.ToTensor(), download=True, train=True |
268 | | - ) |
269 | | - test_dataset = CIFAR10( |
270 | | - root=dataset_path, transform=transforms.ToTensor(), download=True, train=False |
271 | | - ) |
272 | | - # Split the training dataset into a training and validation set. |
273 | | - n_samples = len(train_dataset) |
274 | | - n_valid = int(val_split * n_samples) |
275 | | - n_train = n_samples - n_valid |
276 | | - train_dataset, valid_dataset = random_split( |
277 | | - train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed) |
278 | | - ) |
279 | | - return train_dataset, valid_dataset, test_dataset |
280 | | - |
281 | | - def get_num_workers() -> int: |
282 | | - """Gets the optimal number of DatLoader workers to use in the current job.""" |
283 | | - if "SLURM_CPUS_PER_TASK" in os.environ: |
284 | | - return int(os.environ["SLURM_CPUS_PER_TASK"]) |
285 | | - if hasattr(os, "sched_getaffinity"): |
286 | | - return len(os.sched_getaffinity(0)) |
287 | | - return torch.multiprocessing.cpu_count() |
288 | | - |
289 | | - if __name__ == "__main__": |
290 | | - main() |
291 | | - |
292 | 23 | **Running this example** |
293 | 24 |
|
294 | | -.. code-block:: bash |
| 25 | +```bash |
295 | 26 |
|
296 | | - $ sbatch job.sh |
| 27 | + $ sbatch job.sh |
0 commit comments