Skip to content

Commit 73e8fc8

Browse files
authored
Support multiple CUDA versions (#33)
* Support multiple CUDA versions * cleanup * exclude specific windows combinations from tests * fix cpu backend ordering * prioritize backend over version * fix docstring
1 parent eff3ea9 commit 73e8fc8

File tree

11 files changed

+524
-415
lines changed

11 files changed

+524
-415
lines changed

light_the_torch/_pip/find.py

Lines changed: 74 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
import re
2-
from typing import Any, Iterable, List, NoReturn, Optional, Text, Tuple, Union, cast
2+
from typing import (
3+
Any,
4+
Collection,
5+
Iterable,
6+
List,
7+
NoReturn,
8+
Optional,
9+
Set,
10+
Text,
11+
Tuple,
12+
Union,
13+
cast,
14+
)
315

416
from pip._internal.index.collector import LinkCollector
517
from pip._internal.index.package_finder import (
@@ -13,8 +25,10 @@
1325
from pip._internal.models.search_scope import SearchScope
1426
from pip._internal.req.req_install import InstallRequirement
1527
from pip._internal.req.req_set import RequirementSet
28+
from pip._vendor.packaging.version import Version
29+
30+
import light_the_torch.computation_backend as cb
1631

17-
from ..computation_backend import ComputationBackend, detect_computation_backend
1832
from .common import (
1933
InternalLTTError,
2034
PatchedInstallCommand,
@@ -29,7 +43,9 @@
2943

3044
def find_links(
3145
pip_install_args: List[str],
32-
computation_backend: Optional[Union[str, ComputationBackend]] = None,
46+
computation_backends: Optional[
47+
Union[cb.ComputationBackend, Collection[cb.ComputationBackend]]
48+
] = None,
3349
channel: str = "stable",
3450
platform: Optional[str] = None,
3551
python_version: Optional[str] = None,
@@ -41,9 +57,9 @@ def find_links(
4157
Args:
4258
pip_install_args: Arguments passed to ``pip install`` that will be searched for
4359
required PyTorch distributions
44-
computation_backend: Computation backend, for example ``"cpu"`` or ``"cu102"``.
45-
Defaults to the available hardware of the running system preferring CUDA
46-
over CPU.
60+
computation_backends: Collection of supported computation backends, for example
61+
``"cpu"`` or ``"cu102"``. Defaults to the available hardware of the running
62+
system.
4763
channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
4864
``"test"``, and ``"nightly"``.
4965
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
@@ -55,10 +71,12 @@ def find_links(
5571
Returns:
5672
Wheel links with given properties for all required PyTorch distributions.
5773
"""
58-
if computation_backend is None:
59-
computation_backend = detect_computation_backend()
60-
elif isinstance(computation_backend, str):
61-
computation_backend = ComputationBackend.from_str(computation_backend)
74+
if computation_backends is None:
75+
computation_backends = cb.detect_compatible_computation_backends()
76+
elif isinstance(computation_backends, cb.ComputationBackend):
77+
computation_backends = {computation_backends}
78+
else:
79+
computation_backends = set(computation_backends)
6280

6381
if channel not in ("stable", "test", "nightly"):
6482
raise ValueError(
@@ -69,7 +87,7 @@ def find_links(
6987
dists = extract_dists(pip_install_args)
7088

7189
cmd = StopAfterPytorchLinksFoundCommand(
72-
computation_backend=computation_backend, channel=channel
90+
computation_backends=computation_backends, channel=channel
7391
)
7492
pip_install_args = adjust_pip_install_args(dists, platform, python_version)
7593
options, args = cmd.parser.parse_args(pip_install_args)
@@ -172,37 +190,43 @@ def extract_computation_backend_from_link(self, link: Link) -> Optional[str]:
172190

173191
class PytorchCandidatePreferences(CandidatePreferences):
174192
def __init__(
175-
self, *args: Any, computation_backend: ComputationBackend, **kwargs: Any,
193+
self,
194+
*args: Any,
195+
computation_backends: Set[cb.ComputationBackend],
196+
**kwargs: Any,
176197
) -> None:
177198
super().__init__(*args, **kwargs)
178-
self.computation_backend = computation_backend
199+
self.computation_backends = computation_backends
179200

180201
@classmethod
181202
def from_candidate_preferences(
182203
cls,
183204
candidate_preferences: CandidatePreferences,
184-
computation_backend: ComputationBackend,
205+
computation_backends: Set[cb.ComputationBackend],
185206
) -> "PytorchCandidatePreferences":
186207
return new_from_similar(
187208
cls,
188209
candidate_preferences,
189210
("prefer_binary", "allow_all_prereleases",),
190-
computation_backend=computation_backend,
211+
computation_backends=computation_backends,
191212
)
192213

193214

194215
class PytorchCandidateEvaluator(CandidateEvaluator):
195216
def __init__(
196-
self, *args: Any, computation_backend: ComputationBackend, **kwargs: Any,
217+
self,
218+
*args: Any,
219+
computation_backends: Set[cb.ComputationBackend],
220+
**kwargs: Any,
197221
) -> None:
198222
super().__init__(*args, **kwargs)
199-
self.computation_backend = computation_backend
223+
self.computation_backends = {cb.AnyBackend(), *computation_backends}
200224

201225
@classmethod
202226
def from_candidate_evaluator(
203227
cls,
204228
candidate_evaluator: CandidateEvaluator,
205-
computation_backend: ComputationBackend,
229+
computation_backends: Set[cb.ComputationBackend],
206230
) -> "PytorchCandidateEvaluator":
207231
return new_from_similar(
208232
cls,
@@ -215,51 +239,62 @@ def from_candidate_evaluator(
215239
"allow_all_prereleases",
216240
"hashes",
217241
),
218-
computation_backend=computation_backend,
242+
computation_backends=computation_backends,
243+
)
244+
245+
def _sort_key(
246+
self, candidate: InstallationCandidate
247+
) -> Tuple[cb.ComputationBackend, Version]:
248+
version = Version(
249+
f"{candidate.version.major}"
250+
f".{candidate.version.minor}"
251+
f".{candidate.version.micro}"
219252
)
253+
computation_backend = cb.ComputationBackend.from_str(candidate.version.local)
254+
return computation_backend, version
220255

221256
def get_applicable_candidates(
222257
self, candidates: List[InstallationCandidate]
223258
) -> List[InstallationCandidate]:
224259
return [
225260
candidate
226261
for candidate in super().get_applicable_candidates(candidates)
227-
if candidate.version.local == "any"
228-
or candidate.version.local == self.computation_backend
262+
if candidate.version.local in self.computation_backends
229263
]
230264

231265

232266
class PytorchLinkCollector(LinkCollector):
233267
def __init__(
234268
self,
235269
*args: Any,
236-
computation_backend: ComputationBackend,
270+
computation_backends: Set[cb.ComputationBackend],
237271
channel: str = "stable",
238272
**kwargs: Any,
239273
) -> None:
240274
super().__init__(*args, **kwargs)
241275
if channel == "stable":
242-
url = "https://download.pytorch.org/whl/torch_stable.html"
276+
urls = ["https://download.pytorch.org/whl/torch_stable.html"]
243277
else:
244-
url = (
278+
urls = [
245279
f"https://download.pytorch.org/whl/"
246-
f"{channel}/{computation_backend}/torch_{channel}.html"
247-
)
248-
self.search_scope = SearchScope.create(find_links=[url], index_urls=[])
280+
f"{channel}/{backend}/torch_{channel}.html"
281+
for backend in sorted(computation_backends, key=str)
282+
]
283+
self.search_scope = SearchScope.create(find_links=urls, index_urls=[])
249284

250285
@classmethod
251286
def from_link_collector(
252287
cls,
253288
link_collector: LinkCollector,
254-
computation_backend: ComputationBackend,
289+
computation_backends: Set[cb.ComputationBackend],
255290
channel: str = "stable",
256291
) -> "PytorchLinkCollector":
257292
return new_from_similar(
258293
cls,
259294
link_collector,
260-
("session", "search_scope",),
295+
("session", "search_scope"),
261296
channel=channel,
262-
computation_backend=computation_backend,
297+
computation_backends=computation_backends,
263298
)
264299

265300

@@ -270,18 +305,18 @@ class PytorchPackageFinder(PackageFinder):
270305
def __init__(
271306
self,
272307
*args: Any,
273-
computation_backend: ComputationBackend,
308+
computation_backends: Set[cb.ComputationBackend],
274309
channel: str = "stable",
275310
**kwargs: Any,
276311
) -> None:
277312
super().__init__(*args, **kwargs)
278313
self._candidate_prefs = PytorchCandidatePreferences.from_candidate_preferences(
279-
self._candidate_prefs, computation_backend=computation_backend
314+
self._candidate_prefs, computation_backends=computation_backends
280315
)
281316
self._link_collector = PytorchLinkCollector.from_link_collector(
282317
self._link_collector,
283318
channel=channel,
284-
computation_backend=computation_backend,
319+
computation_backends=computation_backends,
285320
)
286321

287322
def make_candidate_evaluator(
@@ -290,7 +325,7 @@ def make_candidate_evaluator(
290325
candidate_evaluator = super().make_candidate_evaluator(*args, **kwargs)
291326
return PytorchCandidateEvaluator.from_candidate_evaluator(
292327
candidate_evaluator,
293-
computation_backend=self._candidate_prefs.computation_backend,
328+
computation_backends=self._candidate_prefs.computation_backends,
294329
)
295330

296331
def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator:
@@ -301,7 +336,7 @@ def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator
301336
def from_package_finder(
302337
cls,
303338
package_finder: PackageFinder,
304-
computation_backend: ComputationBackend,
339+
computation_backends: Set[cb.ComputationBackend],
305340
channel: str = "stable",
306341
) -> "PytorchPackageFinder":
307342
return new_from_similar(
@@ -315,7 +350,7 @@ def from_package_finder(
315350
"candidate_prefs",
316351
"ignore_requires_python",
317352
),
318-
computation_backend=computation_backend,
353+
computation_backends=computation_backends,
319354
channel=channel,
320355
)
321356

@@ -338,19 +373,19 @@ def resolve(
338373
class StopAfterPytorchLinksFoundCommand(PatchedInstallCommand):
339374
def __init__(
340375
self,
341-
computation_backend: ComputationBackend,
376+
computation_backends: Set[cb.ComputationBackend],
342377
channel: str = "stable",
343378
**kwargs: Any,
344379
) -> None:
345380
super().__init__(**kwargs)
346-
self.computation_backend = computation_backend
381+
self.computation_backends = computation_backends
347382
self.channel = channel
348383

349384
def _build_package_finder(self, *args: Any, **kwargs: Any) -> PytorchPackageFinder:
350385
package_finder = super()._build_package_finder(*args, **kwargs)
351386
return PytorchPackageFinder.from_package_finder(
352387
package_finder,
353-
computation_backend=self.computation_backend,
388+
computation_backends=self.computation_backends,
354389
channel=self.channel,
355390
)
356391

light_the_torch/cli/commands.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def _run(self, pip_install_args: List[str]) -> None:
5454

5555
class FindCommand(Command):
5656
def __init__(self, args: argparse.Namespace) -> None:
57-
self.computation_backend = args.computation_backend
57+
# TODO split by comma
58+
self.computation_backends = args.computation_backend
5859
self.channel = args.channel
5960
self.platform = args.platform
6061
self.python_version = args.python_version
@@ -63,7 +64,7 @@ def __init__(self, args: argparse.Namespace) -> None:
6364
def _run(self, pip_install_args: List[str]) -> None:
6465
links = ltt.find_links(
6566
pip_install_args,
66-
computation_backend=self.computation_backend,
67+
computation_backends=self.computation_backends,
6768
channel=self.channel,
6869
platform=self.platform,
6970
python_version=self.python_version,
@@ -88,7 +89,7 @@ def __init__(self, args: argparse.Namespace) -> None:
8889
def _run(self, pip_install_args: List[str]) -> None:
8990
links = ltt.find_links(
9091
pip_install_args,
91-
computation_backend=CPUBackend() if self.force_cpu else None,
92+
computation_backends={CPUBackend()} if self.force_cpu else None,
9293
channel=self.channel,
9394
verbose=self.verbose,
9495
)

0 commit comments

Comments
 (0)