@@ -291,6 +291,55 @@ def batch_iterator(
291291 yield groups [i % self .cycle_length ]
292292
293293
294+ @dataclasses .dataclass (frozen = True )
295+ class MinimumSeparationSampling (BatchSelectionStrategy ):
296+ """Implements minimum separation sampling.
297+
298+ This batch selection strategy simulates constraints common in federated
299+ learning scenarios.
300+
301+ Formal Guarantees of the batch_iterator:
302+ - All batches consist of indices in the range [0, num_examples).
303+ - The number of examples in each batch is at most max_batch_size..
304+ - The separation between two appearances of any example is at least min_sep.
305+ - Each example appears at most max_part times.
306+
307+ Attributes:
308+ iterations: The number of total batches to generate.
309+ max_batch_size: The maximum number of examples in each batch.
310+ min_sep: The minimum separation between two appearances of any example.
311+ max_part: The maximum number of times any example can appear.
312+ """
313+ iterations : int
314+ max_batch_size : int
315+ min_sep : int
316+ max_part : int | None = None
317+
318+ def batch_iterator (
319+ self , num_examples : int , rng : RngType = None
320+ ) -> Iterator [np .ndarray ]:
321+ rng = np .random .default_rng (rng )
322+ dtype = np .min_scalar_type (- num_examples )
323+ dtype2 = np .min_scalar_type (- self .iterations )
324+ dtype3 = np .min_scalar_type (self .max_part )
325+
326+ all_indices = np .arange (num_examples , dtype = dtype )
327+ last_seen = np .full (num_examples , - self .min_sep , dtype = dtype2 )
328+ counts = np .zeros (num_examples , dtype = dtype3 )
329+
330+ for step in range (self .iterations ):
331+ valid_mask = last_seen <= step - self .min_sep
332+ if self .max_part is not None :
333+ valid_mask &= (counts < self .max_part )
334+
335+ candidates = all_indices [valid_mask ]
336+ current_batch_size = min (self .max_batch_size , candidates .size )
337+ batch = rng .choice (candidates , size = current_batch_size , replace = False )
338+ last_seen [batch ] = step
339+ counts [batch ] += 1
340+ yield batch
341+
342+
294343@dataclasses .dataclass (frozen = True )
295344class UserSelectionStrategy :
296345 """Applies base_strategy at the user level, and selects multiple examples per user.
0 commit comments