-
Notifications
You must be signed in to change notification settings - Fork 36
bench: refactor keras_api.py to use Poisson Sampling
#89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9a3017e
2d9f2d4
d7e3aea
be8307d
a461ad2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| import jax | ||
| import jax.numpy as jnp | ||
| import jax_privacy | ||
| from jax_privacy import batch_selection | ||
| from jax_privacy.accounting import accountants | ||
| from jax_privacy.accounting import analysis | ||
| from jax_privacy.accounting import calibrate | ||
|
|
@@ -98,6 +99,10 @@ class DPKerasConfig: | |
| `gradient_accumulation_steps` operates outside of the jit boundary. | ||
| The default value is None, which means that no microbatching is used, | ||
| and is equivalent to microbatch_size=batch_size. | ||
| padding_multiple: The multiple to pad batches to when using Poisson | ||
| sampling, to reduce JIT recompilation. Note: If microbatch_size is | ||
| specified, then padding_multiple % microbatch_size == 0 should be true | ||
| for optimal performance. | ||
| """ | ||
|
|
||
| epsilon: float | ||
|
|
@@ -111,6 +116,7 @@ class DPKerasConfig: | |
| rescale_to_unit_norm: bool = True | ||
| microbatch_size: int | None = None | ||
| seed: int | None = None | ||
| padding_multiple: int = 32 | ||
|
|
||
| _accountant = analysis.DpsgdTrainingAccountant( | ||
| dp_accountant_config=accountants.PldAccountantConfig( | ||
|
|
@@ -177,6 +183,11 @@ def _validate_params(self) -> None: | |
| f'Gradient accumulation steps {self.gradient_accumulation_steps} must' | ||
| ' be positive.' | ||
| ) | ||
| if self.microbatch_size is not None and self.padding_multiple % self.microbatch_size != 0: | ||
| raise ValueError( | ||
| f'padding_multiple ({self.padding_multiple}) must be divisible by ' | ||
| f'microbatch_size ({self.microbatch_size}) when microbatch_size is specified.' | ||
| ) | ||
| if self.noise_multiplier is not None: | ||
| if self.noise_multiplier <= 0: | ||
| raise ValueError( | ||
|
|
@@ -325,7 +336,6 @@ def _create_fit_fn_with_validation( | |
| The fit function with same signature as original_fit_fn but with validation | ||
| for DP-SGD training. | ||
| """ | ||
|
|
||
| @functools.wraps(original_fit_fn) | ||
| def fit_fn_with_validation( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we make Poisson sampling the default (which requires in-memory arrays as input), we have to make sure that all usages of this API provides the data in the expected format. Examples that provide data in another format either have to be modified, or Poisson sampling should be optional. Do you know what format examples of this API use for the data? |
||
| self, | ||
|
|
@@ -335,60 +345,130 @@ def fit_fn_with_validation( | |
| _validate_optimizer(self, self._dp_params) # pylint: disable=protected-access | ||
| fit_signature = inspect.signature(original_fit_fn) | ||
|
|
||
| # batch_size is not set explicitely in the fit() call if the input dataset | ||
| # is already batched. In this case, we assume that the batch sizes are | ||
| # aligned and use the batch size from the DP parameters. We will check that | ||
| # the batch sizes are aligned in the train_step function. | ||
| # batch_size is not set explicitely in the fit() call if the input dataset | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you revert these formatting changes? |
||
| # is already batched. In this case, we assume that the batch sizes are | ||
| # aligned and use the batch size from the DP parameters. We will check that | ||
| # the batch sizes are aligned in the train_step function. | ||
| batch_size = ( | ||
| _get_param(fit_signature, 'batch_size', *args, **kwargs) | ||
| or params.batch_size | ||
| ) | ||
| # Default values are set according to the Keras documentation. | ||
| _get_param(fit_signature, 'batch_size', *args, **kwargs) | ||
| or params.batch_size | ||
| ) | ||
| # Default values are set according to the Keras documentation. | ||
| epochs = _get_param(fit_signature, 'epochs', *args, **kwargs) or 1 | ||
| initial_epoch = ( | ||
| _get_param(fit_signature, 'initial_epoch', *args, **kwargs) or 0 | ||
| ) | ||
| _get_param(fit_signature, 'initial_epoch', *args, **kwargs) or 0 | ||
| ) | ||
| steps_per_epoch = _get_param( | ||
| fit_signature, 'steps_per_epoch', *args, **kwargs | ||
| ) | ||
| fit_signature, 'steps_per_epoch', *args, **kwargs | ||
| ) | ||
|
|
||
| # Note accessing self._dp_params is safe because it's added in | ||
| # _add_dp_sgd_attributes, but requires disabling pylint because this | ||
| # function is not a method within a class. | ||
| # Note accessing self._dp_params is safe because it's added in | ||
| # _add_dp_sgd_attributes, but requires disabling pylint because this | ||
| # function is not a method within a class. | ||
| _check_dp_params_aligned_with_fit_args( | ||
| self._dp_params, # pylint: disable=protected-access | ||
| batch_size, | ||
| ) | ||
| self._dp_params, # pylint: disable=protected-access | ||
| batch_size, | ||
| ) | ||
|
|
||
| performed_optimizer_steps = ( | ||
| _get_non_trainable_weight('_optimizer_steps', self).numpy().item() | ||
| ) | ||
| _get_non_trainable_weight('_optimizer_steps', self).numpy().item() | ||
| ) | ||
| optimizer_steps_to_perform = _calculate_optimizer_steps_to_perform_in_fit( | ||
| self._dp_params.train_size, # pylint: disable=protected-access | ||
| batch_size, | ||
| epochs, | ||
| initial_epoch, | ||
| steps_per_epoch, | ||
| ) | ||
| self._dp_params.train_size, # pylint: disable=protected-access | ||
| batch_size, | ||
| epochs, | ||
| initial_epoch, | ||
| steps_per_epoch, | ||
| ) | ||
| if ( | ||
| performed_optimizer_steps + optimizer_steps_to_perform | ||
| > self._dp_params.train_steps # pylint: disable=protected-access | ||
| ): | ||
| raise RuntimeError( | ||
| 'fit() cannot be performed because you will run out of privacy' | ||
| ' budget. Currently, you have already performed' | ||
| f' {performed_optimizer_steps} optimizer training steps and you are' | ||
| f' trying to perform {optimizer_steps_to_perform} more. However, you' | ||
| f' can perform in total only {self._dp_params.train_steps} training' # pylint: disable=protected-access | ||
| ' steps (optimizer updates). If you fit() the model with current' | ||
| ' parameters, training steps will exceed the maximum number of' | ||
| f' training steps: {performed_optimizer_steps=} +' | ||
| f' {optimizer_steps_to_perform=} =' | ||
| f' {performed_optimizer_steps + optimizer_steps_to_perform} >' | ||
| f' total_train_steps={self._dp_params.train_steps}.' # pylint: disable=protected-access | ||
| performed_optimizer_steps + optimizer_steps_to_perform | ||
| > self._dp_params.train_steps # pylint: disable=protected-access | ||
| ): | ||
| raise RuntimeError( | ||
| 'fit() cannot be performed because you will run out of privacy' | ||
| ' budget. Currently, you have already performed' | ||
| f' {performed_optimizer_steps} optimizer training steps and you are' | ||
| f' trying to perform {optimizer_steps_to_perform} more. However, you' | ||
| f' can perform in total only {self._dp_params.train_steps} training' # pylint: disable=protected-access | ||
| ' steps (optimizer updates). If you fit() the model with current' | ||
| ' parameters, training steps will exceed the maximum number of' | ||
| f' training steps: {performed_optimizer_steps=} +' | ||
| f' {optimizer_steps_to_perform=} =' | ||
| f' {performed_optimizer_steps + optimizer_steps_to_perform} >' | ||
| f' total_train_steps={self._dp_params.train_steps}.' # pylint: disable=protected-access | ||
| ) | ||
|
|
||
| clipped_grad_fn = jax_privacy.clipped_grad( | ||
| fun=self.compute_loss_and_updates, | ||
| has_aux=True, | ||
| return_values=True, | ||
| l2_clip_norm=self._dp_params.clipping_norm, | ||
| rescale_to_unit_norm=self._dp_params.rescale_to_unit_norm, | ||
| normalize_by=self._dp_params.batch_size, | ||
| batch_argnums=(3, 4, 5), # corresponding to (x, y, sample_weight) | ||
| microbatch_size=self._dp_params.microbatch_size, | ||
| ) | ||
| sampling_probability = self._dp_params.batch_size / self._dp_params.train_size # pylint: disable=protected-access | ||
|
|
||
| strategy = batch_selection.CyclicPoissonSampling( # pylint: disable=protected-access | ||
| sampling_prob=sampling_probability, | ||
| iterations=self._dp_params.train_steps - performed_optimizer_steps, | ||
| ) | ||
|
|
||
| def fit_with_dp_poisson_sampling(*args, x=None, y=None, is_padding_example): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is_padding_example is not used, either here or below. It should be wired through to clipped_grad_fn
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new function fit_with_dp_poisson_sampling and the associated logic for Poisson sampling appears to be on the right track, but contains several issues and seems to be unused. The function is defined but never called. The original original_fit_fn is called on L507, so this new logic is currently dead code. There's a return fit_with_dp_poisson_sampling statement (L505) inside the for loop. This would cause the loop to terminate after the first iteration and return the function object itself, which is likely not the intended behavior. The is_padding_example parameter in the function signature is shadowed by the assignment on L480. As noted by the TODO on L462, the clipped_grad_fn is not used. Note: line numbers might be off
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is a WIP for now Let me go ahead and draft this!! |
||
| if not args: | ||
| raise ValueError("Missing training data.") | ||
|
|
||
| x = args[0] | ||
|
|
||
| # TO-DO add cllipped_grad_fn usage here | ||
|
|
||
| if not hasattr(x, "shape"): | ||
| raise ValueError( | ||
| "DP Poisson sampling requires x to be a JAX or NumPy array " | ||
| "for efficient random access." | ||
| ) | ||
| # Append a dummy example for padding | ||
| dummy_x = jnp.zeros_like(x[:1]) | ||
| x = jnp.concatenate([x, dummy_x], axis=0) | ||
| if y is not None: | ||
| dummy_y = jnp.zeros_like(y[:1]) | ||
| y = jnp.concatenate([y, dummy_y], axis=0) | ||
|
|
||
| for batch_idx in strategy.batch_iterator(self._dp_params.train_size): # pylint: disable=protected-access | ||
|
|
||
| # Pad indices to reduce JIT recompilations | ||
| idx = batch_selection.pad_to_multiple_of(batch_idx, self._dp_params.padding_multiple) | ||
Neerajpathak07 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| is_padding_example = idx == -1 | ||
| batch_x = x[idx] | ||
| batch_y = y[idx] if y is not None else None | ||
| batch_data = (batch_x, batch_y, None) | ||
|
|
||
| trainable_variables = self.trainable_variables | ||
| non_trainable_variables = self.non_trainable_variables | ||
| optimizer_variables = self.optimizer_variables | ||
| metrics_variables = self.metrics_variables | ||
| model_state = (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables) | ||
|
|
||
| updated_model_state = _dp_train_step( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I'm missing something, but I don't see where this function is defined |
||
| self, | ||
| model_state, | ||
| batch_data, | ||
| ) | ||
| ( | ||
| self.trainable_variables, | ||
| self.non_trainable_variables, | ||
| self.optimizer_variables, | ||
| self.metrics_variables, | ||
| ) = updated_model_state | ||
|
|
||
| training_history = keras.callbacks.History() | ||
| training_history.history = {} | ||
| return fit_with_dp_poisson_sampling | ||
|
|
||
| return original_fit_fn( | ||
| *args, | ||
| self, | ||
| *args, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you write "must" as below you throw an exception if it is not true.