|
| 1 | +from typing import Any, Optional |
| 2 | + |
| 3 | +__all__ = ["find_compatible_torch_version"] |
| 4 | + |
| 5 | + |
| 6 | +class Version: |
| 7 | + @classmethod |
| 8 | + def from_str(cls, version: str) -> "Version": |
| 9 | + parts = version.split(".") |
| 10 | + major = int(parts[0]) |
| 11 | + minor = int(parts[1]) if len(parts) > 1 else None |
| 12 | + patch = int(parts[2]) if len(parts) > 2 else None |
| 13 | + return cls(major, minor, patch) |
| 14 | + |
| 15 | + def __init__(self, major: int, minor: Optional[int], patch: Optional[int]) -> None: |
| 16 | + self.major = major |
| 17 | + self.minor = minor |
| 18 | + self.patch = patch |
| 19 | + self.parts = (major, minor, patch) |
| 20 | + |
| 21 | + def __eq__(self, other: Any) -> bool: |
| 22 | + if not isinstance(other, Version): |
| 23 | + return False |
| 24 | + |
| 25 | + return all( |
| 26 | + [ |
| 27 | + self_part == other_part |
| 28 | + for self_part, other_part in zip(self.parts, other.parts) |
| 29 | + if self_part is not None and other_part is not None |
| 30 | + ] |
| 31 | + ) |
| 32 | + |
| 33 | + def __hash__(self) -> int: |
| 34 | + return hash(self.parts) |
| 35 | + |
| 36 | + def __repr__(self) -> str: |
| 37 | + return ".".join([str(part) for part in self.parts if part is not None]) |
| 38 | + |
| 39 | + |
| 40 | +COMPATIBILITY = { |
| 41 | + "torchvision": { |
| 42 | + Version(0, 8, 0): Version(1, 7, 0), |
| 43 | + Version(0, 7, 0): Version(1, 6, 0), |
| 44 | + Version(0, 6, 1): Version(1, 5, 1), |
| 45 | + Version(0, 6, 0): Version(1, 5, 0), |
| 46 | + Version(0, 5, 0): Version(1, 4, 0), |
| 47 | + Version(0, 4, 2): Version(1, 3, 1), |
| 48 | + Version(0, 4, 1): Version(1, 3, 0), |
| 49 | + Version(0, 4, 0): Version(1, 2, 0), |
| 50 | + Version(0, 3, 0): Version(1, 1, 0), |
| 51 | + Version(0, 2, 2): Version(1, 0, 1), |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | + |
| 56 | +def find_compatible_torch_version(dist: str, version: str) -> str: |
| 57 | + version = Version.from_str(version) |
| 58 | + dist_compatibility = COMPATIBILITY[dist] |
| 59 | + candidates = [x for x in dist_compatibility.keys() if x == version] |
| 60 | + if not candidates: |
| 61 | + raise RuntimeError( |
| 62 | + f"No compatible torch version was found for {dist}=={version}" |
| 63 | + ) |
| 64 | + if len(candidates) != 1: |
| 65 | + raise RuntimeError( |
| 66 | + f"Multiple compatible torch versions were found for {dist}=={version}:\n" |
| 67 | + f"{', '.join([str(candidate) for candidate in candidates])}\n" |
| 68 | + ) |
| 69 | + |
| 70 | + return str(dist_compatibility[candidates[0]]) |
0 commit comments