Skip to content

Commit 4a36567

Browse files
committed
Fix pre-commit to run on all files in CI.
1 parent 2b55bd5 commit 4a36567

File tree

5 files changed

+7
-4
lines changed

5 files changed

+7
-4
lines changed

.github/workflows/ci-build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
with:
4040
path: ~/.cache/pre-commit
4141
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
42-
- run: pre-commit run --show-diff-on-failure --color=always
42+
- run: pre-commit run --show-diff-on-failure --color=always --all-files
4343

4444
build:
4545
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"

jax/_src/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
11731173
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
11741174
batch_xs, batch_devs, batch_shardings, batch_cs)
11751175
else:
1176-
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # type: ignore
1176+
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # pytype: disable=missing-parameter
11771177
batch_xs, batch_devs, batch_shardings)
11781178
for i, copy_out in safe_zip(batch_indices, copy_outs):
11791179
assert results[i] is None

jax/_src/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import os
2323
import sys
2424
import threading
25-
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING
25+
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast
2626

2727
from jax._src import lib
2828
from jax._src.lib import guard_lib
@@ -371,7 +371,7 @@ class _Unset: pass
371371

372372
_thread_local_state = threading.local()
373373

374-
class State(Generic[_T]):
374+
class State(Generic[_T]): # type: ignore[no-redef]
375375

376376
__slots__ = (
377377
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from jaxlib.mlir.dialects import nvvm
3030
from .utils import c, memref_ptr, single_thread_predicate
3131

32+
# mypy: ignore-errors
33+
3234

3335
MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]
3436

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dataclasses
2020
import functools
2121
import math
22+
from collections.abc import Callable
2223
from typing import Iterable, Sequence, TypeVar
2324

2425
import jax

0 commit comments

Comments
 (0)