-
Notifications
You must be signed in to change notification settings - Fork 36
Fix poisson sampling #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix poisson sampling #122
Conversation
…mpling; add statistical tests
…endent Bernoulli semantics
amyssnippet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while this fix aligns with the theoretical definition of poisson sampling, using a boolean mask creates dynamic shapes how does this impact performance when the batch_iterator is used inside a jax.jit decorated training loop
| # items (this preserves independent selection semantics before | ||
| # truncation). | ||
| if ( | ||
| self.truncated_batch_size is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the truncation logic negate the privacy benefits of the independent bernoulli trials?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does but it still provides benefits, see https://arxiv.org/abs/2508.15089
|
these changes might respect the pov of math, but it might break the code performance can you check and do a stress test of this for a sample model training, and the performance diff, before and after the changes?? |
… tests for padding
amyssnippet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe some pylint errors to solve in tests/batch_selection_test.py:239:0
amyssnippet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
@ryan112358 take a loot at this, as you are a veteran to the project |
|
@ryan112358 can you please review the changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution, and for critically evaluating the implementations of our core components. This is super crucial work to harden our library for real privacy applications where correctness is the first priority. I will assign a member of our team to take a closer look. But a couple high-level comments:
- Producing variable batch sizes is expected and fine here since we are working with pure numpy in this file. The batches can later be padded to a fixed set of sizes for efficiency under jax.jit.
- The current implementation is more efficient though, since it has O(B) complexity per step rather than O(N).
- [minor] Being able to know the batch sizes in advance, before sampling the exact examples could be useful downstream applications so we can e.g., pre compile our train step for the batch sizes we expect to see.
Of course, these points are moot if the current implementation is actually incorrect, so we will get back to you on that.
|
Thanks @ryan112358 for the review and the context. I agree that the shift from O(B) to O(N) is a performance cost, but as you noted, it seems unavoidable to satisfy the strict independence requirement of Poisson sampling. The previous method's negative dependence effectively leaked information between samples, which risks invalidating RDP accounting. Regarding the point on pre-compilation: since the batch size is inherently stochastic in Poisson sampling, we lose the ability to know it deterministically in advance, but the |
|
Hi, I ran your test_poisson_sampling_marginal_and_pairwise without changing any of the code outside the test file and it did not fail. I'm not convinced the current implementation is incorrect. The events that two different examples are sampled are pairwise dependent conditioned on knowing the value of k, but without conditioning they remain independent. See Lemma 1 of https://arxiv.org/pdf/2406.17298v3 for a proof that the current implementation is distributionally equivalent to Poisson sampling. We will add a reference to this proof as a comment for clarity. Feel free to send a PR adding the test and also adding the option to pad the batch (though I think the fixed_batch_size and truncated_batch_size args are somewhat redundant here). Or, if you disagree with the above paper or think the current implementation has a bug that is not inherent to the sampling technique, feel free to reopen. |
|
@arung54 Understood. I reviewed the lemma and see how the variance in |
Summary
This PR fixes the implementation of
CyclicPoissonSampling.batch_iteratorto strictly align with the canonical definition of Poisson sampling used in Differential Privacy (independent Bernoulli trials).Motivation
The previous implementation used a two-step process:
While this yields the correct marginal probability, it introduces negative dependence between items in the same batch (selecting one item reduces the probability of others being selected). This violation of independence can invalidate privacy accounting assumptions (e.g., RDP amplification by sampling).
Changes
batch_selection.pyto use a vectorized Bernoulli mask (rng.random() < p). This guarantees that item inclusions are statistically independent.test_poisson_sampling_marginal_and_pairwisetotests/batch_selection_test.py. This test statistically validates that:Verification
test_poisson_sampling_marginal_and_pairwise, which now passes (previously failed with negative dependence).Related Issues
Fixes #121