11import re
2- from typing import Any , Iterable , List , NoReturn , Optional , Text , Tuple , Union , cast
2+ from typing import (
3+ Any ,
4+ Collection ,
5+ Iterable ,
6+ List ,
7+ NoReturn ,
8+ Optional ,
9+ Set ,
10+ Text ,
11+ Tuple ,
12+ Union ,
13+ cast ,
14+ )
315
416from pip ._internal .index .collector import LinkCollector
517from pip ._internal .index .package_finder import (
1325from pip ._internal .models .search_scope import SearchScope
1426from pip ._internal .req .req_install import InstallRequirement
1527from pip ._internal .req .req_set import RequirementSet
28+ from pip ._vendor .packaging .version import Version
29+
30+ import light_the_torch .computation_backend as cb
1631
17- from ..computation_backend import ComputationBackend , detect_computation_backend
1832from .common import (
1933 InternalLTTError ,
2034 PatchedInstallCommand ,
2943
3044def find_links (
3145 pip_install_args : List [str ],
32- computation_backend : Optional [Union [str , ComputationBackend ]] = None ,
46+ computation_backends : Optional [
47+ Union [cb .ComputationBackend , Collection [cb .ComputationBackend ]]
48+ ] = None ,
3349 channel : str = "stable" ,
3450 platform : Optional [str ] = None ,
3551 python_version : Optional [str ] = None ,
@@ -41,9 +57,9 @@ def find_links(
4157 Args:
4258 pip_install_args: Arguments passed to ``pip install`` that will be searched for
4359 required PyTorch distributions
44- computation_backend: Computation backend, for example ``"cpu"`` or ``"cu102"``.
45- Defaults to the available hardware of the running system preferring CUDA
46- over CPU .
60+ computation_backends: Collection of supported computation backends, for example
61+ ``"cpu"`` or ``"cu102"``. Defaults to the available hardware of the running
62+ system .
4763 channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
4864 ``"test"``, and ``"nightly"``.
4965 platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
@@ -55,10 +71,12 @@ def find_links(
5571 Returns:
5672 Wheel links with given properties for all required PyTorch distributions.
5773 """
58- if computation_backend is None :
59- computation_backend = detect_computation_backend ()
60- elif isinstance (computation_backend , str ):
61- computation_backend = ComputationBackend .from_str (computation_backend )
74+ if computation_backends is None :
75+ computation_backends = cb .detect_compatible_computation_backends ()
76+ elif isinstance (computation_backends , cb .ComputationBackend ):
77+ computation_backends = {computation_backends }
78+ else :
79+ computation_backends = set (computation_backends )
6280
6381 if channel not in ("stable" , "test" , "nightly" ):
6482 raise ValueError (
@@ -69,7 +87,7 @@ def find_links(
6987 dists = extract_dists (pip_install_args )
7088
7189 cmd = StopAfterPytorchLinksFoundCommand (
72- computation_backend = computation_backend , channel = channel
90+ computation_backends = computation_backends , channel = channel
7391 )
7492 pip_install_args = adjust_pip_install_args (dists , platform , python_version )
7593 options , args = cmd .parser .parse_args (pip_install_args )
@@ -172,37 +190,43 @@ def extract_computation_backend_from_link(self, link: Link) -> Optional[str]:
172190
173191class PytorchCandidatePreferences (CandidatePreferences ):
174192 def __init__ (
175- self , * args : Any , computation_backend : ComputationBackend , ** kwargs : Any ,
193+ self ,
194+ * args : Any ,
195+ computation_backends : Set [cb .ComputationBackend ],
196+ ** kwargs : Any ,
176197 ) -> None :
177198 super ().__init__ (* args , ** kwargs )
178- self .computation_backend = computation_backend
199+ self .computation_backends = computation_backends
179200
180201 @classmethod
181202 def from_candidate_preferences (
182203 cls ,
183204 candidate_preferences : CandidatePreferences ,
184- computation_backend : ComputationBackend ,
205+ computation_backends : Set [ cb . ComputationBackend ] ,
185206 ) -> "PytorchCandidatePreferences" :
186207 return new_from_similar (
187208 cls ,
188209 candidate_preferences ,
189210 ("prefer_binary" , "allow_all_prereleases" ,),
190- computation_backend = computation_backend ,
211+ computation_backends = computation_backends ,
191212 )
192213
193214
194215class PytorchCandidateEvaluator (CandidateEvaluator ):
195216 def __init__ (
196- self , * args : Any , computation_backend : ComputationBackend , ** kwargs : Any ,
217+ self ,
218+ * args : Any ,
219+ computation_backends : Set [cb .ComputationBackend ],
220+ ** kwargs : Any ,
197221 ) -> None :
198222 super ().__init__ (* args , ** kwargs )
199- self .computation_backend = computation_backend
223+ self .computation_backends = { cb . AnyBackend (), * computation_backends }
200224
201225 @classmethod
202226 def from_candidate_evaluator (
203227 cls ,
204228 candidate_evaluator : CandidateEvaluator ,
205- computation_backend : ComputationBackend ,
229+ computation_backends : Set [ cb . ComputationBackend ] ,
206230 ) -> "PytorchCandidateEvaluator" :
207231 return new_from_similar (
208232 cls ,
@@ -215,51 +239,62 @@ def from_candidate_evaluator(
215239 "allow_all_prereleases" ,
216240 "hashes" ,
217241 ),
218- computation_backend = computation_backend ,
242+ computation_backends = computation_backends ,
243+ )
244+
245+ def _sort_key (
246+ self , candidate : InstallationCandidate
247+ ) -> Tuple [cb .ComputationBackend , Version ]:
248+ version = Version (
249+ f"{ candidate .version .major } "
250+ f".{ candidate .version .minor } "
251+ f".{ candidate .version .micro } "
219252 )
253+ computation_backend = cb .ComputationBackend .from_str (candidate .version .local )
254+ return computation_backend , version
220255
221256 def get_applicable_candidates (
222257 self , candidates : List [InstallationCandidate ]
223258 ) -> List [InstallationCandidate ]:
224259 return [
225260 candidate
226261 for candidate in super ().get_applicable_candidates (candidates )
227- if candidate .version .local == "any"
228- or candidate .version .local == self .computation_backend
262+ if candidate .version .local in self .computation_backends
229263 ]
230264
231265
232266class PytorchLinkCollector (LinkCollector ):
233267 def __init__ (
234268 self ,
235269 * args : Any ,
236- computation_backend : ComputationBackend ,
270+ computation_backends : Set [ cb . ComputationBackend ] ,
237271 channel : str = "stable" ,
238272 ** kwargs : Any ,
239273 ) -> None :
240274 super ().__init__ (* args , ** kwargs )
241275 if channel == "stable" :
242- url = "https://download.pytorch.org/whl/torch_stable.html"
276+ urls = [ "https://download.pytorch.org/whl/torch_stable.html" ]
243277 else :
244- url = (
278+ urls = [
245279 f"https://download.pytorch.org/whl/"
246- f"{ channel } /{ computation_backend } /torch_{ channel } .html"
247- )
248- self .search_scope = SearchScope .create (find_links = [url ], index_urls = [])
280+ f"{ channel } /{ backend } /torch_{ channel } .html"
281+ for backend in sorted (computation_backends , key = str )
282+ ]
283+ self .search_scope = SearchScope .create (find_links = urls , index_urls = [])
249284
250285 @classmethod
251286 def from_link_collector (
252287 cls ,
253288 link_collector : LinkCollector ,
254- computation_backend : ComputationBackend ,
289+ computation_backends : Set [ cb . ComputationBackend ] ,
255290 channel : str = "stable" ,
256291 ) -> "PytorchLinkCollector" :
257292 return new_from_similar (
258293 cls ,
259294 link_collector ,
260- ("session" , "search_scope" , ),
295+ ("session" , "search_scope" ),
261296 channel = channel ,
262- computation_backend = computation_backend ,
297+ computation_backends = computation_backends ,
263298 )
264299
265300
@@ -270,18 +305,18 @@ class PytorchPackageFinder(PackageFinder):
270305 def __init__ (
271306 self ,
272307 * args : Any ,
273- computation_backend : ComputationBackend ,
308+ computation_backends : Set [ cb . ComputationBackend ] ,
274309 channel : str = "stable" ,
275310 ** kwargs : Any ,
276311 ) -> None :
277312 super ().__init__ (* args , ** kwargs )
278313 self ._candidate_prefs = PytorchCandidatePreferences .from_candidate_preferences (
279- self ._candidate_prefs , computation_backend = computation_backend
314+ self ._candidate_prefs , computation_backends = computation_backends
280315 )
281316 self ._link_collector = PytorchLinkCollector .from_link_collector (
282317 self ._link_collector ,
283318 channel = channel ,
284- computation_backend = computation_backend ,
319+ computation_backends = computation_backends ,
285320 )
286321
287322 def make_candidate_evaluator (
@@ -290,7 +325,7 @@ def make_candidate_evaluator(
290325 candidate_evaluator = super ().make_candidate_evaluator (* args , ** kwargs )
291326 return PytorchCandidateEvaluator .from_candidate_evaluator (
292327 candidate_evaluator ,
293- computation_backend = self ._candidate_prefs .computation_backend ,
328+ computation_backends = self ._candidate_prefs .computation_backends ,
294329 )
295330
296331 def make_link_evaluator (self , * args : Any , ** kwargs : Any ) -> PytorchLinkEvaluator :
@@ -301,7 +336,7 @@ def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator
301336 def from_package_finder (
302337 cls ,
303338 package_finder : PackageFinder ,
304- computation_backend : ComputationBackend ,
339+ computation_backends : Set [ cb . ComputationBackend ] ,
305340 channel : str = "stable" ,
306341 ) -> "PytorchPackageFinder" :
307342 return new_from_similar (
@@ -315,7 +350,7 @@ def from_package_finder(
315350 "candidate_prefs" ,
316351 "ignore_requires_python" ,
317352 ),
318- computation_backend = computation_backend ,
353+ computation_backends = computation_backends ,
319354 channel = channel ,
320355 )
321356
@@ -338,19 +373,19 @@ def resolve(
338373class StopAfterPytorchLinksFoundCommand (PatchedInstallCommand ):
339374 def __init__ (
340375 self ,
341- computation_backend : ComputationBackend ,
376+ computation_backends : Set [ cb . ComputationBackend ] ,
342377 channel : str = "stable" ,
343378 ** kwargs : Any ,
344379 ) -> None :
345380 super ().__init__ (** kwargs )
346- self .computation_backend = computation_backend
381+ self .computation_backends = computation_backends
347382 self .channel = channel
348383
349384 def _build_package_finder (self , * args : Any , ** kwargs : Any ) -> PytorchPackageFinder :
350385 package_finder = super ()._build_package_finder (* args , ** kwargs )
351386 return PytorchPackageFinder .from_package_finder (
352387 package_finder ,
353- computation_backend = self .computation_backend ,
388+ computation_backends = self .computation_backends ,
354389 channel = self .channel ,
355390 )
356391
0 commit comments