Skip to content

Commit c501354

Browse files
Refactor: Implement batch padding for JIT optimization, migrate to plan.clipped_grad, and patch batch_selection for multi-dim support
1 parent 54b274e commit c501354

3 files changed

Lines changed: 70 additions & 59 deletions

File tree

examples/dp_sgd_transformer_nnx.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import jax
3939
import jax.extend.backend
4040
import jax.numpy as jnp
41-
from jax_privacy import clipped_grad
41+
from jax_privacy import batch_selection
4242
from jax_privacy.experimental import execution_plan
4343
import numpy as np
4444
import optax
@@ -56,6 +56,7 @@
5656
NOISE_MULTIPLIER = 1.0
5757
EPSILON = 10.0
5858
DELTA = 1e-6
59+
PADDING_MULTIPLE = 8
5960

6061

6162
# Data loading and preparation
@@ -295,15 +296,6 @@ def main(argv: list[str]) -> None:
295296

296297
opt_state = optimizer.init(params)
297298

298-
# Configure DP Gradient Clipping
299-
grad_fn = clipped_grad(
300-
functools.partial(pure_loss_fn, graphdef=graphdef, other=other),
301-
l2_clip_norm=CLIP_NORM,
302-
batch_argnums=(1, 2), # x and y are batched
303-
prng_argnum=3, # Explicitly vmap the PRNG key per batch example
304-
return_values=True, # Return loss values for logging
305-
)
306-
307299
# Execution Plan Configuration
308300
dataset_size = len(data)
309301
config = execution_plan.BandMFExecutionPlanConfig.default(
@@ -318,13 +310,22 @@ def main(argv: list[str]) -> None:
318310
privatizer = plan.noise_addition_transform
319311
noise_state = privatizer.init(params)
320312

313+
# Configure DP Gradient Clipping
314+
grad_fn = plan.clipped_grad(
315+
functools.partial(pure_loss_fn, graphdef=graphdef, other=other),
316+
batch_argnums=(1, 2), # x and y are batched
317+
prng_argnum=3, # Explicitly vmap the PRNG key per batch example
318+
return_values=True, # Return loss values for logging
319+
)
320+
321321
@jax.jit(donate_argnums=(0, 1, 4))
322322
def train_step(
323323
params: nnx.State,
324324
opt_state: optax.OptState,
325325
batch: Tuple[jax.Array, jax.Array],
326326
prng_key: jax.Array,
327327
noise_state: Any,
328+
is_padding_example: jax.Array,
328329
) -> Tuple[nnx.State, optax.OptState, Any, jax.Array]:
329330
"""Performs a single training step with DP-SGD.
330331
@@ -334,20 +335,19 @@ def train_step(
334335
batch: A tuple (x, y) of input and target data.
335336
prng_key: A pseudorandom number generator key.
336337
noise_state: Current state of the noise mechanism.
338+
is_padding_example: Boolean mask indicating padding rows.
337339
338340
Returns:
339341
Updated params, opt_state, noise_state, and the mean loss for the batch.
340342
"""
343+
print(f"DEBUG: Compiling train_step for batch size {batch[0].shape[0]}")
341344
x, y = batch
342345

343-
# Handle zero-sized batch explicitly to avoid tracing crash in optax
344-
if x.shape[0] == 0:
345-
grads = jax.tree_util.tree_map(jnp.zeros_like, params)
346-
mean_loss = jnp.array(0.0)
347-
else:
348-
# Compute clipped gradients and per-example loss values
349-
grads, loss = grad_fn(params, x, y, prng_key)
350-
mean_loss = loss.values.mean()
346+
# Compute clipped gradients and per-example loss values
347+
grads, loss = grad_fn(
348+
params, x, y, prng_key, is_padding_example=is_padding_example
349+
)
350+
mean_loss = loss.values.mean()
351351

352352
assert all(
353353
g.shape == p.shape
@@ -371,22 +371,17 @@ def train_step(
371371
iterator = plan.batch_selection_strategy.batch_iterator(dataset_size)
372372
prng_key = jax.random.key(42)
373373
for step, batch_indices in enumerate(iterator):
374-
if step >= NUM_STEPS:
375-
break
376-
377-
# Construct batch from indices
378-
if len(batch_indices) == 0:
379-
x = np.zeros((0, CONTEXT_LENGTH), dtype=np.int32)
380-
y = np.zeros((0, CONTEXT_LENGTH), dtype=np.int32)
381-
else:
382-
batch_seqs = data[batch_indices]
383-
x = batch_seqs[:, :-1]
384-
y = batch_seqs[:, 1:]
374+
idx = batch_selection.pad_to_multiple_of(batch_indices, PADDING_MULTIPLE)
375+
is_padding_example = idx == -1
376+
safe_idx = np.where(idx == -1, 0, idx)
377+
batch_seqs = data[safe_idx]
378+
x = batch_seqs[:, :-1]
379+
y = batch_seqs[:, 1:]
385380
batch = (x, y)
386381

387382
prng_key, subkey = jax.random.split(prng_key)
388383
params, opt_state, noise_state, loss = train_step(
389-
params, opt_state, batch, subkey, noise_state
384+
params, opt_state, batch, subkey, noise_state, is_padding_example
390385
)
391386

392387
print(f"Step {step + 1}/{NUM_STEPS}, Loss: {loss:.4f}")

examples/user_level_transformer_example.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
import jax.numpy as jnp
3737
import numpy as np
3838
import optax
39+
from jax_privacy import batch_selection
3940
from jax_privacy.batch_selection import UserSelectionStrategy
40-
from jax_privacy.clipping import clipped_grad
4141
from jax_privacy.experimental import execution_plan
4242

4343

@@ -49,6 +49,7 @@
4949
LEARNING_RATE = 1e-3
5050
EPSILON = 10.0
5151
DELTA = 1e-6
52+
PADDING_MULTIPLE = 8
5253

5354

5455
class TransformerDecoder(nn.Module):
@@ -142,25 +143,26 @@ def main(argv: list[str]) -> None:
142143
# We need the grad_fn first.
143144

144145
# 3. Training Step & Clipping
145-
def loss_fn(params, batch_data, batch_labels):
146-
logits = model.apply({'params': params}, batch_data, train=True)
147-
one_hot_labels = jax.nn.one_hot(batch_labels, num_classes=vocab_size)
146+
def loss_fn(params, x, y, prng_key=None):
147+
del prng_key # Unused
148+
logits = model.apply({'params': params}, x, train=True)
149+
one_hot_labels = jax.nn.one_hot(y, num_classes=vocab_size)
148150
return jnp.mean(
149151
optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
150152
)
151153

152-
grad_fn = clipped_grad(
153-
loss_fn,
154-
l2_clip_norm=L2_CLIP_NORM,
155-
batch_argnums=(1, 2),
156-
keep_batch_dim=False,
157-
)
158-
159154
# Create Plan
160155
plan = config.plan
161156
privatizer = plan.noise_addition_transform
162157
noise_state = privatizer.init(params)
163158

159+
grad_fn = plan.clipped_grad(
160+
loss_fn,
161+
batch_argnums=(1, 2),
162+
prng_argnum=3,
163+
return_values=True,
164+
)
165+
164166
# Wrap the plan's strategy with UserSelectionStrategy
165167
# We assume plan.batch_selection_strategy is compatible
166168
# (CyclicPoissonSampling)
@@ -170,32 +172,37 @@ def loss_fn(params, batch_data, batch_labels):
170172
)
171173

172174
@jax.jit(donate_argnums=(0, 1, 4))
173-
def train_step(params, opt_state, batch_data, batch_labels, noise_state):
174-
grads = grad_fn(params, batch_data, batch_labels)
175+
def train_step(
176+
params, opt_state, x, y, noise_state, prng_key, is_padding_example
177+
):
178+
print(f'DEBUG: Compiling train_step for batch size {x.shape[0]}')
179+
grads, loss = grad_fn(
180+
params, x, y, prng_key, is_padding_example=is_padding_example
181+
)
175182

176183
# Add Privacy Noise (Using plan's privatizer)
177184
noisy_grads, noise_state = privatizer.update(grads, noise_state)
178185

179186
updates, opt_state = optimizer.update(noisy_grads, opt_state, params)
180187
params = optax.apply_updates(params, updates)
181-
return params, opt_state, noise_state
188+
return params, opt_state, noise_state, loss.values.mean()
182189

183190
# 4. Training Loop
184191
start_time = time.time()
185192
batch_iterator = user_strategy.batch_iterator(user_ids, rng=0)
193+
prng_key = jax.random.key(42)
186194
for step, user_batch_indices in enumerate(batch_iterator):
187-
if user_batch_indices.size == 0:
188-
print(f'Step {step}: Skipping empty batch.')
189-
continue
190-
191-
batch_data = data[user_batch_indices]
192-
batch_labels = labels[user_batch_indices]
193-
194-
# Calculate and print loss
195-
loss_val = loss_fn(params, batch_data, batch_labels)
196-
197-
params, opt_state, noise_state = train_step(
198-
params, opt_state, batch_data, batch_labels, noise_state
195+
idx = batch_selection.pad_to_multiple_of(
196+
user_batch_indices, PADDING_MULTIPLE
197+
)
198+
is_padding_example = idx[:, 0] == -1
199+
safe_idx = np.where(idx == -1, 0, idx)
200+
x = data[safe_idx]
201+
y = labels[safe_idx]
202+
203+
prng_key, subkey = jax.random.split(prng_key)
204+
params, opt_state, noise_state, loss_val = train_step(
205+
params, opt_state, x, y, noise_state, subkey, is_padding_example
199206
)
200207
print(f'Step {step}: Loss: {loss_val:.4f}')
201208

jax_privacy/batch_selection.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,22 @@ def pad_to_multiple_of(indices: np.ndarray, multiple: int) -> np.ndarray:
157157
Returns:
158158
A new 1D array of indices padded with -1.
159159
"""
160-
if indices.ndim > 1:
161-
raise ValueError('pad_to_multiple_of currently expects 1D indices.')
162160
if multiple <= 0:
163161
raise ValueError(f'Padding multiple must be positive, got {multiple}.')
164162
curr_size = indices.shape[0]
165163
pad_size = (multiple - curr_size) % multiple
166-
new_indices = np.full(curr_size + pad_size, -1, dtype=indices.dtype)
164+
if pad_size == 0:
165+
return indices
166+
167+
pad_shape = (pad_size,) + indices.shape[1:]
168+
if indices.ndim == 1:
169+
new_indices = np.full(curr_size + pad_size, -1, dtype=indices.dtype)
170+
else:
171+
new_indices = np.full(pad_shape, -1, dtype=indices.dtype)
172+
173+
if indices.ndim > 1:
174+
return np.concatenate([indices, new_indices], axis=0)
175+
167176
new_indices[:curr_size] = indices
168177
return new_indices
169178

0 commit comments

Comments
 (0)