diff --git a/jax_privacy/batch_selection.py b/jax_privacy/batch_selection.py index 81009b0..3daaa57 100644 --- a/jax_privacy/batch_selection.py +++ b/jax_privacy/batch_selection.py @@ -291,6 +291,55 @@ def batch_iterator( yield groups[i % self.cycle_length] +@dataclasses.dataclass(frozen=True) +class MinimumSeparationSampling(BatchSelectionStrategy): + """Implements minimum separation sampling. + + This batch selection strategy simulates constraints common in federated + learning scenarios. + + Formal Guarantees of the batch_iterator: + - All batches consist of indices in the range [0, num_examples). + - The number of examples in each batch is at most max_batch_size.. + - The separation between two appearances of any example is at least min_sep. + - Each example appears at most max_part times. + + Attributes: + iterations: The number of total batches to generate. + max_batch_size: The maximum number of examples in each batch. + min_sep: The minimum separation between two appearances of any example. + max_part: The maximum number of times any example can appear. + """ + iterations: int + max_batch_size: int + min_sep: int + max_part: int | None = None + + def batch_iterator( + self, num_examples: int, rng: RngType = None + ) -> Iterator[np.ndarray]: + rng = np.random.default_rng(rng) + dtype = np.min_scalar_type(-num_examples) + dtype2 = np.min_scalar_type(-self.iterations) + dtype3 = np.min_scalar_type(self.max_part or self.iterations) + + all_indices = np.arange(num_examples, dtype=dtype) + last_seen = np.full(num_examples, -self.min_sep, dtype=dtype2) + counts = np.zeros(num_examples, dtype=dtype3) + + for step in range(self.iterations): + valid_mask = last_seen <= step - self.min_sep + if self.max_part is not None: + valid_mask &= (counts < self.max_part) + + candidates = all_indices[valid_mask] + current_batch_size = min(self.max_batch_size, candidates.size) + batch = rng.choice(candidates, size=current_batch_size, replace=False) + last_seen[batch] = step + counts[batch] += 1 + yield batch + + @dataclasses.dataclass(frozen=True) class UserSelectionStrategy: """Applies base_strategy at the user level, and selects multiple examples per user. diff --git a/jax_privacy/batch_selection_test.py b/jax_privacy/batch_selection_test.py index bdaa8ad..adb0c4e 100644 --- a/jax_privacy/batch_selection_test.py +++ b/jax_privacy/batch_selection_test.py @@ -73,6 +73,25 @@ def _check_all_equal(x): assert np.all(x == x[0]), f"Elements of x are not all equal: {x}" +def _check_minsep_maxpart(batches, min_sep, max_part): + last_seen = {} + counts = {} + for step, batch in enumerate(batches): + assert len(batch) == len(set(batch)), "Duplicate indices found in batch" + for idx in batch: + count = counts.get(idx, 0) + if max_part is not None: + assert count < max_part, f"Example {idx} does not satisfy {max_part=}" + + if idx in last_seen: + prev_step = last_seen[idx] + dist = step - prev_step + assert dist >= min_sep, f"Example {idx} does not satisfy {min_sep=}" + + last_seen[idx] = step + counts[idx] = count + 1 + + class BatchSelectionTest(parameterized.TestCase): @parameterized.product( @@ -226,6 +245,26 @@ def test_balls_in_bins_sampling_with_large_cycle_length(self): _check_no_repeated_indices(batches[:cycle_length]) _check_cyclic_property(batches, cycle_length) + @parameterized.product( + min_sep=[10], + max_part=[None, 2], + batch_size=[5], + ) + def test_minsep_sampling(self, min_sep, max_part, batch_size): + iterations = 40 + strategy = batch_selection.MinimumSeparationSampling( + iterations=iterations, + min_sep=min_sep, + max_part=max_part, + max_batch_size=batch_size + ) + examples = batch_size * iterations * 2 + batches = list(strategy.batch_iterator(examples, rng=0)) + _check_minsep_maxpart(batches, min_sep, max_part) + _check_batch_sizes_equal(batches, batch_size, batch_size) + _check_element_range(batches, examples) + _check_signed_indices(batches) + def test_user_selection_strategy(self): """Tests for UserSelectionStrategy.""" base_strategy = batch_selection.CyclicPoissonSampling(