Skip to content

Commit bbeed4c

Browse files
committed
Install torch nightly from wheel URLs
1 parent c47cdae commit bbeed4c

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

src/monobase/monogen.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ class MonoGen:
2828
python={'3.12': '3.12.7'},
2929
torch=[
3030
'2.4.1',
31-
# NOTE(meatballhat): This is turned off until we can figure out how to handle
32-
# nightlies better since the torch package index only retains ~2 months of
33-
# versions:
34-
## Nightly
35-
#'2.6.0.dev20240918'
31+
# Nightly
32+
'2.6.0.dev20240918',
3633
],
3734
pip_pkgs=SEED_PKGS,
3835
),

src/monobase/torch.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,8 @@ class TorchDeps:
3030
}
3131

3232
torch_deps_dict = {
33-
# NOTE(meatballhat): This is turned off until we can figure out how to handle
34-
# nightlies better since the torch package index only retains ~2 months of
35-
# versions:
36-
## Nightly
37-
#'2.6.0.dev20240918': TorchDeps('2.5.0.dev20240918', '0.20.0.dev20240918'),
33+
# Nightly
34+
'2.6.0.dev20240918': TorchDeps('2.5.0.dev20240918', '0.20.0.dev20240918'),
3835
# Releases
3936
'2.4.1': TorchDeps('2.4.1', '0.19.1'),
4037
'2.4.0': TorchDeps('2.4.0', '0.19.0'),

src/monobase/uv.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,24 @@ def index_args(torch_version: Version, cuda_version: str) -> list[str]:
3434

3535

3636
def pip_packages(
37-
torch_version: Version, cuda_version: str, pip_pkgs: list[str]
37+
torch_version: Version, python_version: str, cuda_version: str, pip_pkgs: list[str]
3838
) -> list[str]:
3939
deps = torch_deps[torch_version]
4040
cu = cuda_suffix(cuda_version)
41-
pkgs = [
42-
f'torch=={torch_version}+{cu}',
43-
f'torchaudio=={deps.torchaudio}+{cu}',
44-
f'torchvision=={deps.torchvision}+{cu}',
45-
]
41+
if torch_version.extra:
42+
prefix = torch_index_url(torch_version, cuda_version)
43+
py = f'cp{python_version.replace('.', '')}'
44+
pkgs = [
45+
f'torch @ {prefix}/torch-{torch_version}%2B{cu}-{py}-{py}-linux_x86_64.whl',
46+
f'torchaudio @ {prefix}/torchaudio-{deps.torchaudio}%2B{cu}-{py}-{py}-linux_x86_64.whl',
47+
f'torchvision @ {prefix}/torchvision-{deps.torchvision}%2B{cu}-{py}-{py}-linux_x86_64.whl',
48+
]
49+
else:
50+
pkgs = [
51+
f'torch=={torch_version}+{cu}',
52+
f'torchaudio=={deps.torchaudio}+{cu}',
53+
f'torchvision=={deps.torchvision}+{cu}',
54+
]
4655
# Older Torch versions do not bundle CUDA or CuDNN
4756
nvidia_pkgs = []
4857
if torch_version < Version.parse('2.2.0'):
@@ -101,7 +110,7 @@ def update_venv(
101110
]
102111
cmd = ['uv', 'pip', 'compile', '--python-platform', 'x86_64-unknown-linux-gnu']
103112
cmd = cmd + emit_args + index_args(t, cuda_version) + ['-']
104-
pkgs = pip_packages(t, cuda_version, pip_pkgs)
113+
pkgs = pip_packages(t, python_version, cuda_version, pip_pkgs)
105114
env = os.environ.copy()
106115
env['VIRTUAL_ENV'] = vdir
107116
try:

0 commit comments

Comments
 (0)