@@ -34,15 +34,24 @@ def index_args(torch_version: Version, cuda_version: str) -> list[str]:
3434
3535
3636def pip_packages (
37- torch_version : Version , cuda_version : str , pip_pkgs : list [str ]
37+ torch_version : Version , python_version : str , cuda_version : str , pip_pkgs : list [str ]
3838) -> list [str ]:
3939 deps = torch_deps [torch_version ]
4040 cu = cuda_suffix (cuda_version )
41- pkgs = [
42- f'torch=={ torch_version } +{ cu } ' ,
43- f'torchaudio=={ deps .torchaudio } +{ cu } ' ,
44- f'torchvision=={ deps .torchvision } +{ cu } ' ,
45- ]
41+ if torch_version .extra :
42+ prefix = torch_index_url (torch_version , cuda_version )
43+ py = f'cp{ python_version .replace ('.' , '' )} '
44+ pkgs = [
45+ f'torch @ { prefix } /torch-{ torch_version } %2B{ cu } -{ py } -{ py } -linux_x86_64.whl' ,
46+ f'torchaudio @ { prefix } /torchaudio-{ deps .torchaudio } %2B{ cu } -{ py } -{ py } -linux_x86_64.whl' ,
47+ f'torchvision @ { prefix } /torchvision-{ deps .torchvision } %2B{ cu } -{ py } -{ py } -linux_x86_64.whl' ,
48+ ]
49+ else :
50+ pkgs = [
51+ f'torch=={ torch_version } +{ cu } ' ,
52+ f'torchaudio=={ deps .torchaudio } +{ cu } ' ,
53+ f'torchvision=={ deps .torchvision } +{ cu } ' ,
54+ ]
4655 # Older Torch versions do not bundle CUDA or CuDNN
4756 nvidia_pkgs = []
4857 if torch_version < Version .parse ('2.2.0' ):
@@ -101,7 +110,7 @@ def update_venv(
101110 ]
102111 cmd = ['uv' , 'pip' , 'compile' , '--python-platform' , 'x86_64-unknown-linux-gnu' ]
103112 cmd = cmd + emit_args + index_args (t , cuda_version ) + ['-' ]
104- pkgs = pip_packages (t , cuda_version , pip_pkgs )
113+ pkgs = pip_packages (t , python_version , cuda_version , pip_pkgs )
105114 env = os .environ .copy ()
106115 env ['VIRTUAL_ENV' ] = vdir
107116 try :
0 commit comments