Skip to content

Commit 1d917a6

Browse files
committed
lintrunner: enable pyre
1 parent b626d17 commit 1d917a6

15 files changed

+274
-67
lines changed

.github/workflows/docs.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
build:
11-
runs-on: ubuntu-20.04
11+
runs-on: ubuntu-latest
1212
steps:
1313
- name: Setup Python
1414
uses: actions/setup-python@v3

.github/workflows/lint.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
lint:
11-
runs-on: ubuntu-20.04
11+
runs-on: ubuntu-latest
1212
steps:
1313
- name: Setup Python
1414
uses: actions/setup-python@v3
@@ -31,6 +31,8 @@ jobs:
3131
run: |
3232
set -eux
3333
34+
pyre check
35+
3436
lintrunner --force-color --all-files
3537
- name: Run Rust Lint
3638
run: |

.lintrunner.toml

+23
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,26 @@ command = [
4242
'--',
4343
'@{{PATHSFILE}}',
4444
]
45+
46+
[[linter]]
47+
code = 'PYRE'
48+
include_patterns = [
49+
'**/*.py',
50+
'**/*.pyi',
51+
]
52+
command = [
53+
'python3',
54+
'tools/linter/adapters/pyre_linter.py',
55+
'--',
56+
'@{{PATHSFILE}}'
57+
]
58+
init_command = [
59+
'python',
60+
'-m',
61+
'lintrunner_adapters',
62+
'run',
63+
'pip_init',
64+
'--dry-run={{DRYRUN}}',
65+
'pyre-check==0.9.23',
66+
]
67+
is_formatter = false

.watchmanconfig

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"root_files": [
3+
"torchft",
4+
"*.py",
5+
".pyre_configuration",
6+
".watchmanconfig"
7+
]
8+
}

tools/linter/adapters/pyre_linter.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import argparse
2+
import concurrent.futures
3+
import json
4+
import logging
5+
import os
6+
import subprocess
7+
import sys
8+
from enum import Enum
9+
from pathlib import Path
10+
from typing import Any, List, NamedTuple, Optional, Set, TypedDict
11+
12+
logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
class LintSeverity(str, Enum):
16+
ERROR = "error"
17+
WARNING = "warning"
18+
ADVICE = "advice"
19+
DISABLED = "disabled"
20+
21+
22+
class LintMessage(NamedTuple):
23+
path: Optional[str]
24+
line: Optional[int]
25+
char: Optional[int]
26+
code: str
27+
severity: LintSeverity
28+
name: str
29+
original: Optional[str]
30+
replacement: Optional[str]
31+
description: Optional[str]
32+
33+
34+
class PyreResult(TypedDict):
35+
line: int
36+
column: int
37+
stop_line: int
38+
stop_column: int
39+
path: str
40+
code: int
41+
name: str
42+
description: str
43+
concise_description: str
44+
45+
46+
def run_pyre() -> List[PyreResult]:
47+
proc = subprocess.run(
48+
["pyre", "--output=json", "incremental"],
49+
capture_output=True,
50+
)
51+
return json.loads(proc.stdout)
52+
53+
54+
def check_pyre(
55+
filenames: Set[str],
56+
) -> List[LintMessage]:
57+
try:
58+
results = run_pyre()
59+
60+
return [
61+
LintMessage(
62+
path=result["path"],
63+
line=result["line"],
64+
char=result["column"],
65+
code="pyre",
66+
severity=LintSeverity.WARNING,
67+
name=result["name"],
68+
description=result["description"],
69+
original=None,
70+
replacement=None,
71+
)
72+
for result in results
73+
]
74+
except Exception as err:
75+
return [
76+
LintMessage(
77+
path=None,
78+
line=None,
79+
char=None,
80+
code="pyre",
81+
severity=LintSeverity.ADVICE,
82+
name="command-failed",
83+
original=None,
84+
replacement=None,
85+
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
86+
)
87+
]
88+
89+
90+
def main() -> None:
91+
parser = argparse.ArgumentParser(
92+
description="Checks files with pyre",
93+
fromfile_prefix_chars="@",
94+
)
95+
parser.add_argument(
96+
"--verbose",
97+
action="store_true",
98+
help="verbose logging",
99+
)
100+
parser.add_argument(
101+
"filenames",
102+
nargs="+",
103+
help="paths to lint",
104+
)
105+
args = parser.parse_args()
106+
107+
logging.basicConfig(
108+
format="<%(processName)s:%(levelname)s> %(message)s",
109+
level=(
110+
logging.NOTSET
111+
if args.verbose
112+
else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
113+
),
114+
stream=sys.stderr,
115+
)
116+
117+
lint_messages = check_pyre(set(args.filenames))
118+
119+
for lint_message in lint_messages:
120+
print(json.dumps(lint_message._asdict()), flush=True)
121+
122+
123+
if __name__ == "__main__":
124+
main()

torchft/checkpointing.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,19 @@
1616
import socket
1717
import threading
1818
import urllib.request
19-
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
20-
from typing import Callable
19+
from http.server import BaseHTTPRequestHandler
20+
from typing import Callable, Generic, TypeVar
2121

2222
import torch
2323

24-
logger: logging.Logger = logging.getLogger(__name__)
24+
from torchft.http import _IPv6HTTPServer
2525

26+
logger: logging.Logger = logging.getLogger(__name__)
2627

27-
class _IPv6HTTPServer(ThreadingHTTPServer):
28-
address_family = socket.AF_INET6
29-
request_queue_size = 1024
28+
T = TypeVar("T")
3029

3130

32-
class CheckpointServer:
31+
class CheckpointServer(Generic[T]):
3332
"""
3433
This is an HTTP server that can be used to transfer checkpoints
3534
between workers.
@@ -41,7 +40,7 @@ class CheckpointServer:
4140
state_dict: a callable that returns the state dict to be transferred
4241
"""
4342

44-
def __init__(self, state_dict: Callable[[], object]) -> None:
43+
def __init__(self, state_dict: Callable[[], T]) -> None:
4544
self._checkpoint_lock = threading.Lock()
4645
self._disallowed = False
4746
self._step = -1
@@ -88,7 +87,7 @@ def err(self, msg: str) -> None:
8887
self._thread.start()
8988

9089
@classmethod
91-
def load_from_address(cls, address: str) -> object:
90+
def load_from_address(cls, address: str) -> T:
9291
"""
9392
Loads a checkpoint from the given address.
9493

torchft/data.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __init__(
4747
dataset: data.Dataset,
4848
replica_group: int,
4949
num_replica_groups: int,
50-
*args,
5150
rank: Optional[int] = None,
5251
num_replicas: Optional[int] = None,
5352
**kwargs,
@@ -69,5 +68,8 @@ def __init__(
6968
self.global_world_size = num_replicas * num_replica_groups
7069

7170
super().__init__(
72-
dataset, *args, rank=self.global_rank, num_replicas=self.global_world_size
71+
dataset,
72+
rank=self.global_rank,
73+
num_replicas=self.global_world_size,
74+
**kwargs,
7375
)

torchft/ddp_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]:
4444

4545
call_count += 1
4646

47-
fut = Future()
47+
fut = Future() # pyre-fixme[29]: not a function
4848
fut.set_result(tensor)
4949
return fut
5050

torchft/http.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import socket
2+
from http.server import ThreadingHTTPServer
3+
4+
5+
class _IPv6HTTPServer(ThreadingHTTPServer):
6+
address_family = socket.AF_INET6
7+
request_queue_size = 1024

0 commit comments

Comments
 (0)