Skip to content

Commit bc40c9a

Browse files
authored
Enable pyrefly on all libs (#2425)
1 parent bbc4de1 commit bc40c9a

10 files changed

Lines changed: 1275 additions & 1461 deletions

File tree

.pyrefly-baseline.json

Lines changed: 1237 additions & 1441 deletions
Large diffs are not rendered by default.

infra/pre-commit.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,14 @@ class PrecommitConfig:
571571
],
572572
),
573573
PrecommitConfig(
574-
patterns=["lib/marin/src/**/*.py", "lib/levanter/src/**/*.py"],
574+
patterns=[
575+
"lib/marin/src/**/*.py",
576+
"lib/levanter/src/**/*.py",
577+
"lib/haliax/src/**/*.py",
578+
"lib/fray/src/**/*.py",
579+
"lib/iris/src/**/*.py",
580+
"lib/zephyr/src/**/*.py",
581+
],
575582
checks=[
576583
check_pyrefly,
577584
],

lib/fray/src/fray/cluster/ray/cluster.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Ray-based cluster implementation."""
1616

17+
from typing import Any, cast
18+
1719
import asyncio
1820
import logging
1921
import os
@@ -307,8 +309,8 @@ def _get_runtime_env(self, request: JobRequest) -> dict | None:
307309
logger.info("Ray runtime env: %s", runtime_env)
308310
return runtime_env
309311

310-
def _get_entrypoint_params(self, request: JobRequest) -> dict:
311-
params = {}
312+
def _get_entrypoint_params(self, request: JobRequest) -> dict[str, Any]:
313+
params: dict[str, Any] = {}
312314

313315
if request.resources.cpu > 0:
314316
params["entrypoint_num_cpus"] = float(request.resources.cpu)
@@ -333,7 +335,7 @@ def monitor(self, job_id: JobId) -> JobInfo:
333335
job = self._jobs[job_id]
334336
if job.submission_id is None:
335337
logger.info("Job is a remote ref, monitoring is automatic, waiting.")
336-
return self.wait(job_id)
338+
return cast(JobInfo, self.wait(job_id))
337339

338340
async def stream_logs():
339341
async for line in self._job_client().tail_job_logs(job_id):

lib/fray/src/fray/job/context.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import logging
2222
import os
2323
import threading
24-
from collections.abc import Callable, Generator
24+
from collections.abc import Callable, Generator, Iterator
2525
from concurrent.futures import Future, ThreadPoolExecutor, wait
2626
from contextlib import contextmanager
2727
from contextvars import ContextVar
@@ -126,7 +126,7 @@ class ActorHandle:
126126
Provides a unified interface for calling actor methods with .remote() and .call().
127127
"""
128128

129-
def __getattr__(self, method_name: str):
129+
def __getattr__(self, method_name: str) -> "ActorMethod":
130130
"""Get a callable method wrapper for the actor."""
131131
raise NotImplementedError
132132

@@ -150,7 +150,7 @@ def __init__(self, instance: Any, lock: threading.Lock, context):
150150
self._lock = lock # Serializes all method calls
151151
self._context = context
152152

153-
def __getattr__(self, method_name: str):
153+
def __getattr__(self, method_name: str) -> "ThreadActorMethod":
154154
method = getattr(self._instance, method_name)
155155
return ThreadActorMethod(method, self._lock, self._context)
156156

@@ -183,7 +183,7 @@ class _ImmediateFuture:
183183

184184
def __init__(self, result: Any):
185185
self._result = result
186-
self._iterator = None
186+
self._iterator: Iterator[Any] | None = None
187187

188188
def result(self) -> Any:
189189
return self._result
@@ -206,7 +206,7 @@ class GeneratorFuture:
206206

207207
def __init__(self, future: Future):
208208
self._future = future
209-
self._iterator = None
209+
self._iterator: Iterator[Any] | None = None
210210

211211
def result(self) -> Any:
212212
"""Get the underlying result from the future."""
@@ -422,7 +422,7 @@ def create_actor(
422422
num_cpus: float | None = None,
423423
**kwargs,
424424
) -> ActorHandle:
425-
options = {}
425+
options: dict[str, Any] = {}
426426
if name is not None:
427427
options["name"] = name
428428
options["get_if_exists"] = get_if_exists

lib/fray/src/fray/queue/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, host: str = "0.0.0.0", port: int = 9999):
7777

7878
config = uvicorn.Config(self.app, host=host, port=port, log_level="error", access_log=False)
7979
self.server = uvicorn.Server(config)
80-
self.server_thread = None
80+
self.server_thread: ServerThread | None = None
8181

8282
def _create_app(self) -> FastAPI:
8383
"""Create FastAPI app with namespaced queue endpoints."""

lib/iris/src/iris/cluster/worker/worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,11 @@ def __init__(
124124
self._config = config
125125

126126
# Setup cache directory
127+
self._temp_dir: tempfile.TemporaryDirectory[str] | None = None
127128
if cache_dir:
128129
self._cache_dir = cache_dir
129-
self._temp_dir = None
130130
elif config.cache_dir:
131131
self._cache_dir = config.cache_dir
132-
self._temp_dir = None
133132
else:
134133
# Create temporary cache
135134
self._temp_dir = tempfile.TemporaryDirectory(prefix="worker_cache_")

lib/zephyr/src/zephyr/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import re
2121
from collections.abc import Callable, Iterable, Iterator
2222
from dataclasses import dataclass, field
23-
from typing import Any, Generic, Literal, TypeVar
23+
from typing import Any, Generic, Literal, TypeVar, cast
2424

2525
import fsspec
2626
from braceexpand import braceexpand
@@ -830,7 +830,7 @@ def reduce(
830830
4950
831831
"""
832832
if global_reducer is None:
833-
global_reducer = local_reducer
833+
global_reducer = cast(Callable[[Iterator[R]], R], local_reducer)
834834

835835
return Dataset(self.source, [*self.operations, ReduceOp(local_reducer, global_reducer)])
836836

lib/zephyr/src/zephyr/plan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
WindowOp,
5252
WriteOp,
5353
)
54+
from zephyr.expr import Expr
5455
from zephyr.readers import InputFileSpec
5556

5657
if TYPE_CHECKING:
@@ -512,7 +513,7 @@ def _compute_file_pushdown(
512513
Returns:
513514
Tuple of (source_items, remaining_operations), where filter/select have been pushed down.
514515
"""
515-
filter_expr = None
516+
filter_expr: Expr | None = None
516517
select_columns = load_op.columns
517518
ops_to_skip: set[int] = set()
518519

@@ -882,7 +883,7 @@ class StageContext:
882883
total_shards: int
883884
chunk_size: int
884885
aux_shards: dict[int, list[Any]] = field(default_factory=dict)
885-
execution_context: JobContext = None
886+
execution_context: JobContext | None = None
886887

887888
def get_right_shard(self, op_index: int) -> Any:
888889
"""Get right shard for join at given op index.

lib/zephyr/src/zephyr/writers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from dataclasses import asdict, dataclass, is_dataclass
19+
from dataclasses import asdict, is_dataclass
2020
import itertools
2121
import os
2222
from collections.abc import Iterable
@@ -129,7 +129,7 @@ def infer_parquet_type(value):
129129
return pa.string()
130130

131131

132-
def infer_parquet_schema(record: dict | dataclass):
132+
def infer_parquet_schema(record: dict[str, Any] | Any):
133133
"""Infer PyArrow schema from a dictionary record."""
134134
import pyarrow as pa
135135

pyproject.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,14 @@ exclude = ["marin/", "scripts/"]
8181

8282
# Pyrefly type checker configuration
8383
[tool.pyrefly]
84-
project-includes = ["lib/marin/src/**/*.py", "lib/levanter/src/**/*.py"]
84+
project-includes = [
85+
"lib/marin/src/**/*.py",
86+
"lib/levanter/src/**/*.py",
87+
"lib/haliax/src/**/*.py",
88+
"lib/fray/src/**/*.py",
89+
"lib/iris/src/**/*.py",
90+
"lib/zephyr/src/**/*.py",
91+
]
8592

8693
# Explicitly tell Pyrefly where our editable packages live so it resolves imports
8794
# against the library sources instead of the top-level `src` directory, which only
@@ -92,6 +99,7 @@ search-path = [
9299
"lib/haliax/src",
93100
"lib/zephyr/src",
94101
"lib/fray/src",
102+
"lib/iris/src",
95103
]
96104
disable-search-path-heuristics = true
97105

@@ -107,6 +115,7 @@ project-excludes = [
107115
"examples/**", # Example code doesn't need strict typing
108116
"lib/**/crawl/**", # Crawl scripts have library typing issues with smart_open
109117
"lib/marin/src/marin/processing/classification/deduplication/vendor", # Exclude vendor stuff in dedupe
118+
"lib/iris/src/iris/rpc/*_pb2*", # Generated protobuf files
110119
]
111120

112121
# Disable specific error codes that are primarily noise from missing type stubs

0 commit comments

Comments
 (0)