Skip to content

Commit dfbd56e

Browse files
authored
Adding support for Torch LTS (#39)
1 parent 030771d commit dfbd56e

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

light_the_torch/_pip/find.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def find_links(
6161
``"cpu"`` or ``"cu102"``. Defaults to the available hardware of the running
6262
system.
6363
channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
64-
``"test"``, and ``"nightly"``.
64+
``"lts"``, ``"test"``, and ``"nightly"``.
6565
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
6666
to the platform of the running system.
6767
python_version: Python version, for example ``"3"`` or ``"3.7"``. Defaults to
@@ -78,9 +78,9 @@ def find_links(
7878
else:
7979
computation_backends = set(computation_backends)
8080

81-
if channel not in ("stable", "test", "nightly"):
81+
if channel not in ("stable", "lts", "test", "nightly"):
8282
raise ValueError(
83-
f"channel can be one of 'stable', 'test', or 'nightly', "
83+
f"channel can be one of 'stable', 'lts', 'test', or 'nightly', "
8484
f"but got {channel} instead."
8585
)
8686

@@ -276,6 +276,8 @@ def __init__(
276276
super().__init__(*args, **kwargs)
277277
if channel == "stable":
278278
urls = ["https://download.pytorch.org/whl/torch_stable.html"]
279+
elif channel == "lts":
280+
urls = ["https://download.pytorch.org/whl/lts/1.8/torch_lts.html"]
279281
else:
280282
urls = [
281283
f"https://download.pytorch.org/whl/"

light_the_torch/cli/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def add_ltt_install_parser(subparsers: SubParsers) -> None:
117117
default="stable",
118118
help=(
119119
"Channel of the PyTorch wheels. "
120-
"Can be one of 'stable' (default), 'test', or 'nightly'"
120+
"Can be one of 'stable' (default), 'lts', 'test', or 'nightly'"
121121
),
122122
)
123123
parser.add_argument(

tests/unit/_pip/test_find.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def patch_run_():
3030
return patch_run_
3131

3232

33-
CHANNELS = ("stable", "test", "nightly")
33+
CHANNELS = ("stable", "lts", "test", "nightly")
3434
PLATFORMS = ("linux_x86_64", "macosx_10_9_x86_64", "win_amd64")
3535
PLATFORM_MAP = dict(zip(PLATFORMS, ("Linux", "Darwin", "Windows")))
3636

0 commit comments

Comments
 (0)