-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun-dgx-h100.sh
More file actions
executable file
·65 lines (55 loc) · 1.8 KB
/
Copy pathrun-dgx-h100.sh
File metadata and controls
executable file
·65 lines (55 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/bin/bash
# launch file to run gpt3-175b
# Note: this file is configured to be launched via slurm tasks
# Note: use your appropriate launcher environment var to determine the rank of the GPU
RANK=${SLURM_PROCID}
GPUS_PER_NODE=8
# xla by default allocates many threads, which can conflict with
# the threads used by the plugin client and makes backtraces
# harder to debug and interpret.
export PJRT_NPROC=16
export OMP_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export TF_NUM_INTEROP_THREADS=1
export TF_NUM_INTRAOP_THREADS=1
# prefix launch command with nsys_cmd to collect nsight traces
# nsys_cmd="nsys profile --trace=nvtx -o nsys_${RANK} --force-overwrite=true"
# limit the visibility of devices to only the one that the process will be running (1 process/device)
intra_node_rank=$(( RANK % GPUS_PER_NODE ))
export CUDA_VISIBLE_DEVICES=${intra_node_rank}
cpus=(0-13,112-125 14-27,126-139 28-41,140-153 42-55,154-167 56-69,168-181 70-83,182-195 84-97,196-209 98-111,210-223)
mems=(0 0 0 0 1 1 1 1)
cpu_binding="${cpus[$intra_node_rank]}"
mem_binding="${mems[$intra_node_rank]}"
# run training
numactl --physcpubind $cpu_binding \
--membind $mem_binding \
python /opt/maxtext/run.py \
--fbmem 77 \
--cpus 1 \
--gpus 1 \
--nodes 64 \
--dp 1 \
--fsdp 1 \
--pp 8 \
--tp 8 \
--interleave 12 \
--model-name gpt3-175b \
--num-layers 96 \
--sequence-length=2048 \
--batch-size 128 \
--microbatch-size 4 \
--remat minimal \
--attention=cudnn_flash_te \
--xla-rs-threshold=51200 \
--use-nccl-comm-split \
--replicate-small-params \
--no-hoist-loop-convert \
--network ucx \
--schedule prefetch-wavefront \
--autoshard \
--num-steps 10 \
--debug info \
--backend multimesh \
-logfile jax_%.log \
>& ${RANK}.out