3030def find_links (
3131 pip_install_args : List [str ],
3232 computation_backend : Optional [Union [str , ComputationBackend ]] = None ,
33+ channel : str = "stable" ,
3334 platform : Optional [str ] = None ,
3435 python_version : Optional [str ] = None ,
3536 verbose : bool = False ,
@@ -43,6 +44,8 @@ def find_links(
4344 computation_backend: Computation backend, for example ``"cpu"`` or ``"cu102"``.
4445 Defaults to the available hardware of the running system preferring CUDA
4546 over CPU.
47+ channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
48+ ``"test"``, and ``"nightly"``.
4649 platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
4750 to the platform of the running system.
4851 python_version: Python version, for example ``"3"`` or ``"3.7"``. Defaults to
@@ -57,9 +60,17 @@ def find_links(
5760 elif isinstance (computation_backend , str ):
5861 computation_backend = ComputationBackend .from_str (computation_backend )
5962
63+ if channel not in ("stable" , "test" , "nightly" ):
64+ raise ValueError (
65+ f"channel can be one of 'stable', 'test', or 'nightly', "
66+ f"but got { channel } instead."
67+ )
68+
6069 dists = extract_dists (pip_install_args )
6170
62- cmd = StopAfterPytorchLinksFoundCommand (computation_backend = computation_backend )
71+ cmd = StopAfterPytorchLinksFoundCommand (
72+ computation_backend = computation_backend , channel = channel
73+ )
6374 pip_install_args = adjust_pip_install_args (dists , platform , python_version )
6475 options , args = cmd .parser .parse_args (pip_install_args )
6576 try :
@@ -222,32 +233,55 @@ class PytorchLinkCollector(LinkCollector):
222233 def __init__ (
223234 self ,
224235 * args : Any ,
225- url : str = "https://download.pytorch.org/whl/torch_stable.html" ,
236+ computation_backend : ComputationBackend ,
237+ channel : str = "stable" ,
226238 ** kwargs : Any ,
227239 ) -> None :
228240 super ().__init__ (* args , ** kwargs )
241+ if channel == "stable" :
242+ url = "https://download.pytorch.org/whl/torch_stable.html"
243+ else :
244+ url = (
245+ f"https://download.pytorch.org/whl/"
246+ f"{ channel } /{ computation_backend } /torch_{ channel } .html"
247+ )
229248 self .search_scope = SearchScope .create (find_links = [url ], index_urls = [])
230249
231250 @classmethod
232251 def from_link_collector (
233- cls , link_collector : LinkCollector
252+ cls ,
253+ link_collector : LinkCollector ,
254+ computation_backend : ComputationBackend ,
255+ channel : str = "stable" ,
234256 ) -> "PytorchLinkCollector" :
235- return new_from_similar (cls , link_collector , ("session" , "search_scope" ,),)
257+ return new_from_similar (
258+ cls ,
259+ link_collector ,
260+ ("session" , "search_scope" ,),
261+ channel = channel ,
262+ computation_backend = computation_backend ,
263+ )
236264
237265
238266class PytorchPackageFinder (PackageFinder ):
239267 _candidate_prefs : PytorchCandidatePreferences
240268 _link_collector : PytorchLinkCollector
241269
242270 def __init__ (
243- self , * args : Any , computation_backend : ComputationBackend , ** kwargs : Any ,
271+ self ,
272+ * args : Any ,
273+ computation_backend : ComputationBackend ,
274+ channel : str = "stable" ,
275+ ** kwargs : Any ,
244276 ) -> None :
245277 super ().__init__ (* args , ** kwargs )
246278 self ._candidate_prefs = PytorchCandidatePreferences .from_candidate_preferences (
247279 self ._candidate_prefs , computation_backend = computation_backend
248280 )
249281 self ._link_collector = PytorchLinkCollector .from_link_collector (
250- self ._link_collector
282+ self ._link_collector ,
283+ channel = channel ,
284+ computation_backend = computation_backend ,
251285 )
252286
253287 def make_candidate_evaluator (
@@ -265,7 +299,10 @@ def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator
265299
266300 @classmethod
267301 def from_package_finder (
268- cls , package_finder : PackageFinder , computation_backend : ComputationBackend ,
302+ cls ,
303+ package_finder : PackageFinder ,
304+ computation_backend : ComputationBackend ,
305+ channel : str = "stable" ,
269306 ) -> "PytorchPackageFinder" :
270307 return new_from_similar (
271308 cls ,
@@ -279,6 +316,7 @@ def from_package_finder(
279316 "ignore_requires_python" ,
280317 ),
281318 computation_backend = computation_backend ,
319+ channel = channel ,
282320 )
283321
284322
@@ -298,14 +336,22 @@ def resolve(
298336
299337
300338class StopAfterPytorchLinksFoundCommand (PatchedInstallCommand ):
301- def __init__ (self , computation_backend : ComputationBackend , ** kwargs : Any ) -> None :
339+ def __init__ (
340+ self ,
341+ computation_backend : ComputationBackend ,
342+ channel : str = "stable" ,
343+ ** kwargs : Any ,
344+ ) -> None :
302345 super ().__init__ (** kwargs )
303346 self .computation_backend = computation_backend
347+ self .channel = channel
304348
305349 def _build_package_finder (self , * args : Any , ** kwargs : Any ) -> PytorchPackageFinder :
306350 package_finder = super ()._build_package_finder (* args , ** kwargs )
307351 return PytorchPackageFinder .from_package_finder (
308- package_finder , computation_backend = self .computation_backend
352+ package_finder ,
353+ computation_backend = self .computation_backend ,
354+ channel = self .channel ,
309355 )
310356
311357 def make_resolver (
0 commit comments