Skip to content

Conversation

@Neerajpathak07
Copy link
Contributor

Fixes:- #84

What is the purpose of this pull request?

This pull request:-

  • Refactors keras_api.py implementation to use Poisson Sampling from in-house batch_selection parity.

  • Updates the data class examples for DP-SGD to a Keras model.

  • Removes the previous implementation of deterministic fixed-batch sampling and creates a new function to use Poisson sampling..
    Ref:-

    return original_fit_fn(
         *args,
         **kwargs,
     )

    With

     return fit_with_dp_poisson_sampling

@Neerajpathak07
Copy link
Contributor Author

@ryan112358 This was my first core PR to the repository so tried to adhere to the code conventions as much as I could XD

Copy link
Collaborator

@ryan112358 ryan112358 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for the contribution! Left a few review comments

@Neerajpathak07
Copy link
Contributor Author

@ryan112358 Have added a few fixes addressing the comments. But, all-in-all I am having a few doubts on resolving the CI error!!

# is already batched. In this case, we assume that the batch sizes are
# aligned and use the batch size from the DP parameters. We will check that
# the batch sizes are aligned in the train_step function.
# batch_size is not set explicitely in the fit() call if the input dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert these formatting changes?

iterations=self._dp_params.train_steps - performed_optimizer_steps,
)

def fit_with_dp_poisson_sampling(*args, x=None, y=None, is_padding_example):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_padding_example is not used, either here or below. It should be wired through to clipped_grad_fn

metrics_variables = self.metrics_variables
model_state = (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables)

updated_model_state = _dp_train_step(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something, but I don't see where this function is defined

"""

@functools.wraps(original_fit_fn)
def fit_fn_with_validation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make Poisson sampling the default (which requires in-memory arrays as input), we have to make sure that all usages of this API provides the data in the expected format. Examples that provide data in another format either have to be modified, or Poisson sampling should be optional. Do you know what format examples of this API use for the data?

and is equivalent to microbatch_size=batch_size.
padding_multiple: The multiple to pad batches to when using Poisson
sampling, to reduce JIT recompilation. Note: If microbatch_size is
specified, then padding_multiple % microbatch_size == 0 should be true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you write "must" as below you throw an exception if it is not true.

iterations=self._dp_params.train_steps - performed_optimizer_steps,
)

def fit_with_dp_poisson_sampling(*args, x=None, y=None, is_padding_example):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new function fit_with_dp_poisson_sampling and the associated logic for Poisson sampling appears to be on the right track, but contains several issues and seems to be unused.

The function is defined but never called. The original original_fit_fn is called on L507, so this new logic is currently dead code.

There's a return fit_with_dp_poisson_sampling statement (L505) inside the for loop. This would cause the loop to terminate after the first iteration and return the function object itself, which is likely not the intended behavior.

The is_padding_example parameter in the function signature is shadowed by the assignment on L480.

As noted by the TODO on L462, the clipped_grad_fn is not used.

Note: line numbers might be off

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a WIP for now Let me go ahead and draft this!!

@Neerajpathak07 Neerajpathak07 marked this pull request as draft January 24, 2026 10:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants