Skip to content

Conversation

@MaheshThakur9152
Copy link

Summary

This PR fixes the implementation of CyclicPoissonSampling.batch_iterator to strictly align with the canonical definition of Poisson sampling used in Differential Privacy (independent Bernoulli trials).

Motivation

The previous implementation used a two-step process:

  1. Sample a batch size $k \sim \text{Binomial}(n, p)$.
  2. Select $k$ items uniformly without replacement.

While this yields the correct marginal probability, it introduces negative dependence between items in the same batch (selecting one item reduces the probability of others being selected). This violation of independence can invalidate privacy accounting assumptions (e.g., RDP amplification by sampling).

Changes

  • Logic Update: Refactored batch_selection.py to use a vectorized Bernoulli mask (rng.random() < p). This guarantees that item inclusions are statistically independent.
  • New Test: Added test_poisson_sampling_marginal_and_pairwise to tests/batch_selection_test.py. This test statistically validates that:
    • Marginal inclusion probability $\approx p$.
    • Pairwise inclusion probability $\approx p^2$ (confirming independence).

Verification

  • Ran the new statistical test test_poisson_sampling_marginal_and_pairwise, which now passes (previously failed with negative dependence).
  • Verified that existing tests pass to ensure no regressions.

Related Issues

Fixes #121

Copy link
Contributor

@amyssnippet amyssnippet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while this fix aligns with the theoretical definition of poisson sampling, using a boolean mask creates dynamic shapes how does this impact performance when the batch_iterator is used inside a jax.jit decorated training loop

# items (this preserves independent selection semantics before
# truncation).
if (
self.truncated_batch_size is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the truncation logic negate the privacy benefits of the independent bernoulli trials?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does but it still provides benefits, see https://arxiv.org/abs/2508.15089

@amyssnippet
Copy link
Contributor

these changes might respect the pov of math, but it might break the code performance

can you check and do a stress test of this for a sample model training, and the performance diff, before and after the changes??

Copy link
Contributor

@amyssnippet amyssnippet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe some pylint errors to solve in tests/batch_selection_test.py:239:0

Copy link
Contributor

@amyssnippet amyssnippet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@amyssnippet
Copy link
Contributor

@ryan112358 take a loot at this, as you are a veteran to the project

@MaheshThakur9152
Copy link
Author

@ryan112358 can you please review the changes

Copy link
Collaborator

@ryan112358 ryan112358 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution, and for critically evaluating the implementations of our core components. This is super crucial work to harden our library for real privacy applications where correctness is the first priority. I will assign a member of our team to take a closer look. But a couple high-level comments:

  1. Producing variable batch sizes is expected and fine here since we are working with pure numpy in this file. The batches can later be padded to a fixed set of sizes for efficiency under jax.jit.
  2. The current implementation is more efficient though, since it has O(B) complexity per step rather than O(N).
  3. [minor] Being able to know the batch sizes in advance, before sampling the exact examples could be useful downstream applications so we can e.g., pre compile our train step for the batch sizes we expect to see.

Of course, these points are moot if the current implementation is actually incorrect, so we will get back to you on that.

@MaheshThakur9152
Copy link
Author

Thanks @ryan112358 for the review and the context.

I agree that the shift from O(B) to O(N) is a performance cost, but as you noted, it seems unavoidable to satisfy the strict independence requirement of Poisson sampling. The previous method's negative dependence effectively leaked information between samples, which risks invalidating RDP accounting.

Regarding the point on pre-compilation: since the batch size is inherently stochastic in Poisson sampling, we lose the ability to know it deterministically in advance, but the fixed_batch_size option I added should help bridge that gap for downstream JIT usage.

@arung54
Copy link
Collaborator

arung54 commented Jan 27, 2026

Hi, I ran your test_poisson_sampling_marginal_and_pairwise without changing any of the code outside the test file and it did not fail. I'm not convinced the current implementation is incorrect.

The events that two different examples are sampled are pairwise dependent conditioned on knowing the value of k, but without conditioning they remain independent. See Lemma 1 of https://arxiv.org/pdf/2406.17298v3 for a proof that the current implementation is distributionally equivalent to Poisson sampling. We will add a reference to this proof as a comment for clarity.

Feel free to send a PR adding the test and also adding the option to pad the batch (though I think the fixed_batch_size and truncated_batch_size args are somewhat redundant here). Or, if you disagree with the above paper or think the current implementation has a bug that is not inherent to the sampling technique, feel free to reopen.

@arung54 arung54 closed this Jan 27, 2026
@MaheshThakur9152
Copy link
Author

@arung54 Understood. I reviewed the lemma and see how the variance in k cancels the negative dependence. My baseline test configuration must have been incorrect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix: CyclicPoissonSampling should use independent Bernoulli trials (correct Poisson semantics)

4 participants