1- import sys
1+ import itertools
22
33import pytest
44
55import light_the_torch as ltt
6- from light_the_torch import computation_backend as cb
76from light_the_torch ._pip .common import InternalLTTError
8- from light_the_torch ._pip .find import PytorchCandidatePreferences
7+ from light_the_torch ._pip .find import PytorchCandidatePreferences , maybe_add_option
8+ from light_the_torch .computation_backend import ComputationBackend
99
1010
11- def test_find_links_internal_error (mocker ):
12- mocker .patch ("light_the_torch._pip.find.extract_dists" , return_value = [])
13- mocker .patch ("light_the_torch._pip.find.run" )
11+ @pytest .fixture
12+ def patch_extract_dists (mocker ):
13+ def patch_extract_dists_ (return_value = None ):
14+ if return_value is None :
15+ return_value = []
16+ return mocker .patch (
17+ "light_the_torch._pip.find.extract_dists" , return_value = return_value
18+ )
1419
15- with pytest .raises (InternalLTTError ):
16- ltt .find_links (["foo" ])
20+ return patch_extract_dists_
21+
22+
23+ @pytest .fixture
24+ def patch_run (mocker ):
25+ def patch_run_ ():
26+ return mocker .patch ("light_the_torch._pip.find.run" )
27+
28+ return patch_run_
29+
30+
31+ @pytest .fixture
32+ def computation_backends ():
33+ return ("cpu" , "cu92" , "cu101" , "cu102" )
34+
35+
36+ @pytest .fixture
37+ def platforms ():
38+ return ("linux_x86_64" , "macosx_10_9_x86_64" , "win_amd64" )
39+
40+
41+ @pytest .fixture
42+ def python_versions ():
43+ return ("3.6" , "3.7" , "3.8" )
44+
45+
46+ @pytest .fixture
47+ def wheel_properties (computation_backends , platforms , python_versions ):
48+ properties = []
49+ for properties_ in itertools .product (
50+ computation_backends , platforms , python_versions
51+ ):
52+ # macOS binaries don't support CUDA
53+ computation_backend , platform , _ = properties_
54+ if platform .startswith ("macosx" ) and computation_backend != "cpu" :
55+ continue
56+
57+ properties .append (
58+ dict (
59+ zip (("computation_backend" , "platform" , "python_version" ), properties_ )
60+ )
61+ )
62+ return tuple (properties )
63+
64+
65+ def test_maybe_add_option_already_set (subtests ):
66+ args = ["--foo" , "bar" ]
67+ assert maybe_add_option (args , "--foo" ,) == args
68+ assert maybe_add_option (args , "-f" , aliases = ("--foo" ,)) == args
1769
1870
1971def test_PytorchCandidatePreferences_detect_computation_backend (mocker ):
20- class GenericComputationBackend (cb . ComputationBackend ):
72+ class GenericComputationBackend (ComputationBackend ):
2173 @property
2274 def local_specifier (self ):
2375 return "generic"
@@ -32,48 +84,104 @@ def local_specifier(self):
3284 assert candidate_prefs .computation_backend is computation_backend
3385
3486
35- @pytest .fixture
36- def computation_backends ():
37- strings = ["cpu" ]
38- if sys .platform .startswith ("linux" ) or sys .platform .startswith ("win" ):
39- strings .extend (("cu92" , "cu101" , "cu102" ))
40- return [cb .ComputationBackend .from_str (string ) for string in strings ]
87+ def test_find_links_internal_error (patch_extract_dists , patch_run ):
88+ patch_extract_dists ()
89+ patch_run ()
4190
91+ with pytest .raises (InternalLTTError ):
92+ ltt .find_links ([])
4293
43- @pytest .mark .slow
44- def test_find_links_torch_smoke (subtests , computation_backends ):
45- dist = "torch"
94+
95+ def test_find_links_computation_backend (
96+ subtests , patch_extract_dists , patch_run , computation_backends
97+ ):
98+ patch_extract_dists ()
99+ run = patch_run ()
46100
47101 for computation_backend in computation_backends :
48102 with subtests .test (computation_backend = computation_backend ):
49- assert ltt .find_links ([dist ], computation_backend = computation_backend )
103+ run .reset ()
104+
105+ with pytest .raises (InternalLTTError ):
106+ ltt .find_links ([], computation_backend = computation_backend )
107+
108+ args , _ = run .call_args
109+ cmd = args [0 ]
110+ assert cmd .computation_backend == ComputationBackend .from_str (
111+ computation_backend
112+ )
113+
114+
115+ def test_find_links_platform (subtests , patch_extract_dists , patch_run , platforms ):
116+ patch_extract_dists ()
117+ run = patch_run ()
118+
119+ for platform in platforms :
120+ with subtests .test (platform = platform ):
121+ run .reset ()
122+
123+ with pytest .raises (InternalLTTError ):
124+ ltt .find_links ([], platform = platform )
125+
126+ args , _ = run .call_args
127+ options = args [2 ]
128+ assert options .platform == platform
129+
130+
131+ def test_find_links_python_version (
132+ subtests , patch_extract_dists , patch_run , python_versions
133+ ):
134+ patch_extract_dists ()
135+ run = patch_run ()
136+
137+ for python_version in python_versions :
138+ python_version_tuple = tuple (int (v ) for v in python_version .split ("." ))
139+ with subtests .test (python_version = python_version ):
140+ run .reset ()
141+
142+ with pytest .raises (InternalLTTError ):
143+ ltt .find_links ([], python_version = python_version )
144+
145+ args , _ = run .call_args
146+ options = args [2 ]
147+ assert options .python_version == python_version_tuple
148+
149+
150+ @pytest .mark .slow
151+ def test_find_links_torch_smoke (subtests , wheel_properties ):
152+ dist = "torch"
153+
154+ for properties in wheel_properties :
155+ with subtests .test (** properties ):
156+ assert ltt .find_links ([dist ], ** properties )
50157
51158
52159@pytest .mark .slow
53- @pytest .mark .skipif (
54- sys .platform .startswith ("win" ), reason = "torchaudio has no releases for Windows"
55- )
56- def test_find_links_torchaudio_smoke (subtests , computation_backends ):
160+ def test_find_links_torchaudio_smoke (subtests , wheel_properties ):
57161 dist = "torchaudio"
58162
59- for computation_backend in computation_backends :
60- with subtests .test (computation_backend = computation_backend ):
61- assert ltt .find_links ([dist ], computation_backend = computation_backend )
163+ for properties in wheel_properties :
164+ # torchaudio has no published releases for Windows
165+ if properties ["platform" ].startswith ("win" ):
166+ continue
167+ with subtests .test (** properties ):
168+ a = ltt .find_links ([dist ], ** properties )
169+ assert a
62170
63171
64172@pytest .mark .slow
65- def test_find_links_torchtext_smoke (subtests , computation_backends ):
173+ def test_find_links_torchtext_smoke (subtests , wheel_properties ):
66174 dist = "torchtext"
67175
68- for computation_backend in computation_backends :
69- with subtests .test (computation_backend = computation_backend ):
70- assert ltt .find_links ([dist ], computation_backend = computation_backend )
176+ for properties in wheel_properties :
177+ with subtests .test (** properties ):
178+ assert ltt .find_links ([dist ], ** properties )
71179
72180
73181@pytest .mark .slow
74- def test_find_links_torchvision_smoke (subtests , computation_backends ):
182+ def test_find_links_torchvision_smoke (subtests , wheel_properties ):
75183 dist = "torchvision"
76184
77- for computation_backend in computation_backends :
78- with subtests .test (computation_backend = computation_backend ):
79- assert ltt .find_links ([dist ], computation_backend = computation_backend )
185+ for properties in wheel_properties :
186+ with subtests .test (** properties ):
187+ assert ltt .find_links ([dist ], ** properties )
0 commit comments