2525import dp_accounting
2626import jax
2727import jax .numpy as jnp
28- from jax_privacy .experimental import microbatching
2928import optax
29+ from optax .experimental import microbatching
3030
3131
3232PyTree : TypeAlias = chex .ArrayTree
@@ -198,6 +198,36 @@ def _normalize_fun_to_return_aux(fun, has_aux):
198198 return lambda * args , ** kwargs : (fun (* args , ** kwargs ), ())
199199
200200
201+ def _num_real_microbatches (
202+ is_padding_example : jax .Array ,
203+ microbatch_size : int | None ,
204+ ) -> int | jax .Array :
205+ """Calculates the number of non-padding microbatches.
206+
207+ The returned result is 1 + the index of the last microbatch that contains at
208+ least one non-padding example. This means that microbatches consisting of
209+ all-padding examples that do not appear at the end will be treated as a real
210+ microbatch.
211+
212+ Args:
213+ is_padding_example: A 1D array of shape (num_examples,).
214+ microbatch_size: Argument passed to `microbatch`.
215+
216+ Returns:
217+ The `true` batch size, as a scalar jax array.
218+ """
219+ if microbatch_size is None :
220+ return is_padding_example .shape [0 ]
221+ reshaped = microbatching .reshape_batch_axis (
222+ is_padding_example , microbatch_size
223+ )
224+ # Ensure there is at least one True in the array.
225+ is_real_batch = jnp .append (True , ~ reshaped .all (axis = 1 ))
226+ # We want the last real microbatch, argmax returns the first True value,
227+ # so we add increasing numbers from 0 to 1 to each index.
228+ return jnp .argmax (is_real_batch + jnp .linspace (0 , 1 , is_real_batch .size ))
229+
230+
201231def clipped_fun (
202232 fun : Callable ,
203233 has_aux : bool = False ,
@@ -294,7 +324,8 @@ def clipped_fn(*args, **kwargs):
294324 is_padding_example = kwargs .get ('is_padding_example' , None )
295325 batch_size = jax .tree .leaves (args [batch_argnums [0 ]])[0 ].shape [0 ]
296326 if is_padding_example is None :
297- kwargs ['is_padding_example' ] = jnp .zeros (batch_size , dtype = jnp .bool_ )
327+ is_padding_example = jnp .zeros (batch_size , dtype = jnp .bool_ )
328+ kwargs ['is_padding_example' ] = is_padding_example
298329
299330 def clipped_fun_one_group (* args , is_padding_example , ** kwargs ):
300331 value , aux = fun (* args , ** kwargs )
@@ -308,6 +339,7 @@ def clipped_fun_one_group(*args, is_padding_example, **kwargs):
308339 )
309340 return clipped_value , aux , l2_norm
310341
342+ num_real_mb = _num_real_microbatches (is_padding_example , microbatch_size )
311343 sum_ = microbatching .AccumulationType .SUM
312344 concat = microbatching .AccumulationType .CONCAT
313345 axes = [0 if i in batch_argnums else None for i in range (len (args ))]
@@ -322,9 +354,11 @@ def clipped_fun_one_group(*args, is_padding_example, **kwargs):
322354 batch_argnums_with_prng = batch_argnums
323355 microbatched_vmap_fun = microbatching .microbatch (
324356 jax .vmap (clipped_fun_one_group , axes , spmd_axis_name = spmd_axis_name ),
325- batch_argnums = batch_argnums_with_prng ,
357+ argnums = batch_argnums_with_prng ,
358+ argnames = 'is_padding_example' ,
326359 microbatch_size = microbatch_size ,
327- accumulation_type = (sum_ , concat , concat ),
360+ accumulator = (sum_ , concat , concat ),
361+ num_real_microbatches = num_real_mb
328362 )
329363
330364 clipped_values , aux , norms = microbatched_vmap_fun (* args , ** kwargs )
0 commit comments