Skip to content

Commit 00a1e22

Browse files
Ryan McKennacopybara-github
authored andcommitted
Add minimum separation batch selection strategy.
PiperOrigin-RevId: 837270564
1 parent 2ddd6d9 commit 00a1e22

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

jax_privacy/batch_selection.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,55 @@ def batch_iterator(
291291
yield groups[i % self.cycle_length]
292292

293293

294+
@dataclasses.dataclass(frozen=True)
295+
class MinimumSeparationSampling(BatchSelectionStrategy):
296+
"""Implements minimum separation sampling.
297+
298+
This batch selection strategy simulates constraints common in federated
299+
learning scenarios.
300+
301+
Formal Guarantees of the batch_iterator:
302+
- All batches consist of indices in the range [0, num_examples).
303+
- The number of examples in each batch is at most max_batch_size..
304+
- The separation between two appearances of any example is at least min_sep.
305+
- Each example appears at most max_part times.
306+
307+
Attributes:
308+
iterations: The number of total batches to generate.
309+
max_batch_size: The maximum number of examples in each batch.
310+
min_sep: The minimum separation between two appearances of any example.
311+
max_part: The maximum number of times any example can appear.
312+
"""
313+
iterations: int
314+
max_batch_size: int
315+
min_sep: int
316+
max_part: int | None = None
317+
318+
def batch_iterator(
319+
self, num_examples: int, rng: RngType = None
320+
) -> Iterator[np.ndarray]:
321+
rng = np.random.default_rng(rng)
322+
dtype = np.min_scalar_type(-num_examples)
323+
dtype2 = np.min_scalar_type(-self.iterations)
324+
dtype3 = np.min_scalar_type(self.max_part)
325+
326+
all_indices = np.arange(num_examples, dtype=dtype)
327+
last_seen = np.full(num_examples, -self.min_sep, dtype=dtype2)
328+
counts = np.zeros(num_examples, dtype=dtype3)
329+
330+
for step in range(self.iterations):
331+
valid_mask = last_seen <= step - self.min_sep
332+
if self.max_part is not None:
333+
valid_mask &= (counts < self.max_part)
334+
335+
candidates = all_indices[valid_mask]
336+
current_batch_size = min(self.max_batch_size, candidates.size)
337+
batch = rng.choice(candidates, size=current_batch_size, replace=False)
338+
last_seen[batch] = step
339+
counts[batch] += 1
340+
yield batch
341+
342+
294343
@dataclasses.dataclass(frozen=True)
295344
class UserSelectionStrategy:
296345
"""Applies base_strategy at the user level, and selects multiple examples per user.

jax_privacy/batch_selection_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,25 @@ def _check_all_equal(x):
7373
assert np.all(x == x[0]), f"Elements of x are not all equal: {x}"
7474

7575

76+
def _check_minsep_maxpart(batches, min_sep, max_part):
77+
last_seen = {}
78+
counts = {}
79+
for step, batch in enumerate(batches):
80+
assert len(batch) == len(set(batch)), "Duplicate indices found in batch"
81+
for idx in batch:
82+
count = counts.get(idx, 0)
83+
if max_part is not None:
84+
assert count < max_part, f"Example {idx} does not satisfy {max_part=}"
85+
86+
if idx in last_seen:
87+
prev_step = last_seen[idx]
88+
dist = step - prev_step
89+
assert dist >= min_sep, f"Example {idx} does not satisfy {min_sep=}"
90+
91+
last_seen[idx] = step
92+
counts[idx] = count + 1
93+
94+
7695
class BatchSelectionTest(parameterized.TestCase):
7796

7897
@parameterized.product(
@@ -226,6 +245,26 @@ def test_balls_in_bins_sampling_with_large_cycle_length(self):
226245
_check_no_repeated_indices(batches[:cycle_length])
227246
_check_cyclic_property(batches, cycle_length)
228247

248+
@parameterized.product(
249+
min_sep=[10],
250+
max_part=[None, 2],
251+
batch_size=[5],
252+
)
253+
def test_minsep_sampling(self, min_sep, max_part, batch_size):
254+
iterations = 40
255+
strategy = batch_selection.MinimumSeparationSampling(
256+
iterations=iterations,
257+
min_sep=min_sep,
258+
max_part=max_part,
259+
max_batch_size=batch_size
260+
)
261+
examples = batch_size * iterations * 2
262+
batches = list(strategy.batch_iterator(examples, rng=0))
263+
_check_minsep_maxpart(batches, min_sep, max_part)
264+
_check_batch_sizes_equal(batches, batch_size, batch_size)
265+
_check_element_range(batches, examples)
266+
_check_signed_indices(batches)
267+
229268
def test_user_selection_strategy(self):
230269
"""Tests for UserSelectionStrategy."""
231270
base_strategy = batch_selection.CyclicPoissonSampling(

0 commit comments

Comments
 (0)