3030
3131import abc
3232import dataclasses
33+ import enum
3334import itertools
3435from typing import Iterator
3536
4041RngType = np .random .Generator | int | None
4142
4243
44+ class PartitionType (enum .Enum ):
45+ """An enum specifying how examples should be assigned to groups."""
46+ INDEPENDENT = enum .auto ()
47+ """Each example will be assigned to a group independently at random."""
48+ EQUAL_SPLIT = enum .auto ()
49+ """Examples will be shuffled and then split into groups of equal size."""
50+
51+
52+ def independent_partition (
53+ num_examples : int ,
54+ num_groups : int ,
55+ rng : np .random .Generator ,
56+ dtype : np .typing .DTypeLike
57+ ) -> list [np .ndarray ]:
58+ sizes = rng .multinomial (num_examples , np .ones (num_groups ) / num_groups )
59+ boundaries = np .cumsum (sizes )[:- 1 ]
60+ indices = np .random .permutation (num_examples ).astype (dtype )
61+ return np .split (indices , boundaries )
62+
63+
64+ def _equal_split_partition (
65+ num_examples : int ,
66+ num_groups : int ,
67+ rng : np .random .Generator ,
68+ dtype : np .typing .DTypeLike
69+ ) -> list [np .ndarray ]:
70+ indices = rng .permutation (num_examples ).astype (dtype )
71+ group_size = num_examples // num_groups
72+ groups = np .array_split (indices , num_groups )
73+ return [g [:group_size ] for g in groups ]
74+
75+
4376def split_and_pad_global_batch (
4477 indices : np .ndarray , minibatch_size : int , microbatch_size : int | None = None
4578) -> list [np .ndarray ]:
@@ -140,18 +173,18 @@ class CyclicPoissonSampling(BatchSelectionStrategy):
140173 >>> rng = np.random.default_rng(0)
141174 >>> b = CyclicPoissonSampling(sampling_prob=1, iterations=8, cycle_length=4)
142175 >>> print(*b.batch_iterator(12, rng=rng), sep=' ')
143- [0 1 2 ] [3 4 5 ] [6 7 8 ] [ 9 10 11 ] [0 1 2 ] [3 4 5 ] [6 7 8 ] [ 9 10 11 ]
176+ [9 2 7 ] [ 4 5 11 ] [0 3 6 ] [10 8 1 ] [9 2 7 ] [ 4 5 11 ] [0 3 6 ] [10 8 1 ]
144177
145178 Example Usage (standard Poisson sampling) [2]:
146179 >>> b = CyclicPoissonSampling(sampling_prob=0.25, iterations=8)
147180 >>> print(*b.batch_iterator(12, rng=rng), sep=' ')
148- [0 4 9 3 5 ] [] [5] [4 6 2 7 ] [ 5 11] [ 2 5 8 6 9 11 ] [9 1 ] [7 5 4 3 ]
181+ [5 6 7 ] [5 8 3 7 2 ] [ 1 5 11] [0 3] [ 5 1 3 4 10 ] [2 ] [4 5 1 3] [6 ]
149182
150183 Example Usage (BandMF-style sampling) [3]:
151184 >>> p = 0.5
152- >>> b = CyclicPoissonSampling(sampling_prob=p, iterations=8 , cycle_length=2)
185+ >>> b = CyclicPoissonSampling(sampling_prob=p, iterations=6 , cycle_length=2)
153186 >>> print(*b.batch_iterator(12, rng=rng), sep=' ')
154- [2 4 0 ] [ 9 10 11 ] [3 5 ] [9 8 ] [0 3 4 5] [8] [2 1 4 5 ] [11 ]
187+ [2 4] [1 8 9 ] [2 7 5 4 ] [11 1 3 ] [10 2 5 0 4 ] [ 1 11 6 ]
155188
156189
157190 References:
@@ -186,34 +219,32 @@ class CyclicPoissonSampling(BatchSelectionStrategy):
186219 examples into cycle_length groups, and do Poisson sampling from the groups
187220 in a round-robin fashion. cycle_length == 1 retrieves standard Poisson
188221 sampling.
189- shuffle: For cyclic Poisson sampling, whether to shuffle the examples before
190- discarding (see even_partition) and partitioning.
191- even_partition: If True, we discard num_examples % cycle_length examples
192- before partitioning in cyclic Poisson sampling. If False, we can have
193- uneven partitions. Defaults to True for ease of analysis.
222+ partition_type: How to partition the examples into groups for before Poisson
223+ sampling. EQUAL_SPLIT is the default, and is only compatible with zero-out
224+ and replace-one adjacency notions, while INDEPENDENT is compatible
225+ with the add-remove adjacency notion.
194226 """
195227
196228 sampling_prob : float
197229 iterations : int
198230 truncated_batch_size : int | None = None
199231 cycle_length : int = 1
200- shuffle : bool = False
201- even_partition : bool = True
232+ partition_type : PartitionType = PartitionType .EQUAL_SPLIT
202233
203234 def batch_iterator (
204235 self , num_examples : int , rng : RngType = None
205236 ) -> Iterator [np .ndarray ]:
206237 rng = np .random .default_rng (rng )
207238 dtype = np .min_scalar_type (- num_examples )
208239
209- indices = np . arange ( num_examples , dtype = dtype )
210- if self . shuffle :
211- rng . shuffle ( indices )
212- if self . even_partition :
213- group_size = num_examples // self . cycle_length
214- indices = indices [: group_size * self .cycle_length ]
240+ if self . partition_type == PartitionType . INDEPENDENT :
241+ partition_fn = independent_partition
242+ elif self . partition_type == PartitionType . EQUAL_SPLIT :
243+ partition_fn = _equal_split_partition
244+ else :
245+ raise ValueError ( f'Unsupported partition type: { self .partition_type } ' )
215246
216- partition = np . array_split ( indices , self .cycle_length )
247+ partition = partition_fn ( num_examples , self .cycle_length , rng , dtype )
217248
218249 for i in range (self .iterations ):
219250 current_group = partition [i % self .cycle_length ]
@@ -254,20 +285,10 @@ def batch_iterator(
254285 ) -> Iterator [np .ndarray ]:
255286 rng = np .random .default_rng (rng )
256287 dtype = np .min_scalar_type (- num_examples )
257- indices = np .arange (num_examples , dtype = dtype )
258- rng .shuffle (indices )
259-
260- bin_sizes = rng .multinomial (
261- n = num_examples ,
262- pvals = np .ones (self .cycle_length ) / self .cycle_length ,
263- )
264- # Pad bin_sizes so that cumsum's output starts with 0.
265- batch_cutoffs = np .cumsum (np .append (0 , bin_sizes ))
288+ groups = independent_partition (num_examples , self .cycle_length , rng , dtype )
266289
267290 for i in range (self .iterations ):
268- start_index = batch_cutoffs [i % self .cycle_length ]
269- end_index = batch_cutoffs [(i % self .cycle_length ) + 1 ]
270- yield indices [start_index :end_index ]
291+ yield groups [i % self .cycle_length ]
271292
272293
273294@dataclasses .dataclass (frozen = True )
@@ -284,18 +305,19 @@ class UserSelectionStrategy:
284305 users.
285306
286307 Example Usage:
308+ >>> rng = np.random.default_rng(0)
287309 >>> base_strategy = CyclicPoissonSampling(sampling_prob=1, iterations=5)
288310 >>> strategy = UserSelectionStrategy(base_strategy, 2)
289311 >>> user_ids = np.array([0,0,0,1,1,2])
290- >>> iterator = strategy.batch_iterator(user_ids)
312+ >>> iterator = strategy.batch_iterator(user_ids, rng )
291313 >>> print(next(iterator))
292- [[0 1 ]
293- [3 4 ]
294- [5 5 ]]
314+ [[5 5 ]
315+ [0 1 ]
316+ [3 4 ]]
295317 >>> print(next(iterator))
296- [[2 0 ]
297- [3 4 ]
298- [5 5 ]]
318+ [[5 5 ]
319+ [2 0 ]
320+ [3 4 ]]
299321
300322 Attributes:
301323 base_strategy: The base batch selection strategy to apply at the user level.
0 commit comments