Skip to content

Commit 5020a8c

Browse files
committed
[infra] Support CUDA-aware torch installation
This commit supports CUDA-aware torch installation. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 962669a commit 5020a8c

9 files changed

Lines changed: 132 additions & 16 deletions

infra/dependency/torch_dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
--extra-index-url "https://download.pytorch.org/whl/nightly/cpu"
21
torch==2.8.0.dev20250501+cpu
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
--extra-index-url "https://download.pytorch.org/whl/nightly/cpu"
21
torchvision==0.22.0.dev20250501+cpu

infra/scripts/install.sh

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
# [EXPORTED VARIABLES]
2020
# - CCEX_PROJECT_PATH
2121

22+
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
23+
2224
show_help() {
2325
# `cat << EOF` This means that cat should stop reading when EOF is detected
2426
cat << EOF
@@ -28,11 +30,16 @@ Usage: ./ccex install [--dist|--torch_ver|--help|-h]
2830
--torch_ver [2.5|2.6|nightly]
2931
Specify torch version to be installed.
3032
(default: 2.6)
33+
--cuda_ver [11.8|12.1]
34+
Specify the target CUDA version. This overrides automatic
35+
detection.
36+
--cpu_only Forces installation of the CPU-only version of PyTorch.
37+
Disables both CUDA version detection and use of --cuda-version.
3138
-h | --help Show help message and exit
3239
EOF
3340
}
3441

35-
options=$(getopt -o h --long dist,torch_ver:,help -- "$@")
42+
options=$(getopt -o h --long dist,torch_ver:,cuda_ver:,cpu_only,help -- "$@")
3643

3744
[ $? -eq 0 ] || {
3845
echo "Incorrect options provided"
@@ -41,6 +48,9 @@ options=$(getopt -o h --long dist,torch_ver:,help -- "$@")
4148

4249
_DIST=0
4350
_TORCH_VER="2.6"
51+
_USER_SPECIFIED_CUDA=""
52+
_CPU_ONLY=false
53+
_NIGHTLY=""
4454

4555
eval set -- "$options"
4656

@@ -53,6 +63,14 @@ while true; do
5363
_TORCH_VER="$2"
5464
shift
5565
;;
66+
--cuda_ver)
67+
_USER_SPECIFIED_CUDA="$2"
68+
shift
69+
;;
70+
--cpu_only)
71+
_CPU_ONLY=true
72+
shift
73+
;;
5674
-h|--help)
5775
show_help
5876
exit 0;
@@ -65,22 +83,69 @@ while true; do
6583
shift
6684
done
6785

86+
# Check torch version
6887
if [ "$_TORCH_VER" != "2.5" ] && [ "$_TORCH_VER" != "2.6" ] && [ "$_TORCH_VER" != "nightly" ]; then
6988
echo "Invalid torch version '$_TORCH_VER'"
7089
echo "(Use --help to see available torch versions)"
7190
exit 1
7291
fi
92+
# Check if torch version is nightly
93+
if [[ "$_TORCH_VER" == "nightly" ]]; then
94+
_NIGHTLY=true
95+
fi
96+
# Check for conflicting options
97+
if [[ -n "$_USER_SPECIFIED_CUDA" && "$_CPU_ONLY" = true ]]; then
98+
echo "[ERROR] Cannot use --cpu_only and --cuda_ver together"
99+
exit 1
100+
fi
101+
102+
# set index-url
103+
get_index_url_for_cuda_version() {
104+
local version="$1"
105+
local is_nightly="$2"
106+
107+
case "$version" in
108+
11.8) echo "https://download.pytorch.org/whl${is_nightly:+/nightly}/cu118" ;;
109+
12.1) echo "https://download.pytorch.org/whl${is_nightly:+/nightly}/cu121" ;;
110+
*) echo "" ;;
111+
esac
112+
}
113+
INDEX_URL="https://download.pytorch.org/whl${_NIGHTLY:+/nightly}/cpu"
114+
if [[ "$_CPU_ONLY" = true ]]; then
115+
echo "[INFO] Installing CPU-only version of PyTorch"
116+
else
117+
if [[ -n "$_USER_SPECIFIED_CUDA" ]]; then
118+
echo "[INFO] Using user-specified CUDA version: $_USER_SPECIFIED_CUDA"
119+
INDEX_URL=$(get_index_url_for_cuda_version "$_USER_SPECIFIED_CUDA" "$_NIGHTLY")
120+
else
121+
if command -v nvcc &> /dev/null; then
122+
DETECTED_CUDA=$(nvcc --version | grep -oP "release \K[0-9]+\.[0-9]+")
123+
elif command -v nvidia-smi &> /dev/null; then
124+
DETECTED_CUDA=$(nvcc --version | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+")
125+
else
126+
DETECTED_CUDA=""
127+
fi
128+
fi
129+
if [[ -n "$DETECTED_CUDA" ]]; then
130+
echo "[INFO] Detected CUDA version: $DETECTED_CUDA"
131+
INDEX_URL=$(get_index_url_for_cuda_version "$DETECTED_CUDA" "$_NIGHTLY")
132+
fi
133+
fi
73134

74135
SCRIPTS_DIR="${CCEX_PROJECT_PATH}/infra/scripts"
75136

76137
if [ "$_TORCH_VER" == "nightly" ]; then
77138
echo "Install package dependencies from torch nightly version"
139+
REQ_FILE="${SCRIPTS_DIR}/../dependency/torch_dev.txt"
140+
python3 -m pip install -r ${REQ_FILE} --index-url ${INDEX_URL}
78141
python3 -m pip install -r "${SCRIPTS_DIR}/install_requirements_dev.txt"
79142
elif [ "$_TORCH_VER" == "2.6" ]; then
80143
echo "Install package dependencies from torch stable version"
144+
python3 -m pip install torch==2.6.0 --index-url ${INDEX_URL}
81145
python3 -m pip install -r "${SCRIPTS_DIR}/install_requirements_2_6.txt"
82146
elif [ "$_TORCH_VER" == "2.5" ]; then
83147
echo "Install package dependencies from torch stable version"
148+
python3 -m pip install torch==2.5.0 --index-url ${INDEX_URL}
84149
python3 -m pip install -r "${SCRIPTS_DIR}/install_requirements_2_5.txt"
85150
else
86151
echo "Assertion: Cannot reach here"

infra/scripts/install_requirements_2_5.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ wheel==0.43.0
33

44
circle_schema
55

6-
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.5.0
86
cffi==1.17.1
97
packaging==25.0
108
pyyaml==6.0.2

infra/scripts/install_requirements_2_6.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ wheel==0.43.0
33

44
circle_schema
55

6-
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.6.0
86
cffi==1.17.1
97
packaging==25.0
108
pyyaml==6.0.2

infra/scripts/install_requirements_dev.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ wheel==0.43.0
33

44
circle_schema
55

6-
-r ../../infra/dependency/torch_dev.txt
7-
86
cffi==1.17.1
97
packaging==25.0
108
pyyaml==6.0.2

infra/scripts/test_configure.sh

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ Usage: ./ccex configure test [--torch_ver|--help|-h]
2626
--torch_ver [2.5|2.6|nightly]
2727
Specify torch version to install family test packages.
2828
(default: 2.6)
29+
--cuda_ver [11.8|12.1]
30+
Specify the target CUDA version. This overrides automatic
31+
detection.
32+
--cpu_only Forces installation of the CPU-only version of PyTorch.
33+
Disables both CUDA version detection and use of --cuda-version.
2934
-h | --help Show help message and exit
3035
EOF
3136
}
@@ -35,14 +40,17 @@ TEST_DIR="${CCEX_PROJECT_PATH}/test"
3540

3641
pushd ${CCEX_PROJECT_PATH} > /dev/null
3742

38-
options=$(getopt -o h --long torch_ver:,help -- "$@")
43+
options=$(getopt -o h --long torch_ver:,cuda_ver:,cpu_only,help -- "$@")
3944

4045
[ $? -eq 0 ] || {
4146
echo "Incorrect options provided"
4247
exit 1
4348
}
4449

4550
_TORCH_VER="2.6"
51+
_USER_SPECIFIED_CUDA=""
52+
_CPU_ONLY=false
53+
_NIGHTLY=""
4654

4755
eval set -- "$options"
4856

@@ -52,6 +60,14 @@ while true; do
5260
_TORCH_VER="$2"
5361
shift
5462
;;
63+
--cuda_ver)
64+
_USER_SPECIFIED_CUDA="$2"
65+
shift
66+
;;
67+
--cpu_only)
68+
_CPU_ONLY=true
69+
shift
70+
;;
5571
-h|--help)
5672
show_help
5773
exit 0;
@@ -64,21 +80,68 @@ while true; do
6480
shift
6581
done
6682

83+
# Check torch version
6784
if [ "$_TORCH_VER" != "2.5" -a "$_TORCH_VER" != "2.6" -a "$_TORCH_VER" != "nightly" ]; then
6885
echo "Invalid '$_TORCH_VER'"
6986
echo "(Use --help to see available options)"
7087
exit 1
7188
fi
89+
# Check if torch version is nightly
90+
if [[ "$_TORCH_VER" == "nightly" ]]; then
91+
_NIGHTLY=true
92+
fi
93+
# Check for conflicting options
94+
if [[ -n "$_USER_SPECIFIED_CUDA" && "$_CPU_ONLY" = true ]]; then
95+
echo "[ERROR] Cannot use --cpu_only and --cuda_ver together"
96+
exit 1
97+
fi
98+
99+
# set index-url
100+
get_index_url_for_cuda_version() {
101+
local version="$1"
102+
local is_nightly="$2"
103+
104+
case "$version" in
105+
11.8) echo "https://download.pytorch.org/whl${is_nightly:+/nightly}/cu118" ;;
106+
12.1) echo "https://download.pytorch.org/whl${is_nightly:+/nightly}/cu121" ;;
107+
*) echo "" ;;
108+
esac
109+
}
110+
INDEX_URL="https://download.pytorch.org/whl${_NIGHTLY:+/nightly}/cpu"
111+
if [[ "$_CPU_ONLY" = true ]]; then
112+
echo "[INFO] Installing CPU-only version of PyTorch"
113+
else
114+
if [[ -n "$_USER_SPECIFIED_CUDA" ]]; then
115+
echo "[INFO] Using user-specified CUDA version: $_USER_SPECIFIED_CUDA"
116+
INDEX_URL=$(get_index_url_for_cuda_version "$_USER_SPECIFIED_CUDA" "$_NIGHTLY")
117+
else
118+
if command -v nvcc &> /dev/null; then
119+
DETECTED_CUDA=$(nvcc --version | grep -oP "release \K[0-9]+\.[0-9]+")
120+
elif command -v nvidia-smi &> /dev/null; then
121+
DETECTED_CUDA=$(nvcc --version | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+")
122+
else
123+
DETECTED_CUDA=""
124+
fi
125+
fi
126+
if [[ -n "$DETECTED_CUDA" ]]; then
127+
echo "[INFO] Detected CUDA version: $DETECTED_CUDA"
128+
INDEX_URL=$(get_index_url_for_cuda_version "$DETECTED_CUDA" "$_NIGHTLY")
129+
fi
130+
fi
72131

73132
if [ "$_TORCH_VER" == "nightly" ]; then
74133
echo "Install test package dependencies from nightly version"
134+
REQ_FILE="${SCRIPTS_DIR}/../dependency/torchvision_dev.txt"
135+
python3 -m pip install -r ${REQ_FILE} --index-url ${INDEX_URL}
75136
python3 -m pip install -r "${TEST_DIR}/requirements_dev.txt"
76137
elif [ "$_TORCH_VER" == "2.6" ]; then
77138
echo "Install test package dependencies from stable version"
78-
python3 -m pip install -r "${TEST_DIR}/requirements_2_6.txt"
139+
python3 -m pip install torchvision==0.21.0 --index-url ${INDEX_URL}
140+
# python3 -m pip install -r "${TEST_DIR}/requirements_2_6.txt"
79141
elif [ "$_TORCH_VER" == "2.5" ]; then
80142
echo "Install test package dependencies from stable version"
81-
python3 -m pip install -r "${TEST_DIR}/requirements_2_5.txt"
143+
python3 -m pip install torchvision==0.20.0 --index-url ${INDEX_URL}
144+
# python3 -m pip install -r "${TEST_DIR}/requirements_2_5.txt"
82145
fi
83146

84147
popd > /dev/null

test/requirements_2_5.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
--extra-index-url https://download.pytorch.org/whl/cpu
2-
torchvision==0.20.0

test/requirements_2_6.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
--extra-index-url https://download.pytorch.org/whl/cpu
2-
torchvision==0.21.0

0 commit comments

Comments
 (0)