Skip to content

Commit 576b2b4

Browse files
Ryan McKennacopybara-github
authored andcommitted
Add lint checks on examples/ and tests/, resolving or ignoring specific errors as as necessary. Also add pydoctest to the set of static analysis checks to github continuous integration.
PiperOrigin-RevId: 860166125
1 parent edf9b3b commit 576b2b4

File tree

6 files changed

+18
-15
lines changed

6 files changed

+18
-15
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,20 @@ jobs:
2525
run: |
2626
python -m pip install --upgrade pip
2727
pip install build wheel
28-
pip install pylint pytype flake8 pylint-exit
28+
pip install pylint pytype flake8 pylint-exit pydocstyle
2929
pip install -e .[dev]
3030
pip install -r docs/requirements.txt
3131
pip install crc32c
3232
- name: Run flake8
3333
run: flake8 jax_privacy tests examples
34+
- name: Run pydocstyle
35+
run: |
36+
pydocstyle --convention=google --add-ignore=D101,D102,D103,D105,D202,D402 jax_privacy/
3437
- name: Run pylint
3538
run: |
36-
pylint --ignore=auditing.py --rcfile=.pylintrc jax_privacy || pylint-exit -efail -wfail -cfail -rfail $?
37-
# pylint --rcfile=.pylintrc examples || pylint-exit -efail -wfail -cfail -rfail $?
38-
# pylint --rcfile=.pylintrc tests -d W0212,C0114 || pylint-exit -efail -wfail -cfail -rfail $?
39+
pylint jax_privacy || pylint-exit -efail -wfail -cfail -rfail $?
40+
pylint examples || pylint-exit -efail -wfail -cfail -rfail $?
41+
pylint tests -d W0212,C0114 || pylint-exit -efail -wfail -cfail -rfail $?
3942
shell: bash
4043
- name: Run pytype
4144
run: |

examples/distributed_noise_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def run(pytree_like_model_params):
162162
t0 = time.time()
163163
compiled_run = run.lower(model_params).compile()
164164
t1 = time.time()
165-
print('[BandMF] Compilation time: %.3f seconds' % (t1 - t0))
165+
print(f'[BandMF] Compilation time: {t1-t0:.3f} seconds')
166166
state, noisy_grad = jax.block_until_ready(compiled_run(model_params))
167167
t2 = time.time()
168-
print('[BandMF] Per-step run time: %.3f seconds' % ((t2 - t1) / steps))
168+
print(f'[BandMF] Per-step run time: {(t2-t1)/steps:.3f} seconds')
169169

170170
return state, noisy_grad
171171

examples/keras_api_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from absl import app
2424

2525
os.environ["KERAS_BACKEND"] = "jax"
26-
from jax_privacy import keras_api # pylint: disable=g-import-not-at-top
26+
# pylint: disable=g-import-not-at-top,wrong-import-position
27+
from jax_privacy import keras_api
2728
import keras
2829
from keras import layers
2930
import numpy as np

jax_privacy/matrix_factorization/buffered_toeplitz.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def check_float64_dtype(blt: 'BufferedToeplitz'):
7979

8080
@dataclasses.dataclass(kw_only=True, frozen=True)
8181
class StreamingMatrixBuilder:
82-
"""Builder to convert a BLT to a StreamingMatrix.
83-
"""
82+
"""Builder to convert a BLT to a StreamingMatrix."""
8483

8584
buf_decay: np.ndarray
8685
output_scale: np.ndarray

tests/clipping_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def cartesian_product(**kwargs):
3131

3232
PYTREE_STRUCTS = [
3333
jax.ShapeDtypeStruct(shape=(5, 5), dtype=jnp.float32),
34-
dict(
35-
a=jax.ShapeDtypeStruct(shape=(10,), dtype=jnp.float16),
36-
b=jax.ShapeDtypeStruct(shape=(5, 5), dtype=jnp.float32),
37-
c=jax.ShapeDtypeStruct(shape=(), dtype=jnp.bfloat16),
38-
),
34+
{
35+
'a': jax.ShapeDtypeStruct(shape=(10,), dtype=jnp.float16),
36+
'b': jax.ShapeDtypeStruct(shape=(5, 5), dtype=jnp.float32),
37+
'c': jax.ShapeDtypeStruct(shape=(), dtype=jnp.bfloat16),
38+
},
3939
]
4040

4141

tests/experimental/accounting_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
PLD_ACCOUNTANT = functools.partial(
5252
dp_accounting.pld.PLDAccountant, value_discretization_interval=1e-3
5353
)
54-
RDP_ACCOUNTANT = dp_accounting.rdp.RdpAccountant
54+
RDP_ACCOUNTANT = dp_accounting.rdp.RdpAccountant # pylint: disable=invalid-name
5555

5656

5757
def _make_test_case(event_fn, accountant):

0 commit comments

Comments
 (0)