diff --git a/docker/torch_install_helper.py b/docker/torch_install_helper.py index f1255f77..b4ce2aa9 100755 --- a/docker/torch_install_helper.py +++ b/docker/torch_install_helper.py @@ -102,10 +102,10 @@ def install_pytorch_jetson() -> None: def install_nvblox_torch() -> None: cuda_version = get_cuda_version() - if cuda_version not in JETSON_CUDA_VERSION_TO_TORCH: - print( - 'warning: Unsupported CUDA version: {cuda_version}. Skipping nvblox torch installation.' - ) + supported_cuda_versions = list(JETSON_CUDA_VERSION_TO_TORCH.keys() + | X86_CUDA_VERSION_TO_PYTORCH_PIP_URL.keys()) + if cuda_version not in supported_cuda_versions: + print(f'warning: Unsupported CUDA version: {cuda_version}. Skipping nvblox torch install.') return script = """ @@ -113,8 +113,12 @@ def install_nvblox_torch() -> None: umask 000 . /opt/venv/bin/activate python3 -m pip install --ignore-installed --upgrade pip --no-cache-dir - pip install /nvblox/nvblox_torch/ """ + # Need to force the torch version for cuda 11 to prevent upgrade. + if cuda_version == '11': + script += 'pip install /nvblox/nvblox_torch/ "torch==2.7.1"' + else: + script += 'pip install /nvblox/nvblox_torch/' subprocess.run(script, shell=True, check=True)