|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -# pylint: disable=todo-style |
17 | 16 | """API for adding DP-SGD to a Keras model. |
18 | 17 |
|
19 | 18 | Example Usage: |
@@ -79,8 +78,7 @@ class DPKerasConfig: |
79 | 78 | noise). You should set this value before training and only based on the |
80 | 79 | privacy guarantees you have to achieve. You should not increase the |
81 | 80 | delta only because of poor model performance. |
82 | | - clipping_norm: The clipping norm for the gradients. TODO: how to choose |
83 | | - it? |
| 81 | + clipping_norm: The clipping norm for the per-example gradients. |
84 | 82 | batch_size: The batch size for the training. |
85 | 83 | gradient_accumulation_steps: The number of gradient accumulation steps. |
86 | 84 | This is the number of batches to accumulate before adding noise and |
@@ -313,7 +311,7 @@ def _validate_model(model: keras.Model) -> None: |
313 | 311 | raise ValueError(f'Model {model} is not a Keras model.') |
314 | 312 | if keras.config.backend() != 'jax': |
315 | 313 | raise ValueError(f'Model {model} must use Jax backend.') |
316 | | - # TODO: Add validation that the model does not contain layers |
| 314 | + # TODO: b/415360727 - Add validation that the model does not contain layers |
317 | 315 | # that are not compatible with DP-SGD, e.g. batch norm. |
318 | 316 |
|
319 | 317 |
|
@@ -556,7 +554,7 @@ def _dp_train_step( |
556 | 554 | ) = self.optimizer.stateless_apply( |
557 | 555 | optimizer_variables, grads, trainable_variables |
558 | 556 | ) |
559 | | - # TODO: access it and update it by name. |
| 557 | + # TODO: b/415360727 - access it and update it by name. |
560 | 558 | non_trainable_variables[1] = non_trainable_variables[1] + 1 |
561 | 559 |
|
562 | 560 | logs, metrics_variables = self._update_metrics_variables( # pylint: disable=protected-access |
@@ -640,7 +638,7 @@ def _noised_clipped_grads( |
640 | 638 | optimizer_variables, |
641 | 639 | metrics_variables, |
642 | 640 | ) = state |
643 | | - # TODO: access it and update it by name. |
| 641 | + # TODO: b/415360727 - access it and update it by name. |
644 | 642 | noise_state = non_trainable_variables[0], () |
645 | 643 | x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) |
646 | 644 |
|
@@ -674,12 +672,10 @@ def _noised_clipped_grads( |
674 | 672 |
|
675 | 673 | noisy_grads, new_noise_state = privatizer.update(clipped_grad, noise_state) |
676 | 674 |
|
677 | | - # TODO: Investigate whether we should return mean or sum here. |
678 | 675 | loss = per_example_aux.values.mean() |
679 | 676 | unscaled_loss = per_example_aux.aux[0].mean() |
680 | 677 | y_pred = per_example_aux.aux[1] |
681 | 678 | non_trainable_variables = [new_noise_state[0]] + non_trainable_variables[1:] |
682 | | - # TODO: Determine the correct way to aggregate metrics. |
683 | 679 | new_metrics = jax.tree.map(lambda x: x.mean(axis=0), per_example_aux.aux[3]) |
684 | 680 |
|
685 | 681 | aux = (unscaled_loss, y_pred, non_trainable_variables, new_metrics) |
|
0 commit comments