Skip to content

Commit 2ef1454

Browse files
Ryan McKennacopybara-github
authored andcommitted
Fix bad todo style in keras_api.py, and update pylintrc file to catch incosistent quotes issues.
PiperOrigin-RevId: 861299909
1 parent d41d335 commit 2ef1454

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

.pylintrc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ unsafe-load-any-extension=no
3535
extension-pkg-whitelist=
3636

3737

38+
[STRING]
39+
40+
# This flag controls whether inconsistent-quotes generates a warning when the
41+
# character used as a quote delimiter is used inconsistently within a module.
42+
check-quote-consistency=yes
43+
3844
[MESSAGES CONTROL]
3945

4046
# Only show warnings with the listed confidence levels. Leave empty to show

jax_privacy/keras_api.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# pylint: disable=todo-style
1716
"""API for adding DP-SGD to a Keras model.
1817
1918
Example Usage:
@@ -79,8 +78,7 @@ class DPKerasConfig:
7978
noise). You should set this value before training and only based on the
8079
privacy guarantees you have to achieve. You should not increase the
8180
delta only because of poor model performance.
82-
clipping_norm: The clipping norm for the gradients. TODO: how to choose
83-
it?
81+
clipping_norm: The clipping norm for the per-example gradients.
8482
batch_size: The batch size for the training.
8583
gradient_accumulation_steps: The number of gradient accumulation steps.
8684
This is the number of batches to accumulate before adding noise and
@@ -313,7 +311,7 @@ def _validate_model(model: keras.Model) -> None:
313311
raise ValueError(f'Model {model} is not a Keras model.')
314312
if keras.config.backend() != 'jax':
315313
raise ValueError(f'Model {model} must use Jax backend.')
316-
# TODO: Add validation that the model does not contain layers
314+
# TODO: b/415360727 - Add validation that the model does not contain layers
317315
# that are not compatible with DP-SGD, e.g. batch norm.
318316

319317

@@ -556,7 +554,7 @@ def _dp_train_step(
556554
) = self.optimizer.stateless_apply(
557555
optimizer_variables, grads, trainable_variables
558556
)
559-
# TODO: access it and update it by name.
557+
# TODO: b/415360727 - access it and update it by name.
560558
non_trainable_variables[1] = non_trainable_variables[1] + 1
561559

562560
logs, metrics_variables = self._update_metrics_variables( # pylint: disable=protected-access
@@ -640,7 +638,7 @@ def _noised_clipped_grads(
640638
optimizer_variables,
641639
metrics_variables,
642640
) = state
643-
# TODO: access it and update it by name.
641+
# TODO: b/415360727 - access it and update it by name.
644642
noise_state = non_trainable_variables[0], ()
645643
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
646644

@@ -674,12 +672,10 @@ def _noised_clipped_grads(
674672

675673
noisy_grads, new_noise_state = privatizer.update(clipped_grad, noise_state)
676674

677-
# TODO: Investigate whether we should return mean or sum here.
678675
loss = per_example_aux.values.mean()
679676
unscaled_loss = per_example_aux.aux[0].mean()
680677
y_pred = per_example_aux.aux[1]
681678
non_trainable_variables = [new_noise_state[0]] + non_trainable_variables[1:]
682-
# TODO: Determine the correct way to aggregate metrics.
683679
new_metrics = jax.tree.map(lambda x: x.mean(axis=0), per_example_aux.aux[3])
684680

685681
aux = (unscaled_loss, y_pred, non_trainable_variables, new_metrics)

0 commit comments

Comments
 (0)