@@ -41,7 +41,7 @@ python_version=3.11
4141quick_test=false
4242while getopts " :c:f:i:p:qv:" opt; do
4343 case $opt in
44- c) cuda_version =" $OPTARG "
44+ c) cuda_version_conda =" $OPTARG "
4545 ;;
4646 f) root_folder=" $OPTARG "
4747 ;;
@@ -67,6 +67,12 @@ if [ ! -d "$root_folder" ] || [ ! -f "$readme_file" ] ; then
6767 (return 0 2> /dev/null) && return 100 || exit 100
6868fi
6969
70+ # Check that the `cuda_version_conda` is a full version string like "12.8.0"
71+ if ! [[ $cuda_version_conda =~ ^[0-9]+\. [0-9]+\. [0-9]+$ ]]; then
72+ echo -e " \e[01;31mThe cuda_version_conda (-c) must be a full version string like '12.8.0'. Provided: '${cuda_version_conda} '.\e[0m" >&2
73+ (return 0 2> /dev/null) && return 100 || exit 100
74+ fi
75+
7076# Install Miniconda
7177if [ ! -x " $( command -v conda) " ]; then
7278 mkdir -p ~ /.miniconda3
@@ -82,7 +88,7 @@ if [ -n "${extra_packages}" ]; then
8288 pip_extra_url=" --extra-index-url http://localhost:8080"
8389fi
8490while IFS= read -r line; do
85- line=$( echo $line | sed -E " s/cuda_version=(.\{\{)?\s?\S+\s?(\}\})?/cuda_version=${cuda_version} .0 /g" )
91+ line=$( echo $line | sed -E " s/cuda_version=(.\{\{)?\s?\S+\s?(\}\})?/cuda_version=${cuda_version_conda} /g" )
8692 line=$( echo $line | sed -E " s/python(=)?3.[0-9]{1,}/python\1${python_version} /g" )
8793 line=$( echo $line | sed -E " s/pip install (.\{\{)?\s?\S+\s?(\}\})?/pip install cudaq==${cudaq_version} -v ${pip_extra_url// \/ / \\ / } /g" )
8894 if [ -n " $( echo $line | grep " conda activate" ) " ]; then
@@ -162,9 +168,21 @@ done
162168
163169# Run torch integrator tests.
164170# This is an optional integrator, which requires torch and torchdiffeq.
165- # Install torch separately to match the cuda version.
171+ # Install torch separately to match the cuda version installed as CUDA-Q dependency .
166172# Torch if installed as part of torchdiffeq's dependencies, may default to the latest cuda version.
167- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu$( echo $cuda_version | cut -d ' .' -f-2 | tr -d .)
173+ ctk_version=" $( python3 -m pip list | grep nvidia-cuda-runtime | tr -s " " | cut -d " " -f 2) "
174+ torch_channel=" "
175+ # For CUDA version 13, we need to use the nightly build.
176+ # This is required because the stable build (as of v2.9.0) depends on cuBlas 13.0,
177+ # while cuquantum-cu13 depends on cuBlas 13.1.
178+ # Ref: torch changes cublas to 13.1 in this commit.
179+ # https://github.com/pytorch/pytorch/commit/544b443ea1d1a9b19e65f981168a01cb87a2d333
180+ # TODO: Update this script when stable torch builds with this fix are available.
181+ if [ " $( cut -d ' .' -f 1 <<< " $ctk_version" ) " -eq 13 ]; then
182+ torch_channel=" nightly/"
183+ fi
184+
185+ python3 -m pip install torch --index-url https://download.pytorch.org/whl/${torch_channel} cu$( cut -d ' .' -f 1 <<< " $ctk_version" ) " " $( cut -d ' .' -f 2 <<< " $ctk_version" )
168186python3 -m pip install torchdiffeq
169187python3 -m pytest -v " $root_folder /tests/dynamics/integrators"
170188if [ ! $? -eq 0 ]; then
0 commit comments