Skip to content

Commit 8a04524

Browse files
authored
add the ability to select PyTorch wheel channel (#27)
1 parent a17068d commit 8a04524

File tree

5 files changed

+114
-9
lines changed

5 files changed

+114
-9
lines changed

light_the_torch/_pip/find.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
def find_links(
3131
pip_install_args: List[str],
3232
computation_backend: Optional[Union[str, ComputationBackend]] = None,
33+
channel: str = "stable",
3334
platform: Optional[str] = None,
3435
python_version: Optional[str] = None,
3536
verbose: bool = False,
@@ -43,6 +44,8 @@ def find_links(
4344
computation_backend: Computation backend, for example ``"cpu"`` or ``"cu102"``.
4445
Defaults to the available hardware of the running system preferring CUDA
4546
over CPU.
47+
channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
48+
``"test"``, and ``"nightly"``.
4649
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
4750
to the platform of the running system.
4851
python_version: Python version, for example ``"3"`` or ``"3.7"``. Defaults to
@@ -57,9 +60,17 @@ def find_links(
5760
elif isinstance(computation_backend, str):
5861
computation_backend = ComputationBackend.from_str(computation_backend)
5962

63+
if channel not in ("stable", "test", "nightly"):
64+
raise ValueError(
65+
f"channel can be one of 'stable', 'test', or 'nightly', "
66+
f"but got {channel} instead."
67+
)
68+
6069
dists = extract_dists(pip_install_args)
6170

62-
cmd = StopAfterPytorchLinksFoundCommand(computation_backend=computation_backend)
71+
cmd = StopAfterPytorchLinksFoundCommand(
72+
computation_backend=computation_backend, channel=channel
73+
)
6374
pip_install_args = adjust_pip_install_args(dists, platform, python_version)
6475
options, args = cmd.parser.parse_args(pip_install_args)
6576
try:
@@ -222,32 +233,55 @@ class PytorchLinkCollector(LinkCollector):
222233
def __init__(
223234
self,
224235
*args: Any,
225-
url: str = "https://download.pytorch.org/whl/torch_stable.html",
236+
computation_backend: ComputationBackend,
237+
channel: str = "stable",
226238
**kwargs: Any,
227239
) -> None:
228240
super().__init__(*args, **kwargs)
241+
if channel == "stable":
242+
url = "https://download.pytorch.org/whl/torch_stable.html"
243+
else:
244+
url = (
245+
f"https://download.pytorch.org/whl/"
246+
f"{channel}/{computation_backend}/torch_{channel}.html"
247+
)
229248
self.search_scope = SearchScope.create(find_links=[url], index_urls=[])
230249

231250
@classmethod
232251
def from_link_collector(
233-
cls, link_collector: LinkCollector
252+
cls,
253+
link_collector: LinkCollector,
254+
computation_backend: ComputationBackend,
255+
channel: str = "stable",
234256
) -> "PytorchLinkCollector":
235-
return new_from_similar(cls, link_collector, ("session", "search_scope",),)
257+
return new_from_similar(
258+
cls,
259+
link_collector,
260+
("session", "search_scope",),
261+
channel=channel,
262+
computation_backend=computation_backend,
263+
)
236264

237265

238266
class PytorchPackageFinder(PackageFinder):
239267
_candidate_prefs: PytorchCandidatePreferences
240268
_link_collector: PytorchLinkCollector
241269

242270
def __init__(
243-
self, *args: Any, computation_backend: ComputationBackend, **kwargs: Any,
271+
self,
272+
*args: Any,
273+
computation_backend: ComputationBackend,
274+
channel: str = "stable",
275+
**kwargs: Any,
244276
) -> None:
245277
super().__init__(*args, **kwargs)
246278
self._candidate_prefs = PytorchCandidatePreferences.from_candidate_preferences(
247279
self._candidate_prefs, computation_backend=computation_backend
248280
)
249281
self._link_collector = PytorchLinkCollector.from_link_collector(
250-
self._link_collector
282+
self._link_collector,
283+
channel=channel,
284+
computation_backend=computation_backend,
251285
)
252286

253287
def make_candidate_evaluator(
@@ -265,7 +299,10 @@ def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator
265299

266300
@classmethod
267301
def from_package_finder(
268-
cls, package_finder: PackageFinder, computation_backend: ComputationBackend,
302+
cls,
303+
package_finder: PackageFinder,
304+
computation_backend: ComputationBackend,
305+
channel: str = "stable",
269306
) -> "PytorchPackageFinder":
270307
return new_from_similar(
271308
cls,
@@ -279,6 +316,7 @@ def from_package_finder(
279316
"ignore_requires_python",
280317
),
281318
computation_backend=computation_backend,
319+
channel=channel,
282320
)
283321

284322

@@ -298,14 +336,22 @@ def resolve(
298336

299337

300338
class StopAfterPytorchLinksFoundCommand(PatchedInstallCommand):
301-
def __init__(self, computation_backend: ComputationBackend, **kwargs: Any) -> None:
339+
def __init__(
340+
self,
341+
computation_backend: ComputationBackend,
342+
channel: str = "stable",
343+
**kwargs: Any,
344+
) -> None:
302345
super().__init__(**kwargs)
303346
self.computation_backend = computation_backend
347+
self.channel = channel
304348

305349
def _build_package_finder(self, *args: Any, **kwargs: Any) -> PytorchPackageFinder:
306350
package_finder = super()._build_package_finder(*args, **kwargs)
307351
return PytorchPackageFinder.from_package_finder(
308-
package_finder, computation_backend=self.computation_backend
352+
package_finder,
353+
computation_backend=self.computation_backend,
354+
channel=self.channel,
309355
)
310356

311357
def make_resolver(

light_the_torch/cli/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ def add_ltt_install_parser(subparsers: SubParsers) -> None:
111111
action="store_true",
112112
help="install only PyTorch distributions",
113113
)
114+
parser.add_argument(
115+
"--channel",
116+
type=str,
117+
default="stable",
118+
help=(
119+
"Channel of the PyTorch wheels. "
120+
"Can be one of 'stable' (default), 'test', or 'nightly'"
121+
),
122+
)
114123
parser.add_argument(
115124
"--install-cmd",
116125
type=str,
@@ -143,6 +152,15 @@ def add_ltt_find_parser(subparsers: SubParsers) -> None:
143152
"preferring CUDA over CPU."
144153
),
145154
)
155+
parser.add_argument(
156+
"--channel",
157+
type=str,
158+
default="stable",
159+
help=(
160+
"Channel of the PyTorch wheels. "
161+
"Can be one of 'stable' (default), 'test', or 'nightly'"
162+
),
163+
)
146164
add_pip_install_arguments(parser, "platform", "python_version")
147165
LTTParser.add_common_arguments(parser)
148166

light_the_torch/cli/commands.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _run(self, pip_install_args: List[str]) -> None:
5454
class FindCommand(Command):
5555
def __init__(self, args: argparse.Namespace) -> None:
5656
self.computation_backend = args.computation_backend
57+
self.channel = args.channel
5758
self.platform = args.platform
5859
self.python_version = args.python_version
5960
self.verbose = args.verbose
@@ -62,6 +63,7 @@ def _run(self, pip_install_args: List[str]) -> None:
6263
links = ltt.find_links(
6364
pip_install_args,
6465
computation_backend=self.computation_backend,
66+
channel=self.channel,
6567
platform=self.platform,
6668
python_version=self.python_version,
6769
verbose=self.verbose,
@@ -73,6 +75,7 @@ class InstallCommand(Command):
7375
def __init__(self, args: argparse.Namespace) -> None:
7476
self.force_cpu = args.force_cpu
7577
self.pytorch_only = args.pytorch_only
78+
self.channel = args.channel
7679

7780
install_cmd = args.install_cmd
7881
if "{packages}" not in install_cmd:
@@ -85,6 +88,7 @@ def _run(self, pip_install_args: List[str]) -> None:
8588
links = ltt.find_links(
8689
pip_install_args,
8790
computation_backend=CPUBackend() if self.force_cpu else None,
91+
channel=self.channel,
8892
verbose=self.verbose,
8993
)
9094
cmd = self.install_cmd.format(packages=" ".join(links))

tests/cli/test_install.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ def test_ltt_install_pytorch_only(
9292
collect_packages.assert_not_called()
9393

9494

95+
def test_ltt_install_channel(
96+
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
97+
):
98+
channel = "channel"
99+
100+
patch_install_argv(f"--channel={channel}")
101+
find_links = patch_find_links()
102+
patch_subprocess_call()
103+
patch_collect_packages()
104+
105+
with exits():
106+
cli.main()
107+
108+
_, kwargs = find_links.call_args
109+
assert "channel" in kwargs
110+
assert kwargs["channel"] == channel
111+
112+
95113
def test_ltt_install_install_cmd(
96114
patch_install_argv, patch_find_links, patch_subprocess_call,
97115
):

tests/unit/pip/test_find.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def computation_backends():
3333
return ("cpu", "cu92", "cu101", "cu102")
3434

3535

36+
@pytest.fixture
37+
def channels():
38+
return ("stable", "test", "nightly")
39+
40+
3641
@pytest.fixture
3742
def platforms():
3843
return ("linux_x86_64", "macosx_10_9_x86_64", "win_amd64")
@@ -119,6 +124,11 @@ def test_find_links_computation_backend_str(
119124
)
120125

121126

127+
def test_find_links_unknown_channel():
128+
with pytest.raises(ValueError):
129+
ltt.find_links([], channel="channel")
130+
131+
122132
def test_find_links_platform(subtests, patch_extract_dists, patch_run, platforms):
123133
patch_extract_dists()
124134
run = patch_run()
@@ -192,3 +202,12 @@ def test_find_links_torchvision_smoke(subtests, wheel_properties):
192202
for properties in wheel_properties:
193203
with subtests.test(**properties):
194204
assert ltt.find_links([dist], **properties)
205+
206+
207+
@pytest.mark.slow
208+
def test_find_links_torch_channel_smoke(subtests, channels):
209+
dist = "torch"
210+
211+
for channel in channels:
212+
with subtests.test(channel=channel):
213+
assert ltt.find_links([dist], computation_backend="cpu", channel=channel)

0 commit comments

Comments
 (0)