Skip to content

Commit b137c9b

Browse files
Ryan McKennacopybara-github
authored andcommitted
Remove experimental/microbatching.py in favor of optax.experimental.microbatching.
PiperOrigin-RevId: 859625717
1 parent a6b6bdf commit b137c9b

File tree

8 files changed

+130
-634
lines changed

8 files changed

+130
-634
lines changed

docs/core_library.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,3 @@ Experimental Modules
2525

2626
experimental.execution_plan
2727
experimental.compilation_utils
28-
29-
30-
Other References
31-
----------------
32-
.. autosummary::
33-
:toctree: _autosummary_output
34-
:nosignatures:
35-
36-
experimental.microbatching

jax_privacy/batch_selection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import itertools
3535
from typing import Iterator
3636

37-
from jax_privacy.experimental import microbatching
37+
from jax_privacy import sharding_utils
3838
import numpy as np
3939

4040

@@ -109,7 +109,7 @@ def split_and_pad_global_batch(
109109
minibatch_shape = (minibatch_size,) + indices.shape[1:]
110110
last_minibatch = np.full(minibatch_shape, -1, dtype=indices.dtype)
111111
last_minibatch[: minibatches[-1].shape[0]] = minibatches[-1]
112-
permutation = microbatching.compute_early_stopping_order(
112+
permutation = sharding_utils.compute_early_stopping_order(
113113
minibatch_size, microbatch_size
114114
)
115115
minibatches[-1] = last_minibatch[permutation]

jax_privacy/clipping.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import dp_accounting
2626
import jax
2727
import jax.numpy as jnp
28-
from jax_privacy.experimental import microbatching
2928
import optax
29+
from optax.experimental import microbatching
3030

3131

3232
PyTree: TypeAlias = chex.ArrayTree
@@ -198,6 +198,36 @@ def _normalize_fun_to_return_aux(fun, has_aux):
198198
return lambda *args, **kwargs: (fun(*args, **kwargs), ())
199199

200200

201+
def _num_real_microbatches(
202+
is_padding_example: jax.Array,
203+
microbatch_size: int | None,
204+
) -> int | jax.Array:
205+
"""Calculates the number of non-padding microbatches.
206+
207+
The returned result is 1 + the index of the last microbatch that contains at
208+
least one non-padding example. This means that microbatches consisting of
209+
all-padding examples that do not appear at the end will be treated as a real
210+
microbatch.
211+
212+
Args:
213+
is_padding_example: A 1D array of shape (num_examples,).
214+
microbatch_size: Argument passed to `microbatch`.
215+
216+
Returns:
217+
The `true` batch size, as a scalar jax array.
218+
"""
219+
if microbatch_size is None:
220+
return is_padding_example.shape[0]
221+
reshaped = microbatching.reshape_batch_axis(
222+
is_padding_example, microbatch_size
223+
)
224+
# Ensure there is at least one True in the array.
225+
is_real_batch = jnp.append(True, ~reshaped.all(axis=1))
226+
# We want the last real microbatch, argmax returns the first True value,
227+
# so we add increasing numbers from 0 to 1 to each index.
228+
return jnp.argmax(is_real_batch + jnp.linspace(0, 1, is_real_batch.size))
229+
230+
201231
def clipped_fun(
202232
fun: Callable,
203233
has_aux: bool = False,
@@ -294,7 +324,8 @@ def clipped_fn(*args, **kwargs):
294324
is_padding_example = kwargs.get('is_padding_example', None)
295325
batch_size = jax.tree.leaves(args[batch_argnums[0]])[0].shape[0]
296326
if is_padding_example is None:
297-
kwargs['is_padding_example'] = jnp.zeros(batch_size, dtype=jnp.bool_)
327+
is_padding_example = jnp.zeros(batch_size, dtype=jnp.bool_)
328+
kwargs['is_padding_example'] = is_padding_example
298329

299330
def clipped_fun_one_group(*args, is_padding_example, **kwargs):
300331
value, aux = fun(*args, **kwargs)
@@ -308,6 +339,7 @@ def clipped_fun_one_group(*args, is_padding_example, **kwargs):
308339
)
309340
return clipped_value, aux, l2_norm
310341

342+
num_real_mb = _num_real_microbatches(is_padding_example, microbatch_size)
311343
sum_ = microbatching.AccumulationType.SUM
312344
concat = microbatching.AccumulationType.CONCAT
313345
axes = [0 if i in batch_argnums else None for i in range(len(args))]
@@ -322,9 +354,11 @@ def clipped_fun_one_group(*args, is_padding_example, **kwargs):
322354
batch_argnums_with_prng = batch_argnums
323355
microbatched_vmap_fun = microbatching.microbatch(
324356
jax.vmap(clipped_fun_one_group, axes, spmd_axis_name=spmd_axis_name),
325-
batch_argnums=batch_argnums_with_prng,
357+
argnums=batch_argnums_with_prng,
358+
argnames='is_padding_example',
326359
microbatch_size=microbatch_size,
327-
accumulation_type=(sum_, concat, concat),
360+
accumulator=(sum_, concat, concat),
361+
num_real_microbatches=num_real_mb
328362
)
329363

330364
clipped_values, aux, norms = microbatched_vmap_fun(*args, **kwargs)

0 commit comments

Comments
 (0)