diff --git a/jax_privacy/keras_api.py b/jax_privacy/keras_api.py index cc78eae..b3e67bd 100644 --- a/jax_privacy/keras_api.py +++ b/jax_privacy/keras_api.py @@ -25,6 +25,7 @@ import jax import jax.numpy as jnp import jax_privacy +from jax_privacy import batch_selection from jax_privacy.accounting import accountants from jax_privacy.accounting import analysis from jax_privacy.accounting import calibrate @@ -98,6 +99,10 @@ class DPKerasConfig: `gradient_accumulation_steps` operates outside of the jit boundary. The default value is None, which means that no microbatching is used, and is equivalent to microbatch_size=batch_size. + padding_multiple: The multiple to pad batches to when using Poisson + sampling, to reduce JIT recompilation. Note: If microbatch_size is + specified, then padding_multiple % microbatch_size == 0 should be true + for optimal performance. """ epsilon: float @@ -111,6 +116,7 @@ class DPKerasConfig: rescale_to_unit_norm: bool = True microbatch_size: int | None = None seed: int | None = None + padding_multiple: int = 32 _accountant = analysis.DpsgdTrainingAccountant( dp_accountant_config=accountants.PldAccountantConfig( @@ -177,6 +183,11 @@ def _validate_params(self) -> None: f'Gradient accumulation steps {self.gradient_accumulation_steps} must' ' be positive.' ) + if self.microbatch_size is not None and self.padding_multiple % self.microbatch_size != 0: + raise ValueError( + f'padding_multiple ({self.padding_multiple}) must be divisible by ' + f'microbatch_size ({self.microbatch_size}) when microbatch_size is specified.' + ) if self.noise_multiplier is not None: if self.noise_multiplier <= 0: raise ValueError( @@ -325,7 +336,6 @@ def _create_fit_fn_with_validation( The fit function with same signature as original_fit_fn but with validation for DP-SGD training. """ - @functools.wraps(original_fit_fn) def fit_fn_with_validation( self, @@ -335,60 +345,130 @@ def fit_fn_with_validation( _validate_optimizer(self, self._dp_params) # pylint: disable=protected-access fit_signature = inspect.signature(original_fit_fn) - # batch_size is not set explicitely in the fit() call if the input dataset - # is already batched. In this case, we assume that the batch sizes are - # aligned and use the batch size from the DP parameters. We will check that - # the batch sizes are aligned in the train_step function. + # batch_size is not set explicitely in the fit() call if the input dataset + # is already batched. In this case, we assume that the batch sizes are + # aligned and use the batch size from the DP parameters. We will check that + # the batch sizes are aligned in the train_step function. batch_size = ( - _get_param(fit_signature, 'batch_size', *args, **kwargs) - or params.batch_size - ) - # Default values are set according to the Keras documentation. + _get_param(fit_signature, 'batch_size', *args, **kwargs) + or params.batch_size + ) + # Default values are set according to the Keras documentation. epochs = _get_param(fit_signature, 'epochs', *args, **kwargs) or 1 initial_epoch = ( - _get_param(fit_signature, 'initial_epoch', *args, **kwargs) or 0 - ) + _get_param(fit_signature, 'initial_epoch', *args, **kwargs) or 0 + ) steps_per_epoch = _get_param( - fit_signature, 'steps_per_epoch', *args, **kwargs - ) + fit_signature, 'steps_per_epoch', *args, **kwargs + ) - # Note accessing self._dp_params is safe because it's added in - # _add_dp_sgd_attributes, but requires disabling pylint because this - # function is not a method within a class. + # Note accessing self._dp_params is safe because it's added in + # _add_dp_sgd_attributes, but requires disabling pylint because this + # function is not a method within a class. _check_dp_params_aligned_with_fit_args( - self._dp_params, # pylint: disable=protected-access - batch_size, - ) + self._dp_params, # pylint: disable=protected-access + batch_size, + ) performed_optimizer_steps = ( - _get_non_trainable_weight('_optimizer_steps', self).numpy().item() - ) + _get_non_trainable_weight('_optimizer_steps', self).numpy().item() + ) optimizer_steps_to_perform = _calculate_optimizer_steps_to_perform_in_fit( - self._dp_params.train_size, # pylint: disable=protected-access - batch_size, - epochs, - initial_epoch, - steps_per_epoch, - ) + self._dp_params.train_size, # pylint: disable=protected-access + batch_size, + epochs, + initial_epoch, + steps_per_epoch, + ) if ( - performed_optimizer_steps + optimizer_steps_to_perform - > self._dp_params.train_steps # pylint: disable=protected-access - ): - raise RuntimeError( - 'fit() cannot be performed because you will run out of privacy' - ' budget. Currently, you have already performed' - f' {performed_optimizer_steps} optimizer training steps and you are' - f' trying to perform {optimizer_steps_to_perform} more. However, you' - f' can perform in total only {self._dp_params.train_steps} training' # pylint: disable=protected-access - ' steps (optimizer updates). If you fit() the model with current' - ' parameters, training steps will exceed the maximum number of' - f' training steps: {performed_optimizer_steps=} +' - f' {optimizer_steps_to_perform=} =' - f' {performed_optimizer_steps + optimizer_steps_to_perform} >' - f' total_train_steps={self._dp_params.train_steps}.' # pylint: disable=protected-access + performed_optimizer_steps + optimizer_steps_to_perform + > self._dp_params.train_steps # pylint: disable=protected-access + ): + raise RuntimeError( + 'fit() cannot be performed because you will run out of privacy' + ' budget. Currently, you have already performed' + f' {performed_optimizer_steps} optimizer training steps and you are' + f' trying to perform {optimizer_steps_to_perform} more. However, you' + f' can perform in total only {self._dp_params.train_steps} training' # pylint: disable=protected-access + ' steps (optimizer updates). If you fit() the model with current' + ' parameters, training steps will exceed the maximum number of' + f' training steps: {performed_optimizer_steps=} +' + f' {optimizer_steps_to_perform=} =' + f' {performed_optimizer_steps + optimizer_steps_to_perform} >' + f' total_train_steps={self._dp_params.train_steps}.' # pylint: disable=protected-access + ) + + clipped_grad_fn = jax_privacy.clipped_grad( + fun=self.compute_loss_and_updates, + has_aux=True, + return_values=True, + l2_clip_norm=self._dp_params.clipping_norm, + rescale_to_unit_norm=self._dp_params.rescale_to_unit_norm, + normalize_by=self._dp_params.batch_size, + batch_argnums=(3, 4, 5), # corresponding to (x, y, sample_weight) + microbatch_size=self._dp_params.microbatch_size, + ) + sampling_probability = self._dp_params.batch_size / self._dp_params.train_size # pylint: disable=protected-access + + strategy = batch_selection.CyclicPoissonSampling( # pylint: disable=protected-access + sampling_prob=sampling_probability, + iterations=self._dp_params.train_steps - performed_optimizer_steps, + ) + + def fit_with_dp_poisson_sampling(*args, x=None, y=None, is_padding_example): + if not args: + raise ValueError("Missing training data.") + + x = args[0] + + # TO-DO add cllipped_grad_fn usage here + + if not hasattr(x, "shape"): + raise ValueError( + "DP Poisson sampling requires x to be a JAX or NumPy array " + "for efficient random access." ) + # Append a dummy example for padding + dummy_x = jnp.zeros_like(x[:1]) + x = jnp.concatenate([x, dummy_x], axis=0) + if y is not None: + dummy_y = jnp.zeros_like(y[:1]) + y = jnp.concatenate([y, dummy_y], axis=0) + + for batch_idx in strategy.batch_iterator(self._dp_params.train_size): # pylint: disable=protected-access + + # Pad indices to reduce JIT recompilations + idx = batch_selection.pad_to_multiple_of(batch_idx, self._dp_params.padding_multiple) + is_padding_example = idx == -1 + batch_x = x[idx] + batch_y = y[idx] if y is not None else None + batch_data = (batch_x, batch_y, None) + + trainable_variables = self.trainable_variables + non_trainable_variables = self.non_trainable_variables + optimizer_variables = self.optimizer_variables + metrics_variables = self.metrics_variables + model_state = (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables) + + updated_model_state = _dp_train_step( + self, + model_state, + batch_data, + ) + ( + self.trainable_variables, + self.non_trainable_variables, + self.optimizer_variables, + self.metrics_variables, + ) = updated_model_state + + training_history = keras.callbacks.History() + training_history.history = {} + return fit_with_dp_poisson_sampling + return original_fit_fn( - *args, + self, + *args, **kwargs, )