-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlaunch_training.sh
More file actions
42 lines (35 loc) · 1.78 KB
/
launch_training.sh
File metadata and controls
42 lines (35 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/bin/bash
# =============================================================================
# Launch Script for 8xA100 40GB Training
# =============================================================================
set -e
NGPUS=${NGPUS:-8}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
echo "=============================================="
echo " GPT Training Launch Script"
echo "=============================================="
# -----------------------------------------------------------------------------
# Step 1: Install Python dependencies
# -----------------------------------------------------------------------------
echo ""
echo "[1/3] Installing Python dependencies..."
pip install -r requirements.txt --quiet
# -----------------------------------------------------------------------------
# Step 2: Download pre-tokenized dataset from HuggingFace Hub
# -----------------------------------------------------------------------------
echo ""
echo "[2/3] Downloading pre-tokenized dataset from HuggingFace Hub..."
echo " (This will cache to ~/.cache/huggingface/datasets/)"
python data_pipeline.py download
# -----------------------------------------------------------------------------
# Step 3: Launch distributed training
# -----------------------------------------------------------------------------
echo ""
echo "[3/3] Launching training on $NGPUS GPUs..."
echo "=============================================="
# PyTorch CUDA memory management (optional but helps on long runs)
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # Better error recovery
export NCCL_DEBUG=WARN # Log warnings only (change to INFO for debugging)
torchrun --standalone --nproc_per_node=$NGPUS training.py "$@"