Skip to content

Commit 8e5724c

Browse files
committed
fix: Optimize Batch Selection and Update JAX Compatibility
1 parent 704a292 commit 8e5724c

File tree

5 files changed

+40
-17
lines changed

5 files changed

+40
-17
lines changed

examples/distributed_noise_generation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from absl import flags
3737
import jax
3838
import jax.numpy as jnp
39+
# pylint: disable=g-importing-member
40+
from jax.sharding import reshard
3941
from jax_privacy import noise_addition
4042
from jax_privacy.matrix_factorization import toeplitz
4143

@@ -162,10 +164,10 @@ def run(pytree_like_model_params):
162164
t0 = time.time()
163165
compiled_run = run.lower(model_params).compile()
164166
t1 = time.time()
165-
print('[BandMF] Compilation time: %.3f seconds' % (t1 - t0))
167+
print(f'[BandMF] Compilation time: {t1 - t0:.3f} seconds')
166168
state, noisy_grad = jax.block_until_ready(compiled_run(model_params))
167169
t2 = time.time()
168-
print('[BandMF] Per-step run time: %.3f seconds' % ((t2 - t1) / steps))
170+
print(f'[BandMF] Per-step run time: {(t2 - t1) / steps:.3f} seconds')
169171

170172
return state, noisy_grad
171173

jax_privacy/batch_selection.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def batch_iterator(
293293

294294
@dataclasses.dataclass(frozen=True)
295295
class UserSelectionStrategy:
296-
"""A strategy that applies a base_strategy at the user level.
296+
"""Applies base_strategy at the user level, and selects multiple examples
297+
298+
per user.
297299
298300
Each batch returned by the batch_iterator is a 2D array of integer indices,
299301
where all entries in the same row are examples owned by the same user. The
@@ -354,12 +356,19 @@ def batch_iterator(
354356
num_examples = user_ids.size
355357
dtype = np.min_scalar_type(-num_examples)
356358

359+
# Precompute sorted indices and starts to avoid O(n) per user_id
360+
# in np.where.
361+
sorted_indices = np.argsort(inverse)
362+
counts = np.bincount(inverse, minlength=num_users)
363+
starts = np.r_[0, np.cumsum(counts)]
364+
357365
def create_user_generator(user_id):
358-
# TODO: b/415360727 - this where is suboptimal, as it is O(n) per user_id.
359-
owned_examples = np.where(inverse == user_id)[0].astype(dtype)
366+
start = starts[user_id]
367+
end = starts[user_id + 1]
368+
owned_examples = sorted_indices[start:end].astype(dtype)
360369
if self.shuffle_per_user:
361370
rng.shuffle(owned_examples)
362-
return itertools.cycle(list(owned_examples))
371+
return itertools.cycle(owned_examples)
363372

364373
user_generators = [create_user_generator(i) for i in range(num_users)]
365374

jax_privacy/sharding_utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def _ceiling_to_multiple(size: int, multiple: int) -> int:
4343
return size + multiple - remainder if remainder != 0 else size
4444

4545

46-
def flatten_with_zero_redundancy(abstract_array: jax.Array) -> jax.Array:
46+
def flatten_with_zero_redundancy(
47+
abstract_array: jax.ShapeDtypeStruct | jax.Array
48+
) -> jax.ShapeDtypeStruct:
4749
"""Return a flattened, padded, and ZeRo-sharded abstract version of x.
4850
4951
Specifically, the returned object will describe a 1D array that is
@@ -59,15 +61,11 @@ def flatten_with_zero_redundancy(abstract_array: jax.Array) -> jax.Array:
5961
A zero-redundancy abstract flattened+padded version of the input value.
6062
"""
6163
mesh = jax.typeof(abstract_array).sharding.mesh
62-
# As of JAX 0.7.0, jnp.*_like will not preserve sharding of ShapeDtypeStruct
63-
# defined w.r.t. AbstractMeshes, so we return a concrete array here.
64-
# Under JIT, this should get optimized away.
65-
# TODO: b/415360727 - Version bump to 0.7.1, swap in jax.ShapeDtypeStruct,
66-
# and add type annotations to this function.
67-
return jax.numpy.empty(
68-
_ceiling_to_multiple(abstract_array.size, mesh.size),
64+
# As of JAX 0.7.1, we can use ShapeDtypeStruct with sharding preserved.
65+
return jax.ShapeDtypeStruct(
66+
shape=(_ceiling_to_multiple(abstract_array.size, mesh.size),),
6967
dtype=abstract_array.dtype,
70-
out_sharding=jax.sharding.NamedSharding(mesh, jax.P(mesh.axis_names)),
68+
sharding=jax.sharding.NamedSharding(mesh, jax.P(mesh.axis_names)),
7169
)
7270

7371

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ requires-python = ">=3.11"
1414
dependencies = [
1515
"absl-py",
1616
"dp_accounting @ git+https://github.com/google/differential-privacy.git#subdirectory=python/dp_accounting",
17-
"jax>=0.7.0",
18-
"jaxlib>=0.7.0",
17+
"jax>=0.7.1",
18+
"jaxlib>=0.7.1",
1919
"pydantic",
2020
"numpy",
2121
"optax",

tests/sharding_utils_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ def test_ceiling_to_multiple(self):
8080
self.assertEqual(sharding_utils._ceiling_to_multiple(4, 4), 4)
8181
self.assertEqual(sharding_utils._ceiling_to_multiple(5, 4), 8)
8282

83+
def test_flatten_zeros_like_preserves_metadata(self):
84+
sharding = jax.sharding.NamedSharding(
85+
self.mesh, jax.sharding.PartitionSpec(None, 'y')
86+
)
87+
x = jax.device_put(jnp.ones((2, 6), dtype=jnp.float32), sharding)
88+
flattened = sharding_utils.flatten_with_zero_redundancy(x)
89+
zeros = jnp.zeros_like(flattened)
90+
self.assertEqual(zeros.shape, flattened.shape)
91+
self.assertEqual(zeros.dtype, flattened.dtype)
92+
self.assertEqual(zeros.sharding.spec, flattened.sharding.spec)
93+
self.assertEqual(
94+
zeros.sharding.mesh.axis_names, flattened.sharding.mesh.axis_names
95+
)
96+
8397

8498
if __name__ == '__main__':
8599
absltest.main()

0 commit comments

Comments
 (0)