1919# [EXPORTED VARIABLES]
2020# - CCEX_PROJECT_PATH
2121
22+ SCRIPTS_DIR=" $( cd " $( dirname " ${BASH_SOURCE[0]} " ) " && pwd) "
23+
2224show_help () {
2325# `cat << EOF` This means that cat should stop reading when EOF is detected
2426cat << EOF
@@ -28,11 +30,16 @@ Usage: ./ccex install [--dist|--torch_ver|--help|-h]
2830--torch_ver [2.5|2.6|nightly]
2931 Specify torch version to be installed.
3032 (default: 2.6)
33+ --cuda_ver [11.8|12.1]
34+ Specify the target CUDA version. This overrides automatic
35+ detection.
36+ --cpu_only Forces installation of the CPU-only version of PyTorch.
37+ Disables both CUDA version detection and use of --cuda-version.
3138-h | --help Show help message and exit
3239EOF
3340}
3441
35- options=$( getopt -o h --long dist,torch_ver:,help -- " $@ " )
42+ options=$( getopt -o h --long dist,torch_ver:,cuda_ver:,cpu_only, help -- " $@ " )
3643
3744[ $? -eq 0 ] || {
3845 echo " Incorrect options provided"
@@ -41,6 +48,9 @@ options=$(getopt -o h --long dist,torch_ver:,help -- "$@")
4148
4249_DIST=0
4350_TORCH_VER=" 2.6"
51+ _USER_SPECIFIED_CUDA=" "
52+ _CPU_ONLY=false
53+ _NIGHTLY=" "
4454
4555eval set -- " $options "
4656
@@ -53,6 +63,14 @@ while true; do
5363 _TORCH_VER=" $2 "
5464 shift
5565 ;;
66+ --cuda_ver)
67+ _USER_SPECIFIED_CUDA=" $2 "
68+ shift
69+ ;;
70+ --cpu_only)
71+ _CPU_ONLY=true
72+ shift
73+ ;;
5674 -h|--help)
5775 show_help
5876 exit 0;
@@ -65,22 +83,69 @@ while true; do
6583 shift
6684done
6785
86+ # Check torch version
6887if [ " $_TORCH_VER " != " 2.5" ] && [ " $_TORCH_VER " != " 2.6" ] && [ " $_TORCH_VER " != " nightly" ]; then
6988 echo " Invalid torch version '$_TORCH_VER '"
7089 echo " (Use --help to see available torch versions)"
7190 exit 1
7291fi
92+ # Check if torch version is nightly
93+ if [[ " $_TORCH_VER " == " nightly" ]]; then
94+ _NIGHTLY=true
95+ fi
96+ # Check for conflicting options
97+ if [[ -n " $_USER_SPECIFIED_CUDA " && " $_CPU_ONLY " = true ]]; then
98+ echo " [ERROR] Cannot use --cpu_only and --cuda_ver together"
99+ exit 1
100+ fi
101+
102+ # set index-url
103+ get_index_url_for_cuda_version () {
104+ local version=" $1 "
105+ local is_nightly=" $2 "
106+
107+ case " $version " in
108+ 11.8) echo " https://download.pytorch.org/whl${is_nightly: +/ nightly} /cu118" ;;
109+ 12.1) echo " https://download.pytorch.org/whl${is_nightly: +/ nightly} /cu121" ;;
110+ * ) echo " " ;;
111+ esac
112+ }
113+ INDEX_URL=" https://download.pytorch.org/whl${_NIGHTLY: +/ nightly} /cpu"
114+ if [[ " $_CPU_ONLY " = true ]]; then
115+ echo " [INFO] Installing CPU-only version of PyTorch"
116+ else
117+ if [[ -n " $_USER_SPECIFIED_CUDA " ]]; then
118+ echo " [INFO] Using user-specified CUDA version: $_USER_SPECIFIED_CUDA "
119+ INDEX_URL=$( get_index_url_for_cuda_version " $_USER_SPECIFIED_CUDA " " $_NIGHTLY " )
120+ else
121+ if command -v nvcc & > /dev/null; then
122+ DETECTED_CUDA=$( nvcc --version | grep -oP " release \K[0-9]+\.[0-9]+" )
123+ elif command -v nvidia-smi & > /dev/null; then
124+ DETECTED_CUDA=$( nvcc --version | grep -oP " CUDA Version: \K[0-9]+\.[0-9]+" )
125+ else
126+ DETECTED_CUDA=" "
127+ fi
128+ fi
129+ if [[ -n " $DETECTED_CUDA " ]]; then
130+ echo " [INFO] Detected CUDA version: $DETECTED_CUDA "
131+ INDEX_URL=$( get_index_url_for_cuda_version " $DETECTED_CUDA " " $_NIGHTLY " )
132+ fi
133+ fi
73134
74135SCRIPTS_DIR=" ${CCEX_PROJECT_PATH} /infra/scripts"
75136
76137if [ " $_TORCH_VER " == " nightly" ]; then
77138 echo " Install package dependencies from torch nightly version"
139+ REQ_FILE=" ${SCRIPTS_DIR} /../dependency/torch_dev.txt"
140+ python3 -m pip install -r ${REQ_FILE} --index-url ${INDEX_URL}
78141 python3 -m pip install -r " ${SCRIPTS_DIR} /install_requirements_dev.txt"
79142elif [ " $_TORCH_VER " == " 2.6" ]; then
80143 echo " Install package dependencies from torch stable version"
144+ python3 -m pip install torch==2.6.0 --index-url ${INDEX_URL}
81145 python3 -m pip install -r " ${SCRIPTS_DIR} /install_requirements_2_6.txt"
82146elif [ " $_TORCH_VER " == " 2.5" ]; then
83147 echo " Install package dependencies from torch stable version"
148+ python3 -m pip install torch==2.5.0 --index-url ${INDEX_URL}
84149 python3 -m pip install -r " ${SCRIPTS_DIR} /install_requirements_2_5.txt"
85150else
86151 echo " Assertion: Cannot reach here"
0 commit comments