-
Notifications
You must be signed in to change notification settings - Fork 36
bench: refactor keras_api.py to use Poisson Sampling
#89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
bench: refactor keras_api.py to use Poisson Sampling
#89
Conversation
|
@ryan112358 This was my first core PR to the repository so tried to adhere to the code conventions as much as I could XD |
ryan112358
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for the contribution! Left a few review comments
|
@ryan112358 Have added a few fixes addressing the comments. But, all-in-all I am having a few doubts on resolving the CI error!! |
| # is already batched. In this case, we assume that the batch sizes are | ||
| # aligned and use the batch size from the DP parameters. We will check that | ||
| # the batch sizes are aligned in the train_step function. | ||
| # batch_size is not set explicitely in the fit() call if the input dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you revert these formatting changes?
| iterations=self._dp_params.train_steps - performed_optimizer_steps, | ||
| ) | ||
|
|
||
| def fit_with_dp_poisson_sampling(*args, x=None, y=None, is_padding_example): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_padding_example is not used, either here or below. It should be wired through to clipped_grad_fn
| metrics_variables = self.metrics_variables | ||
| model_state = (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables) | ||
|
|
||
| updated_model_state = _dp_train_step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm missing something, but I don't see where this function is defined
| """ | ||
|
|
||
| @functools.wraps(original_fit_fn) | ||
| def fit_fn_with_validation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make Poisson sampling the default (which requires in-memory arrays as input), we have to make sure that all usages of this API provides the data in the expected format. Examples that provide data in another format either have to be modified, or Poisson sampling should be optional. Do you know what format examples of this API use for the data?
| and is equivalent to microbatch_size=batch_size. | ||
| padding_multiple: The multiple to pad batches to when using Poisson | ||
| sampling, to reduce JIT recompilation. Note: If microbatch_size is | ||
| specified, then padding_multiple % microbatch_size == 0 should be true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you write "must" as below you throw an exception if it is not true.
| iterations=self._dp_params.train_steps - performed_optimizer_steps, | ||
| ) | ||
|
|
||
| def fit_with_dp_poisson_sampling(*args, x=None, y=None, is_padding_example): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new function fit_with_dp_poisson_sampling and the associated logic for Poisson sampling appears to be on the right track, but contains several issues and seems to be unused.
The function is defined but never called. The original original_fit_fn is called on L507, so this new logic is currently dead code.
There's a return fit_with_dp_poisson_sampling statement (L505) inside the for loop. This would cause the loop to terminate after the first iteration and return the function object itself, which is likely not the intended behavior.
The is_padding_example parameter in the function signature is shadowed by the assignment on L480.
As noted by the TODO on L462, the clipped_grad_fn is not used.
Note: line numbers might be off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a WIP for now Let me go ahead and draft this!!
Fixes:- #84
This pull request:-
Refactors
keras_api.pyimplementation to use Poisson Sampling from in-housebatch_selectionparity.Updates the data class examples for DP-SGD to a Keras model.
Removes the previous implementation of deterministic fixed-batch sampling and creates a new function to use Poisson sampling..
Ref:-
With