Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 123 additions & 43 deletions jax_privacy/keras_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax
import jax.numpy as jnp
import jax_privacy
from jax_privacy import batch_selection
from jax_privacy.accounting import accountants
from jax_privacy.accounting import analysis
from jax_privacy.accounting import calibrate
Expand Down Expand Up @@ -98,6 +99,10 @@ class DPKerasConfig:
`gradient_accumulation_steps` operates outside of the jit boundary.
The default value is None, which means that no microbatching is used,
and is equivalent to microbatch_size=batch_size.
padding_multiple: The multiple to pad batches to when using Poisson
sampling, to reduce JIT recompilation. Note: If microbatch_size is
specified, then padding_multiple % microbatch_size == 0 should be true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you write "must" as below you throw an exception if it is not true.

for optimal performance.
"""

epsilon: float
Expand All @@ -111,6 +116,7 @@ class DPKerasConfig:
rescale_to_unit_norm: bool = True
microbatch_size: int | None = None
seed: int | None = None
padding_multiple: int = 32

_accountant = analysis.DpsgdTrainingAccountant(
dp_accountant_config=accountants.PldAccountantConfig(
Expand Down Expand Up @@ -177,6 +183,11 @@ def _validate_params(self) -> None:
f'Gradient accumulation steps {self.gradient_accumulation_steps} must'
' be positive.'
)
if self.microbatch_size is not None and self.padding_multiple % self.microbatch_size != 0:
raise ValueError(
f'padding_multiple ({self.padding_multiple}) must be divisible by '
f'microbatch_size ({self.microbatch_size}) when microbatch_size is specified.'
)
if self.noise_multiplier is not None:
if self.noise_multiplier <= 0:
raise ValueError(
Expand Down Expand Up @@ -325,7 +336,6 @@ def _create_fit_fn_with_validation(
The fit function with same signature as original_fit_fn but with validation
for DP-SGD training.
"""

@functools.wraps(original_fit_fn)
def fit_fn_with_validation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make Poisson sampling the default (which requires in-memory arrays as input), we have to make sure that all usages of this API provides the data in the expected format. Examples that provide data in another format either have to be modified, or Poisson sampling should be optional. Do you know what format examples of this API use for the data?

self,
Expand All @@ -335,60 +345,130 @@ def fit_fn_with_validation(
_validate_optimizer(self, self._dp_params) # pylint: disable=protected-access
fit_signature = inspect.signature(original_fit_fn)

# batch_size is not set explicitely in the fit() call if the input dataset
# is already batched. In this case, we assume that the batch sizes are
# aligned and use the batch size from the DP parameters. We will check that
# the batch sizes are aligned in the train_step function.
# batch_size is not set explicitely in the fit() call if the input dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert these formatting changes?

# is already batched. In this case, we assume that the batch sizes are
# aligned and use the batch size from the DP parameters. We will check that
# the batch sizes are aligned in the train_step function.
batch_size = (
_get_param(fit_signature, 'batch_size', *args, **kwargs)
or params.batch_size
)
# Default values are set according to the Keras documentation.
_get_param(fit_signature, 'batch_size', *args, **kwargs)
or params.batch_size
)
# Default values are set according to the Keras documentation.
epochs = _get_param(fit_signature, 'epochs', *args, **kwargs) or 1
initial_epoch = (
_get_param(fit_signature, 'initial_epoch', *args, **kwargs) or 0
)
_get_param(fit_signature, 'initial_epoch', *args, **kwargs) or 0
)
steps_per_epoch = _get_param(
fit_signature, 'steps_per_epoch', *args, **kwargs
)
fit_signature, 'steps_per_epoch', *args, **kwargs
)

# Note accessing self._dp_params is safe because it's added in
# _add_dp_sgd_attributes, but requires disabling pylint because this
# function is not a method within a class.
# Note accessing self._dp_params is safe because it's added in
# _add_dp_sgd_attributes, but requires disabling pylint because this
# function is not a method within a class.
_check_dp_params_aligned_with_fit_args(
self._dp_params, # pylint: disable=protected-access
batch_size,
)
self._dp_params, # pylint: disable=protected-access
batch_size,
)

performed_optimizer_steps = (
_get_non_trainable_weight('_optimizer_steps', self).numpy().item()
)
_get_non_trainable_weight('_optimizer_steps', self).numpy().item()
)
optimizer_steps_to_perform = _calculate_optimizer_steps_to_perform_in_fit(
self._dp_params.train_size, # pylint: disable=protected-access
batch_size,
epochs,
initial_epoch,
steps_per_epoch,
)
self._dp_params.train_size, # pylint: disable=protected-access
batch_size,
epochs,
initial_epoch,
steps_per_epoch,
)
if (
performed_optimizer_steps + optimizer_steps_to_perform
> self._dp_params.train_steps # pylint: disable=protected-access
):
raise RuntimeError(
'fit() cannot be performed because you will run out of privacy'
' budget. Currently, you have already performed'
f' {performed_optimizer_steps} optimizer training steps and you are'
f' trying to perform {optimizer_steps_to_perform} more. However, you'
f' can perform in total only {self._dp_params.train_steps} training' # pylint: disable=protected-access
' steps (optimizer updates). If you fit() the model with current'
' parameters, training steps will exceed the maximum number of'
f' training steps: {performed_optimizer_steps=} +'
f' {optimizer_steps_to_perform=} ='
f' {performed_optimizer_steps + optimizer_steps_to_perform} >'
f' total_train_steps={self._dp_params.train_steps}.' # pylint: disable=protected-access
performed_optimizer_steps + optimizer_steps_to_perform
> self._dp_params.train_steps # pylint: disable=protected-access
):
raise RuntimeError(
'fit() cannot be performed because you will run out of privacy'
' budget. Currently, you have already performed'
f' {performed_optimizer_steps} optimizer training steps and you are'
f' trying to perform {optimizer_steps_to_perform} more. However, you'
f' can perform in total only {self._dp_params.train_steps} training' # pylint: disable=protected-access
' steps (optimizer updates). If you fit() the model with current'
' parameters, training steps will exceed the maximum number of'
f' training steps: {performed_optimizer_steps=} +'
f' {optimizer_steps_to_perform=} ='
f' {performed_optimizer_steps + optimizer_steps_to_perform} >'
f' total_train_steps={self._dp_params.train_steps}.' # pylint: disable=protected-access
)

clipped_grad_fn = jax_privacy.clipped_grad(
fun=self.compute_loss_and_updates,
has_aux=True,
return_values=True,
l2_clip_norm=self._dp_params.clipping_norm,
rescale_to_unit_norm=self._dp_params.rescale_to_unit_norm,
normalize_by=self._dp_params.batch_size,
batch_argnums=(3, 4, 5), # corresponding to (x, y, sample_weight)
microbatch_size=self._dp_params.microbatch_size,
)
sampling_probability = self._dp_params.batch_size / self._dp_params.train_size # pylint: disable=protected-access

strategy = batch_selection.CyclicPoissonSampling( # pylint: disable=protected-access
sampling_prob=sampling_probability,
iterations=self._dp_params.train_steps - performed_optimizer_steps,
)

def fit_with_dp_poisson_sampling(*args, x=None, y=None, is_padding_example):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_padding_example is not used, either here or below. It should be wired through to clipped_grad_fn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new function fit_with_dp_poisson_sampling and the associated logic for Poisson sampling appears to be on the right track, but contains several issues and seems to be unused.

The function is defined but never called. The original original_fit_fn is called on L507, so this new logic is currently dead code.

There's a return fit_with_dp_poisson_sampling statement (L505) inside the for loop. This would cause the loop to terminate after the first iteration and return the function object itself, which is likely not the intended behavior.

The is_padding_example parameter in the function signature is shadowed by the assignment on L480.

As noted by the TODO on L462, the clipped_grad_fn is not used.

Note: line numbers might be off

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a WIP for now Let me go ahead and draft this!!

if not args:
raise ValueError("Missing training data.")

x = args[0]

# TO-DO add cllipped_grad_fn usage here

if not hasattr(x, "shape"):
raise ValueError(
"DP Poisson sampling requires x to be a JAX or NumPy array "
"for efficient random access."
)
# Append a dummy example for padding
dummy_x = jnp.zeros_like(x[:1])
x = jnp.concatenate([x, dummy_x], axis=0)
if y is not None:
dummy_y = jnp.zeros_like(y[:1])
y = jnp.concatenate([y, dummy_y], axis=0)

for batch_idx in strategy.batch_iterator(self._dp_params.train_size): # pylint: disable=protected-access

# Pad indices to reduce JIT recompilations
idx = batch_selection.pad_to_multiple_of(batch_idx, self._dp_params.padding_multiple)
is_padding_example = idx == -1
batch_x = x[idx]
batch_y = y[idx] if y is not None else None
batch_data = (batch_x, batch_y, None)

trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer_variables
metrics_variables = self.metrics_variables
model_state = (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables)

updated_model_state = _dp_train_step(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something, but I don't see where this function is defined

self,
model_state,
batch_data,
)
(
self.trainable_variables,
self.non_trainable_variables,
self.optimizer_variables,
self.metrics_variables,
) = updated_model_state

training_history = keras.callbacks.History()
training_history.history = {}
return fit_with_dp_poisson_sampling

return original_fit_fn(
*args,
self,
*args,
**kwargs,
)

Expand Down
Loading