Skip to content

Commit 6c1efd2

Browse files
authored
add remaining tests (#17)
* get_verbosity * unrecognized argument * find * bug fix * cleanup * trigger CI * Revert "trigger CI" This reverts commit 727504e. * cleanup after rebase
1 parent ddca55f commit 6c1efd2

File tree

7 files changed

+187
-42
lines changed

7 files changed

+187
-42
lines changed

light_the_torch/_pip/find.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def adjust_pip_install_args(
7878
pip_install_args = maybe_add_option(
7979
pip_install_args, "--platform", value=platform
8080
)
81-
if platform is not None:
81+
if python_version is not None:
8282
pip_install_args = maybe_add_option(
8383
pip_install_args, "--python-version", value=python_version
8484
)

tests/cli/test_extract.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,10 @@ def test_ltt_extract_verbose(patch_extract_argv, patch_extract_dists):
4343
_, kwargs = extract_dists.call_args
4444
assert "verbose" in kwargs
4545
assert kwargs["verbose"]
46+
47+
48+
def test_extract_unrecognized_argument(patch_extract_argv):
49+
patch_extract_argv("--unrecognized-argument")
50+
51+
with exits(error=True):
52+
cli.main()

tests/cli/test_find.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,10 @@ def test_ltt_find_verbose(patch_find_argv, patch_find_links):
4343
_, kwargs = find_links.call_args
4444
assert "verbose" in kwargs
4545
assert kwargs["verbose"]
46+
47+
48+
def test_find_unrecognized_argument(patch_find_argv):
49+
patch_find_argv("--unrecognized-argument")
50+
51+
with exits(error=True):
52+
cli.main()

tests/cli/test_install.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,10 @@ def test_ltt_install_verbose(
188188
_, kwargs = find_links.call_args
189189
assert "verbose" in kwargs
190190
assert kwargs["verbose"]
191+
192+
193+
def test_install_unrecognized_argument(patch_install_argv):
194+
patch_install_argv("--unrecognized-argument")
195+
196+
with exits(error=True):
197+
cli.main()

tests/unit/pip/test_common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
1+
import itertools
2+
import optparse
3+
14
import pytest
25

36
from light_the_torch._pip import common
47

58

9+
def test_get_verbosity(subtests):
10+
verboses = tuple(range(4))
11+
quiets = tuple(range(4))
12+
13+
for verbose, quiet in itertools.product(verboses, quiets):
14+
with subtests.test(verbose=verbose, quiet=quiet):
15+
options = optparse.Values({"verbose": verbose, "quiet": quiet})
16+
verbosity = verbose - quiet
17+
18+
assert common.get_verbosity(options, verbose=True) == verbosity
19+
assert common.get_verbosity(options, verbose=False) == -1
20+
21+
622
def test_get_public_or_private_attr_public_and_private():
723
class ObjWithPublicAndPrivateAttribute:
824
attr = "public"

tests/unit/pip/test_extract.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@
55
from light_the_torch._pip.common import InternalLTTError
66

77

8-
def test_extract_dists_internal_error(mocker):
9-
mocker.patch("light_the_torch._pip.extract.run")
10-
11-
with pytest.raises(InternalLTTError):
12-
ltt.extract_dists(["foo"])
13-
14-
158
def test_StopAfterPytorchDistsFoundResolver_no_torch(mocker):
169
mocker.patch(
1710
"light_the_torch._pip.extract.PatchedResolverBase.__init__", return_value=None
@@ -21,6 +14,13 @@ def test_StopAfterPytorchDistsFoundResolver_no_torch(mocker):
2114
assert "torch" in resolver.required_pytorch_dists
2215

2316

17+
def test_extract_pytorch_internal_error(mocker):
18+
mocker.patch("light_the_torch._pip.extract.run")
19+
20+
with pytest.raises(InternalLTTError):
21+
ltt.extract_dists(["foo"])
22+
23+
2424
@pytest.mark.slow
2525
def test_extract_dists_ltt():
2626
assert ltt.extract_dists(["light-the-torch"]) == []

tests/unit/pip/test_find.py

Lines changed: 142 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,75 @@
1-
import sys
1+
import itertools
22

33
import pytest
44

55
import light_the_torch as ltt
6-
from light_the_torch import computation_backend as cb
76
from light_the_torch._pip.common import InternalLTTError
8-
from light_the_torch._pip.find import PytorchCandidatePreferences
7+
from light_the_torch._pip.find import PytorchCandidatePreferences, maybe_add_option
8+
from light_the_torch.computation_backend import ComputationBackend
99

1010

11-
def test_find_links_internal_error(mocker):
12-
mocker.patch("light_the_torch._pip.find.extract_dists", return_value=[])
13-
mocker.patch("light_the_torch._pip.find.run")
11+
@pytest.fixture
12+
def patch_extract_dists(mocker):
13+
def patch_extract_dists_(return_value=None):
14+
if return_value is None:
15+
return_value = []
16+
return mocker.patch(
17+
"light_the_torch._pip.find.extract_dists", return_value=return_value
18+
)
1419

15-
with pytest.raises(InternalLTTError):
16-
ltt.find_links(["foo"])
20+
return patch_extract_dists_
21+
22+
23+
@pytest.fixture
24+
def patch_run(mocker):
25+
def patch_run_():
26+
return mocker.patch("light_the_torch._pip.find.run")
27+
28+
return patch_run_
29+
30+
31+
@pytest.fixture
32+
def computation_backends():
33+
return ("cpu", "cu92", "cu101", "cu102")
34+
35+
36+
@pytest.fixture
37+
def platforms():
38+
return ("linux_x86_64", "macosx_10_9_x86_64", "win_amd64")
39+
40+
41+
@pytest.fixture
42+
def python_versions():
43+
return ("3.6", "3.7", "3.8")
44+
45+
46+
@pytest.fixture
47+
def wheel_properties(computation_backends, platforms, python_versions):
48+
properties = []
49+
for properties_ in itertools.product(
50+
computation_backends, platforms, python_versions
51+
):
52+
# macOS binaries don't support CUDA
53+
computation_backend, platform, _ = properties_
54+
if platform.startswith("macosx") and computation_backend != "cpu":
55+
continue
56+
57+
properties.append(
58+
dict(
59+
zip(("computation_backend", "platform", "python_version"), properties_)
60+
)
61+
)
62+
return tuple(properties)
63+
64+
65+
def test_maybe_add_option_already_set(subtests):
66+
args = ["--foo", "bar"]
67+
assert maybe_add_option(args, "--foo",) == args
68+
assert maybe_add_option(args, "-f", aliases=("--foo",)) == args
1769

1870

1971
def test_PytorchCandidatePreferences_detect_computation_backend(mocker):
20-
class GenericComputationBackend(cb.ComputationBackend):
72+
class GenericComputationBackend(ComputationBackend):
2173
@property
2274
def local_specifier(self):
2375
return "generic"
@@ -32,48 +84,104 @@ def local_specifier(self):
3284
assert candidate_prefs.computation_backend is computation_backend
3385

3486

35-
@pytest.fixture
36-
def computation_backends():
37-
strings = ["cpu"]
38-
if sys.platform.startswith("linux") or sys.platform.startswith("win"):
39-
strings.extend(("cu92", "cu101", "cu102"))
40-
return [cb.ComputationBackend.from_str(string) for string in strings]
87+
def test_find_links_internal_error(patch_extract_dists, patch_run):
88+
patch_extract_dists()
89+
patch_run()
4190

91+
with pytest.raises(InternalLTTError):
92+
ltt.find_links([])
4293

43-
@pytest.mark.slow
44-
def test_find_links_torch_smoke(subtests, computation_backends):
45-
dist = "torch"
94+
95+
def test_find_links_computation_backend(
96+
subtests, patch_extract_dists, patch_run, computation_backends
97+
):
98+
patch_extract_dists()
99+
run = patch_run()
46100

47101
for computation_backend in computation_backends:
48102
with subtests.test(computation_backend=computation_backend):
49-
assert ltt.find_links([dist], computation_backend=computation_backend)
103+
run.reset()
104+
105+
with pytest.raises(InternalLTTError):
106+
ltt.find_links([], computation_backend=computation_backend)
107+
108+
args, _ = run.call_args
109+
cmd = args[0]
110+
assert cmd.computation_backend == ComputationBackend.from_str(
111+
computation_backend
112+
)
113+
114+
115+
def test_find_links_platform(subtests, patch_extract_dists, patch_run, platforms):
116+
patch_extract_dists()
117+
run = patch_run()
118+
119+
for platform in platforms:
120+
with subtests.test(platform=platform):
121+
run.reset()
122+
123+
with pytest.raises(InternalLTTError):
124+
ltt.find_links([], platform=platform)
125+
126+
args, _ = run.call_args
127+
options = args[2]
128+
assert options.platform == platform
129+
130+
131+
def test_find_links_python_version(
132+
subtests, patch_extract_dists, patch_run, python_versions
133+
):
134+
patch_extract_dists()
135+
run = patch_run()
136+
137+
for python_version in python_versions:
138+
python_version_tuple = tuple(int(v) for v in python_version.split("."))
139+
with subtests.test(python_version=python_version):
140+
run.reset()
141+
142+
with pytest.raises(InternalLTTError):
143+
ltt.find_links([], python_version=python_version)
144+
145+
args, _ = run.call_args
146+
options = args[2]
147+
assert options.python_version == python_version_tuple
148+
149+
150+
@pytest.mark.slow
151+
def test_find_links_torch_smoke(subtests, wheel_properties):
152+
dist = "torch"
153+
154+
for properties in wheel_properties:
155+
with subtests.test(**properties):
156+
assert ltt.find_links([dist], **properties)
50157

51158

52159
@pytest.mark.slow
53-
@pytest.mark.skipif(
54-
sys.platform.startswith("win"), reason="torchaudio has no releases for Windows"
55-
)
56-
def test_find_links_torchaudio_smoke(subtests, computation_backends):
160+
def test_find_links_torchaudio_smoke(subtests, wheel_properties):
57161
dist = "torchaudio"
58162

59-
for computation_backend in computation_backends:
60-
with subtests.test(computation_backend=computation_backend):
61-
assert ltt.find_links([dist], computation_backend=computation_backend)
163+
for properties in wheel_properties:
164+
# torchaudio has no published releases for Windows
165+
if properties["platform"].startswith("win"):
166+
continue
167+
with subtests.test(**properties):
168+
a = ltt.find_links([dist], **properties)
169+
assert a
62170

63171

64172
@pytest.mark.slow
65-
def test_find_links_torchtext_smoke(subtests, computation_backends):
173+
def test_find_links_torchtext_smoke(subtests, wheel_properties):
66174
dist = "torchtext"
67175

68-
for computation_backend in computation_backends:
69-
with subtests.test(computation_backend=computation_backend):
70-
assert ltt.find_links([dist], computation_backend=computation_backend)
176+
for properties in wheel_properties:
177+
with subtests.test(**properties):
178+
assert ltt.find_links([dist], **properties)
71179

72180

73181
@pytest.mark.slow
74-
def test_find_links_torchvision_smoke(subtests, computation_backends):
182+
def test_find_links_torchvision_smoke(subtests, wheel_properties):
75183
dist = "torchvision"
76184

77-
for computation_backend in computation_backends:
78-
with subtests.test(computation_backend=computation_backend):
79-
assert ltt.find_links([dist], computation_backend=computation_backend)
185+
for properties in wheel_properties:
186+
with subtests.test(**properties):
187+
assert ltt.find_links([dist], **properties)

0 commit comments

Comments
 (0)