Skip to content

Commit dcfe9bd

Browse files
authored
add initial support for inter PyTorch compatibility (#22)
1 parent 5e19baa commit dcfe9bd

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

light_the_torch/_pip/extract.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import contextlib
12
import re
23
from typing import Any, List, NoReturn, cast
34

45
from pip._internal.req.req_install import InstallRequirement
56
from pip._internal.req.req_set import RequirementSet
67

8+
from ..compatibility import find_compatible_torch_version
79
from .common import InternalLTTError, PatchedInstallCommand, PatchedResolverBase, run
810

911
__all__ = ["extract_dists"]
@@ -75,7 +77,18 @@ def required_pytorch_dists(self) -> List[str]:
7577
return []
7678

7779
if not any(self._pytorch_core_pattern.match(dist) for dist in dists):
78-
dists.insert(0, self.PYTORCH_CORE)
80+
torch = self.PYTORCH_CORE
81+
82+
with contextlib.suppress(RuntimeError):
83+
torch_versions = {
84+
find_compatible_torch_version(*dist.split("=="))
85+
for dist in dists
86+
if "==" in dist
87+
}
88+
if len(torch_versions) == 1:
89+
torch = f"{torch}=={torch_versions.pop()}"
90+
91+
dists.insert(0, torch)
7992

8093
return dists
8194

light_the_torch/compatibility.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any, Optional
2+
3+
__all__ = ["find_compatible_torch_version"]
4+
5+
6+
class Version:
7+
@classmethod
8+
def from_str(cls, version: str) -> "Version":
9+
parts = version.split(".")
10+
major = int(parts[0])
11+
minor = int(parts[1]) if len(parts) > 1 else None
12+
patch = int(parts[2]) if len(parts) > 2 else None
13+
return cls(major, minor, patch)
14+
15+
def __init__(self, major: int, minor: Optional[int], patch: Optional[int]) -> None:
16+
self.major = major
17+
self.minor = minor
18+
self.patch = patch
19+
self.parts = (major, minor, patch)
20+
21+
def __eq__(self, other: Any) -> bool:
22+
if not isinstance(other, Version):
23+
return False
24+
25+
return all(
26+
[
27+
self_part == other_part
28+
for self_part, other_part in zip(self.parts, other.parts)
29+
if self_part is not None and other_part is not None
30+
]
31+
)
32+
33+
def __hash__(self) -> int:
34+
return hash(self.parts)
35+
36+
def __repr__(self) -> str:
37+
return ".".join([str(part) for part in self.parts if part is not None])
38+
39+
40+
COMPATIBILITY = {
41+
"torchvision": {
42+
Version(0, 8, 0): Version(1, 7, 0),
43+
Version(0, 7, 0): Version(1, 6, 0),
44+
Version(0, 6, 1): Version(1, 5, 1),
45+
Version(0, 6, 0): Version(1, 5, 0),
46+
Version(0, 5, 0): Version(1, 4, 0),
47+
Version(0, 4, 2): Version(1, 3, 1),
48+
Version(0, 4, 1): Version(1, 3, 0),
49+
Version(0, 4, 0): Version(1, 2, 0),
50+
Version(0, 3, 0): Version(1, 1, 0),
51+
Version(0, 2, 2): Version(1, 0, 1),
52+
}
53+
}
54+
55+
56+
def find_compatible_torch_version(dist: str, version: str) -> str:
57+
version = Version.from_str(version)
58+
dist_compatibility = COMPATIBILITY[dist]
59+
candidates = [x for x in dist_compatibility.keys() if x == version]
60+
if not candidates:
61+
raise RuntimeError(
62+
f"No compatible torch version was found for {dist}=={version}"
63+
)
64+
if len(candidates) != 1:
65+
raise RuntimeError(
66+
f"Multiple compatible torch versions were found for {dist}=={version}:\n"
67+
f"{', '.join([str(candidate) for candidate in candidates])}\n"
68+
)
69+
70+
return str(dist_compatibility[candidates[0]])

tests/unit/pip/test_extract.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ def test_StopAfterPytorchDistsFoundResolver_no_torch(mocker):
1414
assert "torch" in resolver.required_pytorch_dists
1515

1616

17+
def test_StopAfterPytorchDistsFoundResolver_torch_compatibility(mocker):
18+
mocker.patch(
19+
"light_the_torch._pip.extract.PatchedResolverBase.__init__", return_value=None
20+
)
21+
resolver = extract.StopAfterPytorchDistsFoundResolver()
22+
resolver._required_pytorch_dists = ["torchvision==0.7"]
23+
assert "torch==1.6.0" in resolver.required_pytorch_dists
24+
25+
1726
def test_extract_pytorch_internal_error(mocker):
1827
mocker.patch("light_the_torch._pip.extract.run")
1928

0 commit comments

Comments
 (0)