Skip to content

Commit c88b9dc

Browse files
committed
Add big dataset examples
1 parent 47e7860 commit c88b9dc

File tree

22 files changed

+974
-3
lines changed

22 files changed

+974
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
_build
33
.idea
44
**/__pycache__
5+
/docs/examples/**/*.diff

docs/Minimal_examples.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
.. ***************************
1+
.. ****************
22
.. Minimal Examples
3-
.. ***************************
3+
.. ****************
44
55
66
.. include:: examples/frameworks/README.rst
7+
.. include:: examples/distributed/README.rst
8+
.. include:: examples/data/README.rst

docs/conf.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from __future__ import division, print_function, unicode_literals
44

55
from datetime import datetime
6-
6+
import subprocess
7+
from pathlib import Path
78
import sphinx_theme
89

910
extensions = [
@@ -90,5 +91,18 @@
9091
# Include CNAME file so GitHub Pages can set Custom Domain name
9192
html_extra_path = ['CNAME']
9293

94+
95+
# Generate the diffs that are shown in the examples.
96+
file_dir = Path(__file__).parent / "examples/generate_diffs.sh"
97+
try:
98+
proc = subprocess.run(str(file_dir), shell=True, capture_output=True, check=True)
99+
except subprocess.CalledProcessError as err:
100+
raise RuntimeError(
101+
"Could not build the diff files for the examples:\n"
102+
+ str(err.output, encoding="utf-8")
103+
+ str(err.stderr, encoding="utf-8")
104+
)
105+
106+
93107
def setup(app):
94108
app.add_css_file('custom.css')

docs/examples/data/README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
*****************************
2+
Data Handling during Training
3+
*****************************
4+
5+
6+
.. include:: examples/data/torchvision/README.rst
7+
.. include:: examples/data/hf/README.rst

docs/examples/data/hf/README.rst

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
Hugging Face Dataset
2+
====================
3+
4+
5+
**Prerequisites**
6+
7+
Make sure to read the following sections of the documentation before using this example:
8+
9+
* :ref:`pytorch_setup`
10+
* :ref:`001 - Single GPU Job`
11+
12+
The full source code for this example is available on `the mila-docs GitHub repository. <https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/data/hf>`_
13+
14+
15+
**job.sh**
16+
17+
.. literalinclude:: examples/data/hf/job.sh.diff
18+
:language: diff
19+
20+
21+
**main.py**
22+
23+
.. literalinclude:: examples/data/hf/main.py.diff
24+
:language: diff
25+
26+
27+
**prepare_data.py**
28+
29+
.. literalinclude:: examples/data/hf/prepare_data.py
30+
:language: python
31+
32+
33+
**get_dataset_cache_dir.py**
34+
35+
.. literalinclude:: examples/data/hf/get_dataset_cache_dir.py
36+
:language: python
37+
38+
39+
**cp_data.sh**
40+
41+
.. literalinclude:: examples/data/hf/cp_data.sh
42+
:language: bash
43+
44+
45+
**Running this example**
46+
47+
.. code-block:: bash
48+
49+
$ sbatch job.sh

docs/examples/data/hf/cp_data.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
set -o errexit
3+
4+
_SRC=$1
5+
_DEST=$2
6+
_WORKERS=$3
7+
8+
# Copy the dataset
9+
(cd "${_SRC}" && find -L * -type f) | while read f
10+
do
11+
mkdir --parents "${_DEST}/$(dirname "$f")"
12+
# echo source first so it is matched to the cp's '-T' argument
13+
readlink --canonicalize "${_SRC}/$f"
14+
# echo output last so cp understands it's the output file
15+
echo "${_DEST}/$f"
16+
done | xargs -n2 -P${_WORKERS} cp --update -T
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""List to stdout the files of the dataset"""
2+
import sys
3+
4+
import datasets
5+
6+
7+
# Redirect outputs to stderr to avoid noize in stdout
8+
_stdout = sys.stdout
9+
sys.stdout = sys.stderr
10+
11+
try:
12+
_CACHE_DIR = sys.argv[1]
13+
except IndexError:
14+
_CACHE_DIR = None
15+
16+
builder = datasets.load_dataset_builder("the_pile", cache_dir=_CACHE_DIR, subsets=["all"], version="0.0.0")
17+
print(builder.cache_dir, file=_stdout)

docs/examples/data/hf/job.sh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/bin/bash
2+
#SBATCH --gpus-per-task=rtx8000:1
3+
#SBATCH --cpus-per-task=4
4+
#SBATCH --ntasks-per-node=1
5+
#SBATCH --mem=24G
6+
#SBATCH --time=02:00:00
7+
#SBATCH --tmp=1500G
8+
set -o errexit
9+
10+
11+
# Echo time and hostname into log
12+
echo "Date: $(date)"
13+
echo "Hostname: $(hostname)"
14+
15+
16+
# Ensure only anaconda/3 module loaded.
17+
module purge
18+
# This example uses Conda to manage package dependencies.
19+
# See https://docs.mila.quebec/Userguide.html#conda for more information.
20+
module load anaconda/3
21+
22+
23+
# Creating the environment for the first time:
24+
# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
25+
# pytorch-cuda=11.6 scipy -c pytorch -c nvidia
26+
# Other conda packages:
27+
# conda install -y -n pytorch -c conda-forge rich tqdm datasets
28+
29+
# Activate pre-existing environment.
30+
conda activate pytorch
31+
32+
33+
# Prepare data for training
34+
mkdir -p "$SLURM_TMPDIR/data"
35+
36+
if [[ -z "${HF_DATASETS_CACHE}" ]]
37+
then
38+
# Store the huggingface datasets cache in $SCRATCH
39+
export HF_DATASETS_CACHE=$SCRATCH/cache/huggingface/datasets
40+
fi
41+
if [[ -z "${_DATA_PREP_WORKERS}" ]]
42+
then
43+
_DATA_PREP_WORKERS=${SLURM_JOB_CPUS_PER_NODE}
44+
fi
45+
if [[ -z "${_DATA_PREP_WORKERS}" ]]
46+
then
47+
_DATA_PREP_WORKERS=16
48+
fi
49+
50+
# Preprocess the dataset and cache the result such that the heavy work is done
51+
# only once *ever*
52+
# Required conda packages:
53+
# conda install -y -c conda-forge zstandard
54+
srun --ntasks=1 --ntasks-per-node=1 \
55+
time -p python3 prepare_data.py "/network/datasets/pile" ${_DATA_PREP_WORKERS}
56+
57+
# Copy the preprocessed dataset to $SLURM_TMPDIR so it is close to the GPUs for
58+
# faster training
59+
# Get the current dataset cache
60+
_DATASET_CACHE_DIR=$(python3 get_dataset_cache_dir.py)
61+
# Get the local dataset cache
62+
_LOCAL_DATASET_CACHE_DIR=$(python3 get_dataset_cache_dir.py "$SLURM_TMPDIR/data")
63+
srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \
64+
time -p bash cp_data.sh "${_DATASET_CACHE_DIR}" "${_LOCAL_DATASET_CACHE_DIR}" ${_DATA_PREP_WORKERS}
65+
66+
# Use the local copy of the preprocessed dataset
67+
export HF_DATASETS_CACHE="$SLURM_TMPDIR/data"
68+
69+
70+
# Execute Python script
71+
python main.py

docs/examples/data/hf/main.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Torchvision training example."""
2+
import logging
3+
import os
4+
5+
import datasets
6+
import rich.logging
7+
import torch
8+
from torch import Tensor, nn
9+
from torch.nn import functional as F
10+
from torch.utils.data import DataLoader
11+
from torchvision.models import resnet18
12+
from tqdm import tqdm
13+
14+
15+
def main():
16+
training_epochs = 1
17+
learning_rate = 5e-4
18+
weight_decay = 1e-4
19+
batch_size = 256
20+
21+
# Check that the GPU is available
22+
assert torch.cuda.is_available() and torch.cuda.device_count() > 0
23+
device = torch.device("cuda", 0)
24+
25+
# Setup logging (optional, but much better than using print statements)
26+
logging.basicConfig(
27+
level=logging.INFO,
28+
handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package.
29+
)
30+
31+
logger = logging.getLogger(__name__)
32+
33+
# Create a model and move it to the GPU.
34+
model = resnet18()
35+
model.to(device=device)
36+
37+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
38+
39+
# Setup ImageNet
40+
num_workers = get_num_workers()
41+
dataset_path = "the_pile"
42+
train_dataset, valid_dataset, test_dataset = make_datasets(dataset_path)
43+
train_dataloader = DataLoader(
44+
train_dataset,
45+
batch_size=batch_size,
46+
num_workers=num_workers,
47+
shuffle=True,
48+
)
49+
valid_dataloader = DataLoader(
50+
valid_dataset,
51+
batch_size=batch_size,
52+
num_workers=num_workers,
53+
shuffle=False,
54+
)
55+
test_dataloader = DataLoader( # NOTE: Not used in this example.
56+
test_dataset,
57+
batch_size=batch_size,
58+
num_workers=num_workers,
59+
shuffle=False,
60+
)
61+
62+
# Checkout the "checkpointing and preemption" example for more info!
63+
logger.debug("Starting training from scratch.")
64+
65+
for epoch in range(training_epochs):
66+
logger.debug(f"Starting epoch {epoch}/{training_epochs}")
67+
68+
# Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers)
69+
model.train()
70+
71+
# NOTE: using a progress bar from tqdm because it's nicer than using `print`.
72+
progress_bar = tqdm(
73+
total=len(train_dataloader),
74+
desc=f"Train epoch {epoch}",
75+
)
76+
77+
# Training loop
78+
for batch in train_dataloader:
79+
# Move the batch to the GPU before we pass it to the model
80+
batch = tuple(item.to(device) for item in batch)
81+
82+
# [Training of the model goes here]
83+
84+
# Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just)
85+
progress_bar.update(1)
86+
progress_bar.close()
87+
88+
val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
89+
logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
90+
91+
print("Done!")
92+
93+
94+
@torch.no_grad()
95+
def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
96+
model.eval()
97+
98+
total_loss = 0.0
99+
n_samples = 0
100+
correct_predictions = 0
101+
102+
for batch in dataloader:
103+
batch = tuple(item.to(device) for item in batch)
104+
x, y = batch
105+
106+
logits: Tensor = model(x)
107+
loss = F.cross_entropy(logits, y)
108+
109+
batch_n_samples = x.shape[0]
110+
batch_correct_predictions = logits.argmax(-1).eq(y).sum()
111+
112+
total_loss += loss.item()
113+
n_samples += batch_n_samples
114+
correct_predictions += batch_correct_predictions
115+
116+
accuracy = correct_predictions / n_samples
117+
return total_loss, accuracy
118+
119+
120+
def make_datasets(dataset_path: str):
121+
"""Returns the training, validation, and test splits for ImageNet.
122+
123+
NOTE: We don't use transforms here for simplicity.
124+
Having different transformations for train and validation would complicate things a bit.
125+
Later examples will show how to do the train/val/test split properly when using transforms.
126+
"""
127+
builder = datasets.load_dataset_builder(dataset_path, subsets=["all"], version="0.0.0")
128+
train_dataset = builder.as_dataset(split="train").with_format("torch")
129+
valid_dataset = builder.as_dataset(split="validation").with_format("torch")
130+
test_dataset = builder.as_dataset(split="test").with_format("torch")
131+
return train_dataset, valid_dataset, test_dataset
132+
133+
134+
def get_num_workers() -> int:
135+
"""Gets the optimal number of DatLoader workers to use in the current job."""
136+
if "SLURM_CPUS_PER_TASK" in os.environ:
137+
return int(os.environ["SLURM_CPUS_PER_TASK"])
138+
if hasattr(os, "sched_getaffinity"):
139+
return len(os.sched_getaffinity(0))
140+
return torch.multiprocessing.cpu_count()
141+
142+
143+
if __name__ == "__main__":
144+
main()

0 commit comments

Comments
 (0)