Skip to content

Commit 9da768e

Browse files
committed
Add pyrefly type checking: pyrefly.toml config, CI integration, test runner support
Agent-Logs-Url: https://github.com/vfdev-5/flax/sessions/d862656c-0c03-46b4-971e-55ff771b19b9
1 parent f9790b2 commit 9da768e

5 files changed

Lines changed: 47 additions & 2 deletions

File tree

.github/workflows/flax_test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ jobs:
9898
- python-version: '3.12'
9999
test-type: mypy
100100
jax-version: 'newest'
101+
- python-version: '3.12'
102+
test-type: pyrefly
103+
jax-version: 'newest'
101104
steps:
102105
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
103106
- name: Setup uv
@@ -128,6 +131,8 @@ jobs:
128131
uv run --no-sync tests/run_all_tests.sh --only-pytype
129132
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
130133
uv run --no-sync tests/run_all_tests.sh --only-mypy
134+
elif [[ "${{ matrix.test-type }}" == "pyrefly" ]]; then
135+
uv run --no-sync tests/run_all_tests.sh --only-pyrefly
131136
else
132137
echo "Unknown test type: ${{ matrix.test-type }}"
133138
exit 1

flax/linen/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):
189189
inputs, kernel, bias = self.promote_dtype(
190190
inputs, kernel, bias, dtype=self.dtype
191191
)
192+
assert inputs is not None and kernel is not None
192193

193194
if self.dot_general_cls is not None:
194195
dot_general = self.dot_general_cls()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ testing = [
4545
"jraph>=0.0.6dev0",
4646
"ml-collections",
4747
"mypy",
48+
"pyrefly",
4849
"opencv-python",
4950
# Set protobuf version to prevent error in
5051
# examples/mnist/train_test.py::TrainTest::test_train_and_evaluate

pyrefly.toml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Pyrefly configuration - migrated from mypy
2+
# Only type-check flax/linen/linear.py for now; expand as issues are resolved.
3+
project-includes = ["flax/linen/linear.py"]
4+
5+
preset = "legacy"
6+
ignore-missing-imports = [
7+
"tensorflow.*",
8+
"tensorboard.*",
9+
"absl.*",
10+
"jax.*",
11+
"rich.*",
12+
"jaxlib.cuda.*",
13+
"jaxlib.cpu.*",
14+
"msgpack",
15+
"numpy.*",
16+
"optax.*",
17+
"orbax.*",
18+
"opt_einsum.*",
19+
"scipy.*",
20+
"libtpu.*",
21+
"jaxlib.mlir.*",
22+
"yaml",
23+
]
24+
25+
[errors]
26+
missing-attribute = "ignore"

tests/run_all_tests.sh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ PYTEST_OPTS=
55
RUN_DOCTEST=false
66
RUN_MYPY=false
77
RUN_PYTEST=false
8+
RUN_PYREFLY=false
89
RUN_PYTYPE=false
910
GH_VENV=false
1011

@@ -30,6 +31,9 @@ case $flag in
3031
--only-mypy)
3132
RUN_MYPY=true
3233
;;
34+
--only-pyrefly)
35+
RUN_PYREFLY=true
36+
;;
3337
--use-venv)
3438
GH_VENV=true
3539
;;
@@ -40,12 +44,13 @@ case $flag in
4044
esac
4145
done
4246

43-
# if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy is set, run all tests
44-
if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY; then
47+
# if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy, --only-pyrefly is set, run all tests
48+
if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY && ! $RUN_PYREFLY; then
4549
RUN_DOCTEST=true
4650
RUN_PYTEST=true
4751
RUN_PYTYPE=true
4852
RUN_MYPY=true
53+
RUN_PYREFLY=true
4954
fi
5055

5156
# Activate cached virtual env for github CI
@@ -58,6 +63,7 @@ echo "PYTEST_OPTS: $PYTEST_OPTS"
5863
echo "RUN_DOCTEST: $RUN_DOCTEST"
5964
echo "RUN_PYTEST: $RUN_PYTEST"
6065
echo "RUN_MYPY: $RUN_MYPY"
66+
echo "RUN_PYREFLY: $RUN_PYREFLY"
6167
echo "RUN_PYTYPE: $RUN_PYTYPE"
6268
echo "GH_VENV: $GH_VENV"
6369
echo "WHICH PYTHON: $(which python)"
@@ -155,5 +161,11 @@ if $RUN_MYPY; then
155161
mypy --config pyproject.toml flax/ --show-error-codes
156162
fi
157163

164+
if $RUN_PYREFLY; then
165+
echo "=== RUNNING PYREFLY ==="
166+
# Type-check using pyrefly.toml (currently scoped to flax/linen/linear.py).
167+
pyrefly check
168+
fi
169+
158170
# Return error code 0 if no real failures happened.
159171
echo "finished all tests."

0 commit comments

Comments
 (0)