Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions jax_privacy/batch_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,55 @@ def batch_iterator(
yield groups[i % self.cycle_length]


@dataclasses.dataclass(frozen=True)
class MinimumSeparationSampling(BatchSelectionStrategy):
"""Implements minimum separation sampling.

This batch selection strategy simulates constraints common in federated
learning scenarios.

Formal Guarantees of the batch_iterator:
- All batches consist of indices in the range [0, num_examples).
- The number of examples in each batch is at most max_batch_size..
- The separation between two appearances of any example is at least min_sep.
- Each example appears at most max_part times.

Attributes:
iterations: The number of total batches to generate.
max_batch_size: The maximum number of examples in each batch.
min_sep: The minimum separation between two appearances of any example.
max_part: The maximum number of times any example can appear.
"""
iterations: int
max_batch_size: int
min_sep: int
max_part: int | None = None

def batch_iterator(
self, num_examples: int, rng: RngType = None
) -> Iterator[np.ndarray]:
rng = np.random.default_rng(rng)
dtype = np.min_scalar_type(-num_examples)
dtype2 = np.min_scalar_type(-self.iterations)
dtype3 = np.min_scalar_type(self.max_part or self.iterations)

all_indices = np.arange(num_examples, dtype=dtype)
last_seen = np.full(num_examples, -self.min_sep, dtype=dtype2)
counts = np.zeros(num_examples, dtype=dtype3)

for step in range(self.iterations):
valid_mask = last_seen <= step - self.min_sep
if self.max_part is not None:
valid_mask &= (counts < self.max_part)

candidates = all_indices[valid_mask]
current_batch_size = min(self.max_batch_size, candidates.size)
batch = rng.choice(candidates, size=current_batch_size, replace=False)
last_seen[batch] = step
counts[batch] += 1
yield batch


@dataclasses.dataclass(frozen=True)
class UserSelectionStrategy:
"""Applies base_strategy at the user level, and selects multiple examples per user.
Expand Down
39 changes: 39 additions & 0 deletions jax_privacy/batch_selection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ def _check_all_equal(x):
assert np.all(x == x[0]), f"Elements of x are not all equal: {x}"


def _check_minsep_maxpart(batches, min_sep, max_part):
last_seen = {}
counts = {}
for step, batch in enumerate(batches):
assert len(batch) == len(set(batch)), "Duplicate indices found in batch"
for idx in batch:
count = counts.get(idx, 0)
if max_part is not None:
assert count < max_part, f"Example {idx} does not satisfy {max_part=}"

if idx in last_seen:
prev_step = last_seen[idx]
dist = step - prev_step
assert dist >= min_sep, f"Example {idx} does not satisfy {min_sep=}"

last_seen[idx] = step
counts[idx] = count + 1


class BatchSelectionTest(parameterized.TestCase):

@parameterized.product(
Expand Down Expand Up @@ -226,6 +245,26 @@ def test_balls_in_bins_sampling_with_large_cycle_length(self):
_check_no_repeated_indices(batches[:cycle_length])
_check_cyclic_property(batches, cycle_length)

@parameterized.product(
min_sep=[10],
max_part=[None, 2],
batch_size=[5],
)
def test_minsep_sampling(self, min_sep, max_part, batch_size):
iterations = 40
strategy = batch_selection.MinimumSeparationSampling(
iterations=iterations,
min_sep=min_sep,
max_part=max_part,
max_batch_size=batch_size
)
examples = batch_size * iterations * 2
batches = list(strategy.batch_iterator(examples, rng=0))
_check_minsep_maxpart(batches, min_sep, max_part)
_check_batch_sizes_equal(batches, batch_size, batch_size)
_check_element_range(batches, examples)
_check_signed_indices(batches)

def test_user_selection_strategy(self):
"""Tests for UserSelectionStrategy."""
base_strategy = batch_selection.CyclicPoissonSampling(
Expand Down