@@ -69,7 +69,7 @@ repository.
6969 +# Copy the dataset to $SLURM_TMPDIR so it is close to the GPUs for
7070 +# faster training
7171 +srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \
72- + time -p bash data.sh "/network/datasets/inat" ${_DATA_PREP_WORKERS}
72+ + time -p bash data.py "/network/datasets/inat" ${_DATA_PREP_WORKERS}
7373
7474
7575 # Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
@@ -293,67 +293,64 @@ repository.
293293 main()
294294
295295
296- **data.sh **
296+ **data.py **
297297
298- .. code :: bash
298+ .. code :: python
299299
300- #! /bin/bash
301- set -o errexit
300+ """ Make sure the data is available"""
301+ import os
302+ import shutil
303+ import sys
304+ import time
305+ from multiprocessing import Pool
306+ from pathlib import Path
302307
303- function ln_files {
304- # Clone the dataset structure of `src` to `dest` with symlinks and using
305- # `workers` numbre of workers (defaults to 4)
306- local src=$1
307- local dest=$2
308- local workers=${3:- 4}
308+ from torchvision.datasets import INaturalist
309309
310- (cd " ${src} " && find -L * -type f) | while read f
311- do
312- mkdir --parents " ${dest} /$( dirname " $f " ) "
313- # echo source first so it is matched to the ln's '-T' argument
314- readlink --canonicalize " ${src} /$f "
315- # echo output last so ln understands it's the output file
316- echo " ${dest} /$f "
317- done | xargs -n2 -P${workers} ln --symbolic --force -T
318- }
319310
320- _SRC=$1
321- _WORKERS=$2
322- # Referencing $SLURM_TMPDIR here instead of job.sh makes sure that the
323- # environment variable will only be resolved on the worker node (i.e. not
324- # referencing the $SLURM_TMPDIR of the master node)
325- _DEST=$SLURM_TMPDIR /data
311+ def link_file (src :str , dest :str ):
312+ Path(src).symlink_to(dest)
326313
327- ln_files " ${_SRC} " " ${_DEST} " ${_WORKERS}
328314
329- # Reorganise the files if needed
330- (
331- cd " ${_DEST} "
332- # Torchvision expects these names
333- mv train.tar.gz 2021_train.tgz
334- mv val.tar.gz 2021_valid.tgz
335- )
315+ def link_files (src :str , dest :str , workers = 4 ):
316+ src = Path(src)
317+ dest = Path(dest)
318+ os.makedirs(dest, exist_ok = True )
319+ with Pool(processes = workers) as pool:
320+ for path, dnames, fnames in os.walk(str (src)):
321+ rel_path = Path(path).relative_to(src)
322+ fnames = map (lambda _f : rel_path / _f, fnames)
323+ dnames = map (lambda _d : rel_path / _d, dnames)
324+ for d in dnames:
325+ os.makedirs(str (dest / d), exist_ok = True )
326+ pool.starmap(
327+ link_file,
328+ [(src / _f, dest / _f) for _f in fnames]
329+ )
336330
337- # Extract and prepare the data
338- python3 data.py " ${_DEST} "
339331
332+ if __name__ == " __main__" :
333+ src = Path(sys.argv[1 ])
334+ workers = int (sys.argv[2 ])
335+ # Referencing $SLURM_TMPDIR here instead of job.sh makes sure that the
336+ # environment variable will only be resolved on the worker node (i.e. not
337+ # referencing the $SLURM_TMPDIR of the master node)
338+ dest = Path(os.environ[" SLURM_TMPDIR" ]) / " dest"
340339
341- ** data.py **
340+ start_time = time.time()
342341
343- .. code :: python
342+ link_files(src, dest, workers)
344343
345- """ Make sure the data is available """
346- import sys
347- import time
344+ # Torchvision expects these names
345+ shutil.move(dest / " train.tar.gz " , dest / " 2021_train.tgz " )
346+ shutil.move(dest / " val.tar.gz " , dest / " 2021_valid.tgz " )
348347
349- from torchvision.datasets import INaturalist
348+ INaturalist(root = dest, version = " 2021_train" , download = True )
349+ INaturalist(root = dest, version = " 2021_valid" , download = True )
350350
351+ seconds_spent = time.time() - start_time
351352
352- start_time = time.time()
353- INaturalist(root = sys.argv[1 ], version = " 2021_train" , download = True )
354- INaturalist(root = sys.argv[1 ], version = " 2021_valid" , download = True )
355- seconds_spent = time.time() - start_time
356- print (f " Prepared data in { seconds_spent/ 60 :.2f } m " )
353+ print (f " Prepared data in { seconds_spent/ 60 :.2f } m " )
357354
358355
359356 **Running this example **
0 commit comments