Skip to content

Commit 41f93cb

Browse files
committed
Add huggingface dataset example
1 parent b932b50 commit 41f93cb

File tree

11 files changed

+1010
-0
lines changed

11 files changed

+1010
-0
lines changed

docs/examples/data/hf/README.rst

Lines changed: 452 additions & 0 deletions
Large diffs are not rendered by default.

docs/examples/data/hf/_index.rst

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
Hugging Face Dataset
2+
====================
3+
4+
5+
**Prerequisites**
6+
7+
Make sure to read the following sections of the documentation before using this
8+
example:
9+
10+
* :ref:`pytorch_setup`
11+
* :ref:`001 - Single GPU Job`
12+
13+
The full source code for this example is available on `the mila-docs GitHub
14+
repository.
15+
<https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/data/hf>`_
16+
17+
18+
**job.sh**
19+
20+
.. literalinclude:: examples/data/hf/job.sh.diff
21+
:language: diff
22+
23+
24+
**main.py**
25+
26+
.. literalinclude:: examples/data/hf/main.py.diff
27+
:language: diff
28+
29+
30+
**prepare_data.py**
31+
32+
.. literalinclude:: examples/data/hf/prepare_data.py
33+
:language: python
34+
35+
36+
**data.sh**
37+
38+
.. literalinclude:: examples/data/hf/data.sh
39+
:language: bash
40+
41+
42+
**get_dataset_cache_files.py**
43+
44+
.. literalinclude:: examples/data/hf/get_dataset_cache_files.py
45+
:language: python
46+
47+
48+
**Running this example**
49+
50+
.. code-block:: bash
51+
52+
$ sbatch job.sh

docs/examples/data/hf/data.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
set -o errexit
3+
4+
_SRC=$1
5+
_DEST=$2
6+
_WORKERS=$3
7+
8+
# Clone the dataset structure (not the data itself) locally so HF can find the
9+
# cache hashes it looks for. Else HF might think he needs to redo some
10+
# preprocessing. Directories will be created and symlinks will replace the files
11+
bash sh_utils.sh ln_files "${_SRC}" "${_DEST}" $_WORKERS
12+
13+
# Copy the preprocessed dataset to compute node's local dataset cache dir so it
14+
# is close to the GPUs for faster training. Since HF can very easily change the
15+
# hash to reference a preprocessed dataset, we only copy the data for the
16+
# current preprocess pipeline.
17+
python3 get_dataset_cache_files.py | bash sh_utils.sh cp_files "${_SRC}" "${_DEST}" $_WORKERS
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""List to stdout the files of the dataset"""
2+
3+
from pathlib import Path
4+
import sys
5+
6+
import datasets
7+
8+
from py_utils import (
9+
get_dataset_builder, get_num_workers, get_raw_datasets, get_tokenizer,
10+
preprocess_datasets
11+
)
12+
13+
14+
if __name__ == "__main__":
15+
# Redirect outputs to stderr to avoid noize in stdout
16+
_stdout = sys.stdout
17+
sys.stdout = sys.stderr
18+
19+
try:
20+
_CACHE_DIR = sys.argv[1]
21+
except IndexError:
22+
_CACHE_DIR = datasets.config.HF_DATASETS_CACHE
23+
try:
24+
_WORKERS = int(sys.argv[2])
25+
except IndexError:
26+
_WORKERS = get_num_workers()
27+
28+
cache_dir = Path(_CACHE_DIR)
29+
builder = get_dataset_builder(cache_dir=_CACHE_DIR)
30+
raw_datasets = get_raw_datasets(builder)
31+
tokenizer = get_tokenizer()
32+
for dataset in preprocess_datasets(tokenizer, raw_datasets, num_workers=_WORKERS).values():
33+
for cache_file in dataset.cache_files:
34+
cache_file = Path(cache_file["filename"]).relative_to(cache_dir)
35+
print(cache_file, file=_stdout)

docs/examples/data/hf/job.sh

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/bin/bash
2+
#SBATCH --gpus-per-task=rtx8000:1
3+
#SBATCH --cpus-per-task=12
4+
#SBATCH --ntasks-per-node=1
5+
#SBATCH --mem=48G
6+
#SBATCH --time=04:00:00
7+
#SBATCH --tmp=1500G
8+
set -o errexit
9+
10+
function wrap_cmd {
11+
for a in "$@"
12+
do
13+
echo -n "\"$a\" "
14+
done
15+
}
16+
17+
18+
# Echo time and hostname into log
19+
echo "Date: $(date)"
20+
echo "Hostname: $(hostname)"
21+
22+
23+
# Ensure only anaconda/3 module loaded.
24+
module --quiet purge
25+
# This example uses Conda to manage package dependencies.
26+
# See https://docs.mila.quebec/Userguide.html#conda for more information.
27+
module load anaconda/3
28+
module load cuda/11.7
29+
30+
31+
# Creating the environment for the first time:
32+
# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
33+
# pytorch-cuda=11.7 scipy -c pytorch -c nvidia
34+
# Other conda packages:
35+
# conda install -y -n pytorch -c conda-forge rich tqdm datasets
36+
37+
# Activate pre-existing environment.
38+
conda activate pytorch
39+
40+
41+
if [[ -z "$HF_DATASETS_CACHE" ]]
42+
then
43+
# Store the huggingface datasets cache in $SCRATCH
44+
export HF_DATASETS_CACHE=$SCRATCH/cache/huggingface/datasets
45+
fi
46+
if [[ -z "$HUGGINGFACE_HUB_CACHE" ]]
47+
then
48+
# Store the huggingface hub cache in $SCRATCH
49+
export HUGGINGFACE_HUB_CACHE=$SCRATCH/cache/huggingface/hub
50+
fi
51+
if [[ -z "$_DATA_PREP_WORKERS" ]]
52+
then
53+
_DATA_PREP_WORKERS=$SLURM_JOB_CPUS_PER_NODE
54+
fi
55+
if [[ -z "$_DATA_PREP_WORKERS" ]]
56+
then
57+
_DATA_PREP_WORKERS=16
58+
fi
59+
60+
61+
# Preprocess the dataset and cache the result such that the heavy work is done
62+
# only once *ever*
63+
# Required conda packages:
64+
# conda install -y -c conda-forge zstandard
65+
srun --ntasks=1 --ntasks-per-node=1 \
66+
time -p python3 prepare_data.py "/network/datasets/pile" $_DATA_PREP_WORKERS
67+
68+
69+
# Copy the preprocessed dataset to $SLURM_TMPDIR so it is close to the GPUs for
70+
# faster training. This should be done once per compute node
71+
cmd=(
72+
# Having 'bash' here allows the execution of a script file which might not
73+
# have the execution flag on
74+
bash data.sh
75+
# The current dataset cache dir
76+
"$HF_DATASETS_CACHE"
77+
# The local dataset cache dir
78+
# Use '' to lazy expand the expression such as $SLURM_TMPDIR will be
79+
# interpreted on the local compute node rather than the master node
80+
'$SLURM_TMPDIR/data'
81+
$_DATA_PREP_WORKERS
82+
)
83+
# 'time' will objectively give a measure for the copy of the dataset. This can
84+
# be used to compare the timing of multiple code versionw and make sure any slow
85+
# down doesn't come from the code itself.
86+
srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \
87+
time -p bash -c "$(wrap_cmd "${cmd[@]}")"
88+
89+
90+
# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
91+
unset CUDA_VISIBLE_DEVICES
92+
93+
# Execute Python script
94+
env_var=(
95+
# Use the local copy of the preprocessed dataset
96+
HF_DATASETS_CACHE='"$SLURM_TMPDIR/data"'
97+
)
98+
cmd=(
99+
python3
100+
main.py
101+
)
102+
srun bash -c "$(echo "${env_var[@]}") $(wrap_cmd "${cmd[@]}")"

docs/examples/data/hf/main.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""HuggingFace training example."""
2+
import logging
3+
4+
import rich.logging
5+
import torch
6+
from torch import nn
7+
from torch.utils.data import DataLoader
8+
from tqdm import tqdm
9+
10+
from py_utils import (
11+
get_dataset_builder, get_num_workers, get_raw_datasets, get_tokenizer,
12+
preprocess_datasets
13+
)
14+
15+
16+
def main():
17+
training_epochs = 1
18+
batch_size = 256
19+
20+
# Check that the GPU is available
21+
assert torch.cuda.is_available() and torch.cuda.device_count() > 0
22+
device = torch.device("cuda", 0)
23+
24+
# Setup logging (optional, but much better than using print statements)
25+
logging.basicConfig(
26+
level=logging.INFO,
27+
handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package.
28+
)
29+
30+
logger = logging.getLogger(__name__)
31+
32+
# Setup ImageNet
33+
num_workers = get_num_workers()
34+
train_dataset, valid_dataset, test_dataset = make_datasets(num_workers)
35+
train_dataloader = DataLoader(
36+
train_dataset,
37+
batch_size=batch_size,
38+
num_workers=num_workers,
39+
shuffle=True,
40+
)
41+
valid_dataloader = DataLoader(
42+
valid_dataset,
43+
batch_size=batch_size,
44+
num_workers=num_workers,
45+
shuffle=False,
46+
)
47+
test_dataloader = DataLoader( # NOTE: Not used in this example.
48+
test_dataset,
49+
batch_size=batch_size,
50+
num_workers=num_workers,
51+
shuffle=False,
52+
)
53+
54+
# Checkout the "checkpointing and preemption" example for more info!
55+
logger.debug("Starting training from scratch.")
56+
57+
for epoch in range(training_epochs):
58+
logger.debug(f"Starting epoch {epoch}/{training_epochs}")
59+
60+
# NOTE: using a progress bar from tqdm because it's nicer than using `print`.
61+
progress_bar = tqdm(
62+
total=len(train_dataloader),
63+
desc=f"Train epoch {epoch}",
64+
)
65+
66+
# Training loop
67+
for batch in train_dataloader:
68+
# Move the batch to the GPU before we pass it to the model
69+
batch = {k:item.to(device) for k, item in batch.items()}
70+
71+
# [Training of the model goes here]
72+
73+
# Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just)
74+
progress_bar.update(1)
75+
progress_bar.close()
76+
77+
val_loss, val_accuracy = validation_loop(None, valid_dataloader, device)
78+
logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
79+
80+
print("Done!")
81+
82+
83+
@torch.no_grad()
84+
def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
85+
total_loss = 0.0
86+
n_samples = 0
87+
correct_predictions = 0
88+
89+
for batch in dataloader:
90+
batch = {k:item.to(device) for k, item in batch.items()}
91+
92+
batch_n_samples = batch["input_ids"].data.shape[0]
93+
94+
n_samples += batch_n_samples
95+
96+
accuracy = correct_predictions / n_samples
97+
return total_loss, accuracy
98+
99+
100+
def make_datasets(num_workers:int=None):
101+
"""Returns the training, validation, and test splits for the prepared dataset.
102+
"""
103+
builder = get_dataset_builder()
104+
raw_datasets = get_raw_datasets(builder)
105+
tokenizer = get_tokenizer()
106+
preprocessed_datasets = preprocess_datasets(tokenizer, raw_datasets, num_workers=num_workers)
107+
return (
108+
preprocessed_datasets["train"], preprocessed_datasets["validation"],
109+
preprocessed_datasets["test"]
110+
)
111+
112+
113+
if __name__ == "__main__":
114+
main()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Preprocess the dataset.
2+
In this example, HuggingFace is used and the resulting dataset will be stored in
3+
$HF_DATASETS_CACHE. It is preferable to set the datasets cache to a location in
4+
$SCRATCH"""
5+
6+
from py_utils import (
7+
get_config, get_dataset_builder, get_num_workers, get_raw_datasets,
8+
get_tokenizer, preprocess_datasets
9+
)
10+
11+
12+
if __name__ == "__main__":
13+
import sys
14+
import time
15+
16+
_LOCAL_DS = sys.argv[1]
17+
try:
18+
_WORKERS = int(sys.argv[2])
19+
except IndexError:
20+
_WORKERS = get_num_workers()
21+
22+
t = -time.time()
23+
_ = get_config()
24+
builder = get_dataset_builder(local_dataset=_LOCAL_DS, num_workers=_WORKERS)
25+
raw_datasets = get_raw_datasets(builder)
26+
tokenizer = get_tokenizer()
27+
_ = preprocess_datasets(tokenizer, raw_datasets, num_workers=_WORKERS)
28+
t += time.time()
29+
30+
print(f"Prepared data in {t/60:.2f}m")

0 commit comments

Comments
 (0)