Skip to content

Commit 2ddd6d9

Browse files
Ryan McKennacopybara-github
authored andcommitted
Refactor batch selection strategy to support different partitioning schemes.
PiperOrigin-RevId: 837273077
1 parent 8f574de commit 2ddd6d9

File tree

4 files changed

+129
-96
lines changed

4 files changed

+129
-96
lines changed

jax_privacy/batch_selection.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import abc
3232
import dataclasses
33+
import enum
3334
import itertools
3435
from typing import Iterator
3536

@@ -40,6 +41,38 @@
4041
RngType = np.random.Generator | int | None
4142

4243

44+
class PartitionType(enum.Enum):
45+
"""An enum specifying how examples should be assigned to groups."""
46+
INDEPENDENT = enum.auto()
47+
"""Each example will be assigned to a group independently at random."""
48+
EQUAL_SPLIT = enum.auto()
49+
"""Examples will be shuffled and then split into groups of equal size."""
50+
51+
52+
def independent_partition(
53+
num_examples: int,
54+
num_groups: int,
55+
rng: np.random.Generator,
56+
dtype: np.typing.DTypeLike
57+
) -> list[np.ndarray]:
58+
sizes = rng.multinomial(num_examples, np.ones(num_groups) / num_groups)
59+
boundaries = np.cumsum(sizes)[:-1]
60+
indices = np.random.permutation(num_examples).astype(dtype)
61+
return np.split(indices, boundaries)
62+
63+
64+
def _equal_split_partition(
65+
num_examples: int,
66+
num_groups: int,
67+
rng: np.random.Generator,
68+
dtype: np.typing.DTypeLike
69+
) -> list[np.ndarray]:
70+
indices = rng.permutation(num_examples).astype(dtype)
71+
group_size = num_examples // num_groups
72+
groups = np.array_split(indices, num_groups)
73+
return [g[:group_size] for g in groups]
74+
75+
4376
def split_and_pad_global_batch(
4477
indices: np.ndarray, minibatch_size: int, microbatch_size: int | None = None
4578
) -> list[np.ndarray]:
@@ -140,18 +173,18 @@ class CyclicPoissonSampling(BatchSelectionStrategy):
140173
>>> rng = np.random.default_rng(0)
141174
>>> b = CyclicPoissonSampling(sampling_prob=1, iterations=8, cycle_length=4)
142175
>>> print(*b.batch_iterator(12, rng=rng), sep=' ')
143-
[0 1 2] [3 4 5] [6 7 8] [ 9 10 11] [0 1 2] [3 4 5] [6 7 8] [ 9 10 11]
176+
[9 2 7] [ 4 5 11] [0 3 6] [10 8 1] [9 2 7] [ 4 5 11] [0 3 6] [10 8 1]
144177
145178
Example Usage (standard Poisson sampling) [2]:
146179
>>> b = CyclicPoissonSampling(sampling_prob=0.25, iterations=8)
147180
>>> print(*b.batch_iterator(12, rng=rng), sep=' ')
148-
[0 4 9 3 5] [] [5] [4 6 2 7] [ 5 11] [ 2 5 8 6 9 11] [9 1] [7 5 4 3]
181+
[5 6 7] [5 8 3 7 2] [ 1 5 11] [0 3] [ 5 1 3 4 10] [2] [4 5 1 3] [6]
149182
150183
Example Usage (BandMF-style sampling) [3]:
151184
>>> p = 0.5
152-
>>> b = CyclicPoissonSampling(sampling_prob=p, iterations=8, cycle_length=2)
185+
>>> b = CyclicPoissonSampling(sampling_prob=p, iterations=6, cycle_length=2)
153186
>>> print(*b.batch_iterator(12, rng=rng), sep=' ')
154-
[2 4 0] [ 9 10 11] [3 5] [9 8] [0 3 4 5] [8] [2 1 4 5] [11]
187+
[2 4] [1 8 9] [2 7 5 4] [11 1 3] [10 2 5 0 4] [ 1 11 6]
155188
156189
157190
References:
@@ -186,34 +219,32 @@ class CyclicPoissonSampling(BatchSelectionStrategy):
186219
examples into cycle_length groups, and do Poisson sampling from the groups
187220
in a round-robin fashion. cycle_length == 1 retrieves standard Poisson
188221
sampling.
189-
shuffle: For cyclic Poisson sampling, whether to shuffle the examples before
190-
discarding (see even_partition) and partitioning.
191-
even_partition: If True, we discard num_examples % cycle_length examples
192-
before partitioning in cyclic Poisson sampling. If False, we can have
193-
uneven partitions. Defaults to True for ease of analysis.
222+
partition_type: How to partition the examples into groups for before Poisson
223+
sampling. EQUAL_SPLIT is the default, and is only compatible with zero-out
224+
and replace-one adjacency notions, while INDEPENDENT is compatible
225+
with the add-remove adjacency notion.
194226
"""
195227

196228
sampling_prob: float
197229
iterations: int
198230
truncated_batch_size: int | None = None
199231
cycle_length: int = 1
200-
shuffle: bool = False
201-
even_partition: bool = True
232+
partition_type: PartitionType = PartitionType.EQUAL_SPLIT
202233

203234
def batch_iterator(
204235
self, num_examples: int, rng: RngType = None
205236
) -> Iterator[np.ndarray]:
206237
rng = np.random.default_rng(rng)
207238
dtype = np.min_scalar_type(-num_examples)
208239

209-
indices = np.arange(num_examples, dtype=dtype)
210-
if self.shuffle:
211-
rng.shuffle(indices)
212-
if self.even_partition:
213-
group_size = num_examples // self.cycle_length
214-
indices = indices[: group_size * self.cycle_length]
240+
if self.partition_type == PartitionType.INDEPENDENT:
241+
partition_fn = independent_partition
242+
elif self.partition_type == PartitionType.EQUAL_SPLIT:
243+
partition_fn = _equal_split_partition
244+
else:
245+
raise ValueError(f'Unsupported partition type: {self.partition_type}')
215246

216-
partition = np.array_split(indices, self.cycle_length)
247+
partition = partition_fn(num_examples, self.cycle_length, rng, dtype)
217248

218249
for i in range(self.iterations):
219250
current_group = partition[i % self.cycle_length]
@@ -254,20 +285,10 @@ def batch_iterator(
254285
) -> Iterator[np.ndarray]:
255286
rng = np.random.default_rng(rng)
256287
dtype = np.min_scalar_type(-num_examples)
257-
indices = np.arange(num_examples, dtype=dtype)
258-
rng.shuffle(indices)
259-
260-
bin_sizes = rng.multinomial(
261-
n=num_examples,
262-
pvals=np.ones(self.cycle_length) / self.cycle_length,
263-
)
264-
# Pad bin_sizes so that cumsum's output starts with 0.
265-
batch_cutoffs = np.cumsum(np.append(0, bin_sizes))
288+
groups = independent_partition(num_examples, self.cycle_length, rng, dtype)
266289

267290
for i in range(self.iterations):
268-
start_index = batch_cutoffs[i % self.cycle_length]
269-
end_index = batch_cutoffs[(i % self.cycle_length) + 1]
270-
yield indices[start_index:end_index]
291+
yield groups[i % self.cycle_length]
271292

272293

273294
@dataclasses.dataclass(frozen=True)
@@ -284,18 +305,19 @@ class UserSelectionStrategy:
284305
users.
285306
286307
Example Usage:
308+
>>> rng = np.random.default_rng(0)
287309
>>> base_strategy = CyclicPoissonSampling(sampling_prob=1, iterations=5)
288310
>>> strategy = UserSelectionStrategy(base_strategy, 2)
289311
>>> user_ids = np.array([0,0,0,1,1,2])
290-
>>> iterator = strategy.batch_iterator(user_ids)
312+
>>> iterator = strategy.batch_iterator(user_ids, rng)
291313
>>> print(next(iterator))
292-
[[0 1]
293-
[3 4]
294-
[5 5]]
314+
[[5 5]
315+
[0 1]
316+
[3 4]]
295317
>>> print(next(iterator))
296-
[[2 0]
297-
[3 4]
298-
[5 5]]
318+
[[5 5]
319+
[2 0]
320+
[3 4]]
299321
300322
Attributes:
301323
base_strategy: The base batch selection strategy to apply at the user level.

jax_privacy/batch_selection_test.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,30 @@ def _check_all_equal(x):
7676
class BatchSelectionTest(parameterized.TestCase):
7777

7878
@parameterized.product(
79-
shuffle=[True, False],
80-
even_partition=[True, False],
79+
partition_type=[
80+
batch_selection.PartitionType.EQUAL_SPLIT,
81+
batch_selection.PartitionType.INDEPENDENT
82+
],
8183
num_examples=[10],
8284
cycle_length=[3],
8385
iterations=[5],
8486
)
8587
def test_cyclic_participation(
86-
self, *, shuffle, even_partition, num_examples, cycle_length, iterations
88+
self, *, partition_type, num_examples, cycle_length, iterations
8789
):
8890
"""Tests the use of CyclicPoissonSampling instantiated to do shuffling."""
8991
strategy = batch_selection.CyclicPoissonSampling(
9092
sampling_prob=1.0,
9193
iterations=iterations,
9294
cycle_length=cycle_length,
93-
shuffle=shuffle,
94-
even_partition=even_partition,
95+
partition_type=partition_type,
9596
)
9697
batches = list(strategy.batch_iterator(num_examples, rng=0))
9798

9899
self.assertLen(batches, iterations)
99-
min_batch_size = num_examples // cycle_length
100-
max_batch_size = min_batch_size if even_partition else min_batch_size + 1
101-
_check_batch_sizes_equal(batches, min_batch_size, max_batch_size)
100+
if partition_type == batch_selection.PartitionType.EQUAL_SPLIT:
101+
batch_size = num_examples // cycle_length
102+
_check_batch_sizes_equal(batches, batch_size, batch_size)
102103
_check_no_repeated_indices(batches[:cycle_length])
103104
_check_cyclic_property(batches, cycle_length)
104105
_check_element_range(batches, num_examples)
@@ -109,8 +110,10 @@ def test_cyclic_participation(
109110
iterations=[1000],
110111
cycle_length=[1, 3],
111112
expected_batch_size=[3],
112-
shuffle=[True, False],
113-
even_partition=[True, False],
113+
partition_type=[
114+
batch_selection.PartitionType.INDEPENDENT,
115+
batch_selection.PartitionType.EQUAL_SPLIT,
116+
],
114117
truncated_batch_size=[None, 4],
115118
)
116119
def test_poisson_sampling(
@@ -120,8 +123,7 @@ def test_poisson_sampling(
120123
iterations,
121124
cycle_length,
122125
expected_batch_size,
123-
shuffle,
124-
even_partition,
126+
partition_type,
125127
truncated_batch_size,
126128
):
127129
"""Tests for Poisson sampling, potentially cyclic and truncated."""
@@ -130,8 +132,7 @@ def test_poisson_sampling(
130132
sampling_prob=sampling_prob,
131133
iterations=iterations,
132134
cycle_length=cycle_length,
133-
shuffle=shuffle,
134-
even_partition=even_partition,
135+
partition_type=partition_type,
135136
truncated_batch_size=truncated_batch_size,
136137
)
137138
batches = list(strategy.batch_iterator(num_examples, rng=0))
@@ -140,10 +141,9 @@ def test_poisson_sampling(
140141
min_batch_size = 0
141142
if truncated_batch_size:
142143
max_batch_size = truncated_batch_size
143-
elif even_partition:
144-
max_batch_size = num_examples // cycle_length
145144
else:
146-
max_batch_size = math.ceil(num_examples / cycle_length)
145+
max_batch_size = num_examples // cycle_length
146+
147147
_check_batch_sizes_equal(batches, min_batch_size, max_batch_size)
148148
for start_index in range(0, iterations, cycle_length):
149149
_check_no_repeated_indices(
@@ -152,7 +152,7 @@ def test_poisson_sampling(
152152
_check_element_range(batches, num_examples)
153153
_check_subset_of_kb_participation(batches, cycle_length)
154154
# Make sure elements are discarded when using even partition.
155-
if even_partition:
155+
if partition_type == batch_selection.PartitionType.EQUAL_SPLIT:
156156
distinct_elements = _get_unique_elements(batches)
157157
self.assertLessEqual(
158158
distinct_elements.size, (num_examples // cycle_length) * cycle_length
@@ -176,20 +176,14 @@ def test_poisson_sampling_with_large_cycle_length(self):
176176
sampling_prob=sampling_prob,
177177
iterations=iterations,
178178
cycle_length=cycle_length,
179-
even_partition=False,
179+
partition_type=batch_selection.PartitionType.EQUAL_SPLIT,
180180
)
181181
batches = list(strategy.batch_iterator(num_examples, rng=0))
182182

183183
self.assertLen(batches, iterations)
184184
min_batch_size = 0
185-
max_batch_size = 1
185+
max_batch_size = 0
186186
_check_batch_sizes_equal(batches, min_batch_size, max_batch_size)
187-
for start_index in range(0, iterations, cycle_length):
188-
_check_no_repeated_indices(
189-
batches[start_index : start_index + cycle_length]
190-
)
191-
_check_element_range(batches, num_examples)
192-
_check_subset_of_kb_participation(batches, cycle_length)
193187

194188
@parameterized.product(
195189
num_examples=[100],

0 commit comments

Comments
 (0)