-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Summary
The current implementation of CyclicPoissonSampling.batch_iterator simulates sampling by drawing a batch size and then selecting items via rng.choice. While this yields the correct marginal inclusion probability, it creates negative dependence between items in the same batch (selecting one item reduces the probability of selecting others).
Canonical "Poisson Sampling" in Differential Privacy (DP) literature strictly defines inclusion as independent Bernoulli trials for each item. This independence is a prerequisite for standard privacy accounting (e.g., RDP amplification by sampling). The current implementation violates this assumption.
The Issue
- Current Logic:
- Sample count .
- Select indices uniformly without replacement.
-
Result: If item is selected, it "uses up" one of the slots, slightly reducing the probability that item is selected. This results in .
-
Expected Logic (Standard DP Definition):
- For every item , include it if .
- Result: Decisions are independent, so .
Reproduction & Evidence
I have implemented a statistical test (test_poisson_sampling_marginal_and_pairwise) to verify this behavior:
- Setup: Population , Sampling Prob , 10,000 trials.
- Current Implementation: Pairwise inclusion probability consistently tests significantly lower than , confirming negative dependence.
- Proposed Fix: Pairwise inclusion matches within statistical tolerance.
Proposed Fix
I propose switching to vectorized Bernoulli masking. This is both theoretically correct for DP and more JAX-friendly (better vectorization).
# Proposed change in batch_selection.py
# 1. Independent Bernoulli Trials (Correct Poisson semantics)
mask = rng.random(size=current_group.shape[0]) < self.sampling_prob
selected = current_group[mask]
# 2. Handle Truncation (if applicable)
# Note: Truncation inherently introduces dependence, but this ensures
# the base sampling is correct before that constraint is applied.
if self.truncated_batch_size is not None and selected.size > self.truncated_batch_size:
selected = rng.choice(selected, size=self.truncated_batch_size, replace=False, shuffle=False)
yield selectedPR Status
I have a fix ready that includes:
- The logic change in
batch_selection.py. - A new statistical unit test
test_poisson_sampling_marginal_and_pairwiseverifying marginal and pairwise probabilities. - Verification that existing tests pass.
Labels: bug, privacy, tests