Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 53 additions & 77 deletions jax_privacy/keras_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax
import jax.numpy as jnp
import jax_privacy
from jax_privacy import batch_selection
from jax_privacy.accounting import accountants
from jax_privacy.accounting import analysis
from jax_privacy.accounting import calibrate
Expand Down Expand Up @@ -98,6 +99,11 @@ class DPKerasConfig:
`gradient_accumulation_steps` operates outside of the jit boundary.
The default value is None, which means that no microbatching is used,
and is equivalent to microbatch_size=batch_size.
sampling_prob: The probability of sampling an example in each iteration.
If set, uses Poisson sampling instead of deterministic batching. The
dataset must support random access (e.g., in-memory arrays).
padding_multiple: The multiple to pad batches to when using Poisson
sampling, to reduce JIT recompilation.
"""

epsilon: float
Expand All @@ -111,6 +117,8 @@ class DPKerasConfig:
rescale_to_unit_norm: bool = True
microbatch_size: int | None = None
seed: int | None = None
sampling_prob: float | None = None
padding_multiple: int = 32

_accountant = analysis.DpsgdTrainingAccountant(
dp_accountant_config=accountants.PldAccountantConfig(
Expand Down Expand Up @@ -335,78 +343,59 @@ def fit_fn_with_validation(
_validate_optimizer(self, self._dp_params) # pylint: disable=protected-access
fit_signature = inspect.signature(original_fit_fn)

# batch_size is not set explicitely in the fit() call if the input dataset
# is already batched. In this case, we assume that the batch sizes are
# aligned and use the batch size from the DP parameters. We will check that
# the batch sizes are aligned in the train_step function.
batch_size = (
_get_param(fit_signature, 'batch_size', *args, **kwargs)
or params.batch_size
)
# Default values are set according to the Keras documentation.
epochs = _get_param(fit_signature, 'epochs', *args, **kwargs) or 1
initial_epoch = (
_get_param(fit_signature, 'initial_epoch', *args, **kwargs) or 0
)
steps_per_epoch = _get_param(
fit_signature, 'steps_per_epoch', *args, **kwargs
)

# Note accessing self._dp_params is safe because it's added in
# _add_dp_sgd_attributes, but requires disabling pylint because this
# function is not a method within a class.
_check_dp_params_aligned_with_fit_args(
self._dp_params, # pylint: disable=protected-access
batch_size,
)
# For DP training with Poisson sampling, we don't use the standard batching.
# We require x, y to be arrays for random access.

performed_optimizer_steps = (
_get_non_trainable_weight('_optimizer_steps', self).numpy().item()
)
optimizer_steps_to_perform = _calculate_optimizer_steps_to_perform_in_fit(
self._dp_params.train_size, # pylint: disable=protected-access
batch_size,
epochs,
initial_epoch,
steps_per_epoch,
)
if (
performed_optimizer_steps + optimizer_steps_to_perform
> self._dp_params.train_steps # pylint: disable=protected-access
):
remaining_steps = self._dp_params.train_steps - performed_optimizer_steps # pylint: disable=protected-access
if remaining_steps <= 0:
raise RuntimeError(
'fit() cannot be performed because you will run out of privacy'
' budget. Currently, you have already performed'
f' {performed_optimizer_steps} optimizer training steps and you are'
f' trying to perform {optimizer_steps_to_perform} more. However, you'
f' can perform in total only {self._dp_params.train_steps} training' # pylint: disable=protected-access
' steps (optimizer updates). If you fit() the model with current'
' parameters, training steps will exceed the maximum number of'
f' training steps: {performed_optimizer_steps=} +'
f' {optimizer_steps_to_perform=} ='
f' {performed_optimizer_steps + optimizer_steps_to_perform} >'
f' total_train_steps={self._dp_params.train_steps}.' # pylint: disable=protected-access
'No more training steps allowed. The privacy budget has been exhausted.'
f' Performed: {performed_optimizer_steps}, Total allowed: {self._dp_params.train_steps}.' # pylint: disable=protected-access
)
return original_fit_fn(
*args,
**kwargs,
)

return fit_fn_with_validation
def fit_with_dp_poisson_sampling(x=None, y=None, epochs=1, verbose='auto', callbacks=None, **kwargs):
if x is None or not hasattr(x, 'shape'):
raise ValueError(
'DP training requires x to be a JAX or NumPy array for efficient random access during Poisson sampling.'
)
# Use provided sampling_prob or default to batch_size / train_size for expected batch size
sampling_probability = (
self._dp_params.sampling_prob or (self._dp_params.batch_size / self._dp_params.train_size)
)
plan = batch_selection.CyclicPoissonSampling(
sampling_prob=sampling_probability,
iterations=remaining_steps
)
for batch_idx in plan.batch_iterator(self._dp_params.train_size):

# Pad indices to reduce JIT recompilations
idx = batch_selection.pad_to_multiple_of(batch_idx, self._dp_params.padding_multiple)
batch_x = x[idx]
batch_y = y[idx] if y is not None else None
batch_data = (batch_x, batch_y, None)

def _check_dp_params_aligned_with_fit_args(
dp_params: DPKerasConfig,
batch_size: int,
) -> None:
"""Checks that the DP parameters are aligned with the fit() arguments."""
if dp_params.batch_size != batch_size:
raise ValueError(
'The batch size in the DP parameters is not equal to the batch size'
f' passed to fit(): {dp_params.batch_size=} != {batch_size=}. Please'
' make sure that the batch size in the DP parameters is equal to the'
' batch size passed to fit().'
)
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
model_state = (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables)

step_logs, updated_model_state = self.train_step(model_state, batch_data)

self.trainable_variables = updated_model_state[0]
self.non_trainable_variables = updated_model_state[1]
self.optimizer.variables = updated_model_state[2]
self.metrics_variables = updated_model_state[3]

training_history = keras.callbacks.History()
training_history.history = {}
return training_history
return fit_with_dp_poisson_sampling

return fit_fn_with_validation


_XType = chex.ArrayTree
Expand Down Expand Up @@ -470,7 +459,7 @@ def _dp_train_step(

dp_batch_size = self._dp_params.batch_size # pylint: disable=protected-access
actual_batch_size = jax.tree_util.tree_leaves(x)[0].shape[0]
if dp_batch_size != actual_batch_size:
if actual_batch_size < dp_batch_size:
# it is ok to throw an exception even though we are in a jit function
# because the check is based on the static values, i.e. they won't
# change between invocations, and if the condition is violated, it will
Expand Down Expand Up @@ -682,19 +671,6 @@ def _get_non_trainable_weight(
return next(w for w in model.non_trainable_weights if w.name == weight_name)


def _calculate_optimizer_steps_to_perform_in_fit(
train_size: int,
batch_size: int,
epochs: int,
initial_epoch: int,
steps_per_epoch: int,
) -> int:
"""Returns the number of optimizer steps that will be performed by fit."""
epochs_to_perform = epochs - initial_epoch
steps_per_epoch = steps_per_epoch or (train_size // batch_size)
return steps_per_epoch * epochs_to_perform


def _get_random_int64() -> np.int64:
int64_info = np.iinfo(np.int64)
return np.random.randint(
Expand Down