|
1 | | -import os |
2 | | -import re |
3 | | -import sys |
4 | | -from typing import Dict, Optional |
5 | | - |
6 | | -# IMPORTANT: this list needs to be sorted in reverse |
7 | | -VERSIONS = [ |
8 | | - dict(torch="2.8.0", torchvision="0.23.0"), # nightly |
9 | | - dict(torch="2.7.1", torchvision="0.22.1"), # stable |
10 | | - dict(torch="2.7.0", torchvision="0.22.0"), |
11 | | - dict(torch="2.6.0", torchvision="0.21.0"), |
12 | | - dict(torch="2.5.1", torchvision="0.20.1"), |
13 | | - dict(torch="2.5.0", torchvision="0.20.0"), |
14 | | - dict(torch="2.4.1", torchvision="0.19.1"), |
15 | | - dict(torch="2.4.0", torchvision="0.19.0"), |
16 | | -] |
17 | | - |
18 | | - |
19 | | -def find_latest(ver: str) -> Dict[str, str]: |
20 | | - # drop all except semantic version |
21 | | - ver = re.search(r"([\.\d]+)", ver).groups()[0] # type: ignore[union-attr] |
22 | | - # in case there remaining dot at the end - e.g "1.9.0.dev20210504" |
23 | | - ver = ver[:-1] if ver[-1] == "." else ver |
24 | | - print(f"finding ecosystem versions for: {ver}") |
25 | | - |
26 | | - # find first match |
27 | | - for option in VERSIONS: |
28 | | - if option["torch"].startswith(ver): |
29 | | - return option |
30 | | - |
31 | | - raise ValueError(f"Missing {ver} in {VERSIONS}") |
32 | | - |
33 | | - |
34 | | -def main(req: str, torch_version: Optional[str] = None) -> str: |
35 | | - if not torch_version: |
36 | | - import torch |
37 | | - |
38 | | - torch_version = torch.__version__ |
39 | | - assert torch_version, f"invalid torch: {torch_version}" |
40 | | - |
41 | | - # remove comments and strip whitespace |
42 | | - req = re.sub(rf"\s*#.*{os.linesep}", os.linesep, req).strip() |
43 | | - |
44 | | - latest = find_latest(torch_version) |
45 | | - for lib, version in latest.items(): |
46 | | - replace = f"{lib}=={version}" if version else "" |
47 | | - req = re.sub(rf"\b{lib}(?!\w).*", replace, req) |
48 | | - |
49 | | - return req |
50 | | - |
51 | | - |
52 | | -if __name__ == "__main__": |
53 | | - if len(sys.argv) == 3: |
54 | | - requirements_path, torch_version = sys.argv[1:] |
55 | | - else: |
56 | | - requirements_path, torch_version = sys.argv[1], None # type: ignore[assignment] |
57 | | - |
58 | | - with open(requirements_path) as fp: |
59 | | - requirements = fp.read() |
60 | | - requirements = main(requirements, torch_version) |
61 | | - print(requirements) # on purpose - to debug |
62 | | - with open(requirements_path, "w") as fp: |
63 | | - fp.write(requirements) |
| 1 | +# import os |
| 2 | +# import re |
| 3 | +# import sys |
| 4 | +# from typing import Dict, Optional |
| 5 | + |
| 6 | +# # IMPORTANT: this list needs to be sorted in reverse |
| 7 | +# VERSIONS = [ |
| 8 | +# dict(torch="2.8.0", torchvision="0.23.0"), # nightly |
| 9 | +# dict(torch="2.7.1", torchvision="0.22.1"), # stable |
| 10 | +# dict(torch="2.7.0", torchvision="0.22.0"), |
| 11 | +# dict(torch="2.6.0", torchvision="0.21.0"), |
| 12 | +# dict(torch="2.5.1", torchvision="0.20.1"), |
| 13 | +# dict(torch="2.5.0", torchvision="0.20.0"), |
| 14 | +# dict(torch="2.4.1", torchvision="0.19.1"), |
| 15 | +# dict(torch="2.4.0", torchvision="0.19.0"), |
| 16 | +# ] |
| 17 | + |
| 18 | + |
| 19 | +# def find_latest(ver: str) -> Dict[str, str]: |
| 20 | +# # drop all except semantic version |
| 21 | +# ver = re.search(r"([\.\d]+)", ver).groups()[0] # type: ignore[union-attr] |
| 22 | +# # in case there remaining dot at the end - e.g "1.9.0.dev20210504" |
| 23 | +# ver = ver[:-1] if ver[-1] == "." else ver |
| 24 | +# print(f"finding ecosystem versions for: {ver}") |
| 25 | + |
| 26 | +# # find first match |
| 27 | +# for option in VERSIONS: |
| 28 | +# if option["torch"].startswith(ver): |
| 29 | +# return option |
| 30 | + |
| 31 | +# raise ValueError(f"Missing {ver} in {VERSIONS}") |
| 32 | + |
| 33 | + |
| 34 | +# def main(req: str, torch_version: Optional[str] = None) -> str: |
| 35 | +# if not torch_version: |
| 36 | +# import torch |
| 37 | + |
| 38 | +# torch_version = torch.__version__ |
| 39 | +# assert torch_version, f"invalid torch: {torch_version}" |
| 40 | + |
| 41 | +# # remove comments and strip whitespace |
| 42 | +# req = re.sub(rf"\s*#.*{os.linesep}", os.linesep, req).strip() |
| 43 | + |
| 44 | +# latest = find_latest(torch_version) |
| 45 | +# for lib, version in latest.items(): |
| 46 | +# replace = f"{lib}=={version}" if version else "" |
| 47 | +# req = re.sub(rf"\b{lib}(?!\w).*", replace, req) |
| 48 | + |
| 49 | +# return req |
| 50 | + |
| 51 | + |
| 52 | +# if __name__ == "__main__": |
| 53 | +# if len(sys.argv) == 3: |
| 54 | +# requirements_path, torch_version = sys.argv[1:] |
| 55 | +# else: |
| 56 | +# requirements_path, torch_version = sys.argv[1], None # type: ignore[assignment] |
| 57 | + |
| 58 | +# with open(requirements_path) as fp: |
| 59 | +# requirements = fp.read() |
| 60 | +# requirements = main(requirements, torch_version) |
| 61 | +# print(requirements) # on purpose - to debug |
| 62 | +# with open(requirements_path, "w") as fp: |
| 63 | +# fp.write(requirements) |
0 commit comments