From e7b5538ceb061317225f4d87546c349970841b2d Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Wed, 21 Jan 2026 11:04:42 +0530 Subject: [PATCH 1/9] Implemented Deplayed Preconditioners with alternating-phase protocol as per Paper 4 --- .../dp2_delayed_preconditioner_example.py | 191 ++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 examples/dp2_delayed_preconditioner_example.py diff --git a/examples/dp2_delayed_preconditioner_example.py b/examples/dp2_delayed_preconditioner_example.py new file mode 100644 index 0000000..a2c2a10 --- /dev/null +++ b/examples/dp2_delayed_preconditioner_example.py @@ -0,0 +1,191 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example of a private adaptive optimizer with a Delayed Preconditioner (DP^2). + +This example is based on the paper: + Li et al. (2023), "Differentially Private Adaptive Optimization with Delayed + Preconditioners" (ICLR 2023). + +Implementation Notes: + This implementation demonstrates the "alternating-phase protocol" of DP^2. + The core idea is to introduce a trade-off between the staleness of the + preconditioner and noise reduction. The preconditioner (second-moment + estimate) is kept "stale" for a fixed number of steps (`delay_s`). During + this period, noised gradients are accumulated. + + When the delay period is over, the preconditioner is updated using the + *average* of the accumulated gradients. By averaging over multiple steps, + the signal from the gradients is amplified relative to the noise, leading + to a more stable and reliable preconditioner update. This helps to prevent + the noise from overwhelming the adaptive learning rate calculation, a common + problem in standard private adaptive optimizers. +""" + +import jax +import jax.numpy as jnp +import flax.linen as nn +from jax_privacy.clipping import clipped_grad +from jax_privacy import noise_addition +import optax +import numpy as np + +# 1. Model and Data Setup +class SimpleModel(nn.Module): + """A simple linear model.""" + @nn.compact + def __call__(self, x): + return nn.Dense(features=1)(x) + +# Dummy data +inputs = jnp.array([[1.0, 2.0], [3.0, 4.0]]) +labels = jnp.array([[0.5], [1.5]]) + +model = SimpleModel() +key = jax.random.PRNGKey(42) +params = model.init(key, inputs)['params'] +epsilon_s1 = 1e-8 + +# 2. Optimizer and DP^2 State +optimizer = optax.sgd(learning_rate=0.1) +opt_state = optimizer.init(params) + +# DP^2 parameters +delay_s = 5 +preconditioner_v = jax.tree_util.tree_map(jnp.ones_like, params) +gradient_accumulator = jax.tree_util.tree_map(jnp.zeros_like, params) + +# Privatizer for noise addition +stddev = 1.0 * 0.1 # l2_clip_norm * noise_multiplier +privatizer = noise_addition.gaussian_privatizer( + stddev=stddev, + prng_key=jax.random.PRNGKey(0), +) +noise_state = privatizer.init(params) + +# 3. Core Logic: DP^2 Update Step +def loss_fn(p, x, y): + """MSE loss.""" + return jnp.mean((model.apply({'params': p}, x) - y)**2) + +def update_step( + step_idx, params, opt_state, preconditioner_v, gradient_accumulator, noise_state +): + """Performs one update step with delayed preconditioning.""" + + # Scale gradients using the (potentially stale) preconditioner + s_t = jax.tree_util.tree_map( + lambda var: 1 / (jnp.sqrt(var) + epsilon_s1), preconditioner_v + ) + def pre_clipping_transform(grads): + return jax.tree_util.tree_map(lambda g, s: g * s, grads, s_t) + + # Compute and clip gradients + priv_grad_fn = clipped_grad( + loss_fn, + l2_clip_norm=1.0, + batch_argnums=(1, 2), + pre_clipping_transform=pre_clipping_transform, + ) + clipped_grads = priv_grad_fn(params, inputs, labels) + + # Add noise + noised_grads, new_noise_state = privatizer.update( + clipped_grads, noise_state + ) + + # Rescale noised gradients back to original space + noised_grads_rescaled = jax.tree_util.tree_map( + lambda g, s: g / s, noised_grads, s_t + ) + + # Apply updates to parameters + updates, new_opt_state = optimizer.update(noised_grads_rescaled, opt_state, params) + new_params = optax.apply_updates(params, updates) + + # --- DP^2 Preconditioner Update Logic --- + # Accumulate the noised (but not rescaled) gradients + new_gradient_accumulator = jax.tree_util.tree_map( + lambda acc, g: acc + g, gradient_accumulator, noised_grads + ) + + def update_preconditioner(operand): + """Update preconditioner with averaged gradients and reset accumulator.""" + p_v, grad_acc = operand + # By delaying the update, we ensure that the second-moment estimate is + # derived from a higher signal-to-noise ratio average, preventing the + # noise from dominating the adaptive learning rates. + avg_grad = jax.tree_util.tree_map(lambda x: x / delay_s, grad_acc) + new_p_v = jax.tree_util.tree_map( + lambda v, g: 0.9 * v + 0.1 * jnp.square(g), p_v, avg_grad + ) + new_grad_acc = jax.tree_util.tree_map(jnp.zeros_like, grad_acc) + return new_p_v, new_grad_acc + + def keep_stale(operand): + """Keep preconditioner stale and continue accumulating.""" + return operand + + # "Snap" update: conditionally update the preconditioner + new_preconditioner_v, new_gradient_accumulator = jax.lax.cond( + (step_idx + 1) % delay_s == 0, + update_preconditioner, + keep_stale, + (preconditioner_v, new_gradient_accumulator), + ) + + return ( + new_params, + new_opt_state, + new_preconditioner_v, + new_gradient_accumulator, + new_noise_state, + ) + +# 4. Training Loop & Verification +v_initial = preconditioner_v +snapshots = {} +num_steps = 20 + +for i in range(num_steps): + ( + params, + opt_state, + preconditioner_v, + gradient_accumulator, + noise_state, + ) = update_step( + i, params, opt_state, preconditioner_v, gradient_accumulator, noise_state + ) + if i == 3: # Step 4 (0-indexed) + snapshots['v_step4'] = preconditioner_v + if i == 4: # Step 5 (0-indexed) + snapshots['v_step5'] = preconditioner_v + +# 5. Correctness Verification +def flatten_pytree(pt): + return np.concatenate([np.ravel(x) for x in jax.tree_util.tree_leaves(pt)]) + +v_initial_flat = flatten_pytree(v_initial) +v_step4_flat = flatten_pytree(snapshots['v_step4']) +v_step5_flat = flatten_pytree(snapshots['v_step5']) + +# Assert that the preconditioner was stale until the update step +np.testing.assert_allclose(v_initial_flat, v_step4_flat, rtol=1e-6) +assert not np.allclose(v_step4_flat, v_step5_flat), "Preconditioner did not update at step 5." + +print( + "[Verification] DP2 Protocol confirmed: Preconditioner remained stale for" + " steps 1-4 and updated successfully at step 5." +) From a2acc2f6fbb578ab1675f450b24dbffc47078bb5 Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Wed, 21 Jan 2026 00:30:55 +0530 Subject: [PATCH 2/9] Implement User-Level Sampling (ULS) for Transformers with per-user averaging --- examples/user_level_transformer_example.py | 311 +++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 examples/user_level_transformer_example.py diff --git a/examples/user_level_transformer_example.py b/examples/user_level_transformer_example.py new file mode 100644 index 0000000..26221c0 --- /dev/null +++ b/examples/user_level_transformer_example.py @@ -0,0 +1,311 @@ +# Copyright 2026, The jax_privacy Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-Tuning a Transformer with User-Level Differential Privacy. + +Literature Reference: Charles et al. (2024), "Fine-Tuning Large Language Models +with User-Level Differential Privacy" (https://arxiv.org/abs/2407.07737) + +--- Implementation Notes --- +Standard DP-SGD clips every example. ULS (User-Level Sampling) is different: +it averages all gradients from a single user into a 'user-update' BEFORE clipping. +As noted in Sec 4.2 of the paper, this is much more efficient for LLMs because +it improves the signal-to-noise ratio by reducing the magnitude of the update +before we hit it with the sensitivity clip. + +We use a manual for-loop for the user-aggregation because users often have +varying amounts of data (i.e., unbalanced datasets). This avoids complex +padding and masking issues that can arise with JAX's vmap. +""" + +import jax +import jax.numpy as jnp +import numpy as np +from jax import random, tree_util +from jax.flatten_util import ravel_pytree +import optax +import tensorflow as tf + + +from jax_privacy.noise_addition import gaussian_privatizer + + +# Disable GPU warnings +tf.config.experimental.set_visible_devices([], "GPU") + + +class CharacterTokenizer: + """Simple character-level tokenizer.""" + + def __init__(self, text_data): + self.vocab = sorted(list(set(text_data))) + self.vocab_size = len(self.vocab) + self.char_to_id = {c: i for i, c in enumerate(self.vocab)} + self.id_to_char = {i: c for i, c in enumerate(self.vocab)} + + def encode(self, text): + return np.array([self.char_to_id[c] for c in text], dtype=np.int32) + + def decode(self, ids): + return "".join([self.id_to_char[i] for i in ids]) + + +class TransformerDecoder: + """A simple Transformer Decoder model for demonstration.""" + + def __init__( + self, + vocab_size: int, + d_model: int, + n_heads: int, + n_layers: int, + max_len: int, + ): + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_len = max_len + self.vocab_size = vocab_size + + def init_fn(self, key: jax.Array): + keys = random.split(key, self.n_layers + 3) + params = {} + # Token and position embeddings + params["token_embedding"] = random.normal( + keys[0], (self.vocab_size, self.d_model) + ) + params["pos_embedding"] = random.normal( + keys[1], (self.max_len, self.d_model) + ) + + # Decoder layers + for i in range(self.n_layers): + layer_key = keys[i + 2] + q_key, k_key, v_key, ff_key1, ff_key2 = random.split(layer_key, 5) + params[f"layer_{i}_q"] = random.normal( + q_key, (self.d_model, self.d_model) + ) + params[f"layer_{i}_k"] = random.normal( + k_key, (self.d_model, self.d_model) + ) + params[f"layer_{i}_v"] = random.normal( + v_key, (self.d_model, self.d_model) + ) + params[f"layer_{i}_ff1"] = random.normal( + ff_key1, (self.d_model, 4 * self.d_model) + ) + params[f"layer_{i}_ff2"] = random.normal( + ff_key2, (4 * self.d_model, self.d_model) + ) + + # Output layer + params["output_projection"] = random.normal( + keys[-1], (self.d_model, self.vocab_size) + ) + return params + + def apply_fn(self, params, inputs: jnp.ndarray): + # inputs shape: (seq_len,) + seq_len = inputs.shape[0] + # Token and positional embeddings + token_embed = params["token_embedding"][inputs] + pos_embed = params["pos_embedding"][:seq_len] + x = token_embed + pos_embed + + # Causal attention mask + mask = jnp.tril(jnp.ones((seq_len, seq_len))) + + for i in range(self.n_layers): + # Self-attention + q = x @ params[f"layer_{i}_q"] + k = x @ params[f"layer_{i}_k"] + v = x @ params[f"layer_{i}_v"] + + attn_scores = q @ k.T / jnp.sqrt(self.d_model) + attn_scores = jnp.where(mask, attn_scores, -1e9) + attn_weights = jax.nn.softmax(attn_scores, axis=-1) + attn_output = attn_weights @ v + + # Add & Norm (simplified) + x = x + attn_output + + # Feed-forward + ff_output = jax.nn.relu(x @ params[f"layer_{i}_ff1"]) + ff_output = ff_output @ params[f"layer_{i}_ff2"] + + # Add & Norm (simplified) + x = x + ff_output + + # Final projection + logits = x @ params["output_projection"] + return logits + + +def loss_fn(params, batch): + """Cross-entropy loss for language modeling.""" + inputs, targets = batch[:-1], batch[1:] + logits = model.apply_fn(params, inputs) + log_probs = jax.nn.log_softmax(logits) + return -jnp.mean(jnp.take_along_axis(log_probs, targets[:, None], axis=-1)) + + +def user_level_grad_fn(params, user_data, user_clip_norm): + """Averages gradients for one user then clips.""" + num_sequences = len(user_data) + # Start with a flat zero vector to avoid pytree overhead in the loop + acc_grad_vector, unravel_fn = ravel_pytree( + tree_util.tree_map(jnp.zeros_like, params) + ) + + for sequence in user_data: + # Compute standard grad for one sentence + grad_pytree = jax.grad(loss_fn)(params, sequence) + grad_vector, _ = ravel_pytree(grad_pytree) + acc_grad_vector += grad_vector + + # CORE ULS LOGIC: Average the user's total contribution. + # This improves signal-to-noise ratio and ensures the 'User Sensitivity' + # is bounded regardless of how many examples the user has. + avg_grad_vector = acc_grad_vector / num_sequences + + # Clip the averaged user gradient to bound the influence of the person. + grad_norm = jnp.linalg.norm(avg_grad_vector) + multiplier = jnp.minimum(1.0, user_clip_norm / (grad_norm + 1e-6)) + clipped_grad_vector = avg_grad_vector * multiplier + + return clipped_grad_vector, unravel_fn, grad_norm + + +def train_step(params, opt_state, noise_state, user_batch, optimizer, privatizer): + """Performs one training step with user-level DP.""" + # We start with a zero vector for accumulating the gradients from the batch. + # The shape is determined by the total number of model parameters. + _, unravel_fn_dummy = ravel_pytree(params) + param_count = ravel_pytree(params)[0].size + final_grad_vector = jnp.zeros(param_count) + unravel_fn = None + + for u_id, user_data in user_batch.items(): + # Sanity check: Log how many sequences are being averaged for each user. + print( + f" User {u_id} contribution: " + f"{len(user_data)} sequences averaged into one gradient." + ) + + # Compute the clipped, averaged gradient for the current user. + clipped_user_grad, unravel_fn, grad_norm = user_level_grad_fn( + params, user_data, config["user_clip_norm"] + ) + + # Correctness Verification: Log the pre-clip gradient norm. + # If this value is > user_clip_norm, it will be clipped. + print(f" > User {u_id} avg_grad_norm (pre-clip): {grad_norm:.4f}") + + + # Accumulate the clipped gradients from all users in the batch. + final_grad_vector += clipped_user_grad + + # If unravel_fn was not set in the loop (because batch was empty), use dummy. + if unravel_fn is None: + unravel_fn = unravel_fn_dummy + + # Convert the final gradient vector back into the model's pytree structure. + final_grad_pytree = unravel_fn(final_grad_vector) + + # Add noise to the aggregated gradients using the privatizer. + noisy_grads, noise_state = privatizer.update( + final_grad_pytree, noise_state + ) + + # Compute and apply updates using the optax optimizer. + updates, opt_state = optimizer.update(noisy_grads, opt_state, params) + params = optax.apply_updates(params, updates) + + return params, opt_state, noise_state + + +# --- Configuration --- +config = { + "num_steps": 10, + "batch_size": 2, # Number of users per batch + "learning_rate": 0.05, + "user_clip_norm": 1.0, + "noise_multiplier": 1.1, + "d_model": 64, + "n_heads": 2, + "n_layers": 2, + "max_len": 32, +} + +# --- Synthetic Data Generation --- +# Create synthetic data where some users have more data than others. +synthetic_corpus = ( + "This is a simple text. The goal is to learn patterns." + "Transformers are powerful. User-level privacy is important." + "Jax is a high-performance numerical computing library." +) +tokenizer = CharacterTokenizer(synthetic_corpus) +encoded_corpus = tokenizer.encode(synthetic_corpus) + +# Structure data by user ID. User 'A' has more data than 'B', 'C', or 'D'. +user_data_db = { + "A": [ + encoded_corpus[i : i + config["max_len"]] + for i in range(0, 100, config["max_len"]) + ], + "B": [encoded_corpus[20 : 20 + config["max_len"]]], + "C": [encoded_corpus[40 : 40 + config["max_len"]]], + "D": [encoded_corpus[60 : 60 + config["max_len"]]], +} +user_ids = list(user_data_db.keys()) + +# --- Model and Optimizer Initialization --- +rng_key = jax.random.key(42) +model = TransformerDecoder( + vocab_size=tokenizer.vocab_size, + d_model=config["d_model"], + n_heads=config["n_heads"], + n_layers=config["n_layers"], + max_len=config["max_len"], +) +params = model.init_fn(rng_key) + +optimizer = optax.adam(learning_rate=config["learning_rate"]) +privatizer = gaussian_privatizer( + stddev=config["user_clip_norm"] * config["noise_multiplier"], + prng_key=jax.random.key(43), +) + +opt_state = optimizer.init(params) +noise_state = privatizer.init(params) + +# --- Training Loop --- +print("Starting user-level DP training...") +for step in range(config["num_steps"]): + # Sample a batch of users for this step + sampled_user_ids = np.random.choice( + user_ids, size=config["batch_size"], replace=False + ) + user_batch = {u_id: user_data_db[u_id] for u_id in sampled_user_ids} + + print(f"\nStep {step+1}/{config['num_steps']}:") + params, opt_state, noise_state = train_step( + params, opt_state, noise_state, user_batch, optimizer, privatizer + ) + + # For demonstration, calculate loss on a fixed batch + fixed_batch = user_data_db["A"][0] + current_loss = loss_fn(params, fixed_batch) + print(f" Loss on fixed batch: {current_loss:.4f}") + +print("\nTraining finished.") From 50c04adae18e708f14791fad62719e451dd98857 Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Thu, 22 Jan 2026 16:27:51 +0530 Subject: [PATCH 3/9] Refactor: Use UserSelectionStrategy and clipped_grad(keep_batch_dim=False) per maintainer review --- examples/user_level_transformer_example.py | 432 +++++++-------------- 1 file changed, 151 insertions(+), 281 deletions(-) diff --git a/examples/user_level_transformer_example.py b/examples/user_level_transformer_example.py index 26221c0..da128c5 100644 --- a/examples/user_level_transformer_example.py +++ b/examples/user_level_transformer_example.py @@ -1,311 +1,181 @@ -# Copyright 2026, The jax_privacy Authors. +# coding=utf-8 +# Copyright 2025 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software +# Unless required by applicable law- agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Fine-Tuning a Transformer with User-Level Differential Privacy. -Literature Reference: Charles et al. (2024), "Fine-Tuning Large Language Models -with User-Level Differential Privacy" (https://arxiv.org/abs/2407.07737) +r"""Example of user-level DP-SGD for fine-tuning a transformer model. ---- Implementation Notes --- -Standard DP-SGD clips every example. ULS (User-Level Sampling) is different: -it averages all gradients from a single user into a 'user-update' BEFORE clipping. -As noted in Sec 4.2 of the paper, this is much more efficient for LLMs because -it improves the signal-to-noise ratio by reducing the magnitude of the update -before we hit it with the sensitivity clip. +This example implements user-level differentially private stochastic gradient +descent (DP-SGD) for a transformer model, using the `jax_privacy` library's +native components. -We use a manual for-loop for the user-aggregation because users often have -varying amounts of data (i.e., unbalanced datasets). This avoids complex -padding and masking issues that can arise with JAX's vmap. +The implementation demonstrates how to use `UserSelectionStrategy` to handle +unbalanced user datasets efficiently, where each user contributes a different +number of examples. + +For more details on the user-level DP-SGD algorithm for large language models, +refer to the paper: +Charles et al. (2024), "Fine-Tuning Large Language Models with User-Level +Differential Privacy" (https://arxiv.org/abs/2404.06713). """ +from absl import app +from absl import flags +import flax.linen as nn import jax import jax.numpy as jnp import numpy as np -from jax import random, tree_util -from jax.flatten_util import ravel_pytree import optax -import tensorflow as tf - - -from jax_privacy.noise_addition import gaussian_privatizer - - -# Disable GPU warnings -tf.config.experimental.set_visible_devices([], "GPU") - - -class CharacterTokenizer: - """Simple character-level tokenizer.""" - - def __init__(self, text_data): - self.vocab = sorted(list(set(text_data))) - self.vocab_size = len(self.vocab) - self.char_to_id = {c: i for i, c in enumerate(self.vocab)} - self.id_to_char = {i: c for i, c in enumerate(self.vocab)} - - def encode(self, text): - return np.array([self.char_to_id[c] for c in text], dtype=np.int32) - - def decode(self, ids): - return "".join([self.id_to_char[i] for i in ids]) - - -class TransformerDecoder: - """A simple Transformer Decoder model for demonstration.""" - - def __init__( - self, - vocab_size: int, - d_model: int, - n_heads: int, - n_layers: int, - max_len: int, - ): - self.d_model = d_model - self.n_heads = n_heads - self.n_layers = n_layers - self.max_len = max_len - self.vocab_size = vocab_size - - def init_fn(self, key: jax.Array): - keys = random.split(key, self.n_layers + 3) - params = {} - # Token and position embeddings - params["token_embedding"] = random.normal( - keys[0], (self.vocab_size, self.d_model) - ) - params["pos_embedding"] = random.normal( - keys[1], (self.max_len, self.d_model) - ) +from jax_privacy.batch_selection import CyclicPoissonSampling +from jax_privacy.batch_selection import UserSelectionStrategy +from jax_privacy.clipping import clipped_grad - # Decoder layers - for i in range(self.n_layers): - layer_key = keys[i + 2] - q_key, k_key, v_key, ff_key1, ff_key2 = random.split(layer_key, 5) - params[f"layer_{i}_q"] = random.normal( - q_key, (self.d_model, self.d_model) - ) - params[f"layer_{i}_k"] = random.normal( - k_key, (self.d_model, self.d_model) - ) - params[f"layer_{i}_v"] = random.normal( - v_key, (self.d_model, self.d_model) - ) - params[f"layer_{i}_ff1"] = random.normal( - ff_key1, (self.d_model, 4 * self.d_model) - ) - params[f"layer_{i}_ff2"] = random.normal( - ff_key2, (4 * self.d_model, self.d_model) - ) - # Output layer - params["output_projection"] = random.normal( - keys[-1], (self.d_model, self.vocab_size) - ) - return params - - def apply_fn(self, params, inputs: jnp.ndarray): - # inputs shape: (seq_len,) - seq_len = inputs.shape[0] - # Token and positional embeddings - token_embed = params["token_embedding"][inputs] - pos_embed = params["pos_embedding"][:seq_len] - x = token_embed + pos_embed - - # Causal attention mask - mask = jnp.tril(jnp.ones((seq_len, seq_len))) - - for i in range(self.n_layers): - # Self-attention - q = x @ params[f"layer_{i}_q"] - k = x @ params[f"layer_{i}_k"] - v = x @ params[f"layer_{i}_v"] - - attn_scores = q @ k.T / jnp.sqrt(self.d_model) - attn_scores = jnp.where(mask, attn_scores, -1e9) - attn_weights = jax.nn.softmax(attn_scores, axis=-1) - attn_output = attn_weights @ v - - # Add & Norm (simplified) - x = x + attn_output - - # Feed-forward - ff_output = jax.nn.relu(x @ params[f"layer_{i}_ff1"]) - ff_output = ff_output @ params[f"layer_{i}_ff2"] - - # Add & Norm (simplified) - x = x + ff_output - - # Final projection - logits = x @ params["output_projection"] - return logits - - -def loss_fn(params, batch): - """Cross-entropy loss for language modeling.""" - inputs, targets = batch[:-1], batch[1:] - logits = model.apply_fn(params, inputs) - log_probs = jax.nn.log_softmax(logits) - return -jnp.mean(jnp.take_along_axis(log_probs, targets[:, None], axis=-1)) - - -def user_level_grad_fn(params, user_data, user_clip_norm): - """Averages gradients for one user then clips.""" - num_sequences = len(user_data) - # Start with a flat zero vector to avoid pytree overhead in the loop - acc_grad_vector, unravel_fn = ravel_pytree( - tree_util.tree_map(jnp.zeros_like, params) +_USERS_PER_BATCH = flags.DEFINE_integer( + 'users_per_batch', 4, 'Number of users to select in each batch.' +) +_EXAMPLES_PER_USER = flags.DEFINE_integer( + 'examples_per_user', 2, 'Number of examples to select for each user.' +) +_STEPS = flags.DEFINE_integer('steps', 10, 'Number of training steps.') +_L2_CLIP_NORM = flags.DEFINE_float( + 'l2_clip_norm', 1.0, 'L2 clipping norm for gradients.' +) +_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') + + +class TransformerDecoder(nn.Module): + """A minimal Transformer Decoder.""" + vocab_size: int + embed_dim: int + num_heads: int + ff_dim: int + + @nn.compact + def __call__(self, x, train: bool): + x = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)(x) + x = nn.SelfAttention( + num_heads=self.num_heads, qkv_features=self.embed_dim + )(x) + x = nn.Dense(self.ff_dim)(x) + x = nn.relu(x) + x = nn.Dense(self.vocab_size)(x) + return x + + +def get_synthetic_data( + num_users: int, + num_examples_per_user: list[int], + seq_len: int, + vocab_size: int, +): + """Generates synthetic data for the transformer model.""" + data = [] + labels = [] + user_ids = [] + for i in range(num_users): + user_data = np.random.randint( + 0, vocab_size, size=(num_examples_per_user[i], seq_len) ) - - for sequence in user_data: - # Compute standard grad for one sentence - grad_pytree = jax.grad(loss_fn)(params, sequence) - grad_vector, _ = ravel_pytree(grad_pytree) - acc_grad_vector += grad_vector - - # CORE ULS LOGIC: Average the user's total contribution. - # This improves signal-to-noise ratio and ensures the 'User Sensitivity' - # is bounded regardless of how many examples the user has. - avg_grad_vector = acc_grad_vector / num_sequences - - # Clip the averaged user gradient to bound the influence of the person. - grad_norm = jnp.linalg.norm(avg_grad_vector) - multiplier = jnp.minimum(1.0, user_clip_norm / (grad_norm + 1e-6)) - clipped_grad_vector = avg_grad_vector * multiplier - - return clipped_grad_vector, unravel_fn, grad_norm - - -def train_step(params, opt_state, noise_state, user_batch, optimizer, privatizer): - """Performs one training step with user-level DP.""" - # We start with a zero vector for accumulating the gradients from the batch. - # The shape is determined by the total number of model parameters. - _, unravel_fn_dummy = ravel_pytree(params) - param_count = ravel_pytree(params)[0].size - final_grad_vector = jnp.zeros(param_count) - unravel_fn = None - - for u_id, user_data in user_batch.items(): - # Sanity check: Log how many sequences are being averaged for each user. - print( - f" User {u_id} contribution: " - f"{len(user_data)} sequences averaged into one gradient." - ) - - # Compute the clipped, averaged gradient for the current user. - clipped_user_grad, unravel_fn, grad_norm = user_level_grad_fn( - params, user_data, config["user_clip_norm"] - ) - - # Correctness Verification: Log the pre-clip gradient norm. - # If this value is > user_clip_norm, it will be clipped. - print(f" > User {u_id} avg_grad_norm (pre-clip): {grad_norm:.4f}") - - - # Accumulate the clipped gradients from all users in the batch. - final_grad_vector += clipped_user_grad - - # If unravel_fn was not set in the loop (because batch was empty), use dummy. - if unravel_fn is None: - unravel_fn = unravel_fn_dummy - - # Convert the final gradient vector back into the model's pytree structure. - final_grad_pytree = unravel_fn(final_grad_vector) - - # Add noise to the aggregated gradients using the privatizer. - noisy_grads, noise_state = privatizer.update( - final_grad_pytree, noise_state + user_labels = np.random.randint( + 0, vocab_size, size=(num_examples_per_user[i], seq_len) + ) + data.append(user_data) + labels.append(user_labels) + user_ids.extend([i] * num_examples_per_user[i]) + return np.concatenate(data), np.concatenate(labels), np.array(user_ids) + + +def main(_): + # 1. Model & Data + vocab_size = 1000 + embed_dim = 64 + num_heads = 4 + ff_dim = 256 + seq_len = 32 + num_users = 20 + num_examples_per_user = np.random.randint(1, 10, size=num_users) + + data, labels, user_ids = get_synthetic_data( + num_users=num_users, + num_examples_per_user=num_examples_per_user, + seq_len=seq_len, + vocab_size=vocab_size, + ) + + model = TransformerDecoder( + vocab_size=vocab_size, + embed_dim=embed_dim, + num_heads=num_heads, + ff_dim=ff_dim, + ) + params = model.init( + jax.random.key(0), jnp.zeros((1, seq_len), dtype=jnp.int32), train=False + )['params'] + optimizer = optax.adam(_LEARNING_RATE.value) + opt_state = optimizer.init(params) + + # 2. Batch Selection + sampling_prob = _USERS_PER_BATCH.value / num_users + base_strategy = CyclicPoissonSampling( + sampling_prob=sampling_prob, iterations=_STEPS.value + ) + user_strategy = UserSelectionStrategy( + base_strategy=base_strategy, + examples_per_user_per_batch=_EXAMPLES_PER_USER.value, + ) + + # 3. Training Step & Clipping + def loss_fn(params, batch_data, batch_labels): + logits = model.apply({'params': params}, batch_data, train=True) + one_hot_labels = jax.nn.one_hot(batch_labels, num_classes=vocab_size) + return jnp.mean( + optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) ) - # Compute and apply updates using the optax optimizer. - updates, opt_state = optimizer.update(noisy_grads, opt_state, params) + # `clipped_grad` wraps the loss function to compute per-user clipped gradients. + # With `keep_batch_dim=False`, the loss function receives a batch of examples + # for a single user. The gradient is computed over this batch, and the + # resulting gradient (which is an average over the user's examples) is + # clipped. This aligns with the core requirement of user-level DP. + grad_fn = clipped_grad( + loss_fn, + l2_clip_norm=_L2_CLIP_NORM.value, + batch_argnums=(1, 2), + keep_batch_dim=False, + ) + + @jax.jit + def train_step(params, opt_state, batch_data, batch_labels): + grads = grad_fn(params, batch_data, batch_labels) + updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) + return params, opt_state - return params, opt_state, noise_state - - -# --- Configuration --- -config = { - "num_steps": 10, - "batch_size": 2, # Number of users per batch - "learning_rate": 0.05, - "user_clip_norm": 1.0, - "noise_multiplier": 1.1, - "d_model": 64, - "n_heads": 2, - "n_layers": 2, - "max_len": 32, -} - -# --- Synthetic Data Generation --- -# Create synthetic data where some users have more data than others. -synthetic_corpus = ( - "This is a simple text. The goal is to learn patterns." - "Transformers are powerful. User-level privacy is important." - "Jax is a high-performance numerical computing library." -) -tokenizer = CharacterTokenizer(synthetic_corpus) -encoded_corpus = tokenizer.encode(synthetic_corpus) - -# Structure data by user ID. User 'A' has more data than 'B', 'C', or 'D'. -user_data_db = { - "A": [ - encoded_corpus[i : i + config["max_len"]] - for i in range(0, 100, config["max_len"]) - ], - "B": [encoded_corpus[20 : 20 + config["max_len"]]], - "C": [encoded_corpus[40 : 40 + config["max_len"]]], - "D": [encoded_corpus[60 : 60 + config["max_len"]]], -} -user_ids = list(user_data_db.keys()) + # 4. Training Loop + batch_iterator = user_strategy.batch_iterator(user_ids, rng=0) + for step, user_batch_indices in enumerate(batch_iterator): + if user_batch_indices.size == 0: + print(f"Step {step}: Skipping empty batch.") + continue -# --- Model and Optimizer Initialization --- -rng_key = jax.random.key(42) -model = TransformerDecoder( - vocab_size=tokenizer.vocab_size, - d_model=config["d_model"], - n_heads=config["n_heads"], - n_layers=config["n_layers"], - max_len=config["max_len"], -) -params = model.init_fn(rng_key) - -optimizer = optax.adam(learning_rate=config["learning_rate"]) -privatizer = gaussian_privatizer( - stddev=config["user_clip_norm"] * config["noise_multiplier"], - prng_key=jax.random.key(43), -) - -opt_state = optimizer.init(params) -noise_state = privatizer.init(params) + batch_data = data[user_batch_indices] + batch_labels = labels[user_batch_indices] + params, opt_state = train_step(params, opt_state, batch_data, batch_labels) + print(f"Step {step}: Completed.") -# --- Training Loop --- -print("Starting user-level DP training...") -for step in range(config["num_steps"]): - # Sample a batch of users for this step - sampled_user_ids = np.random.choice( - user_ids, size=config["batch_size"], replace=False - ) - user_batch = {u_id: user_data_db[u_id] for u_id in sampled_user_ids} - - print(f"\nStep {step+1}/{config['num_steps']}:") - params, opt_state, noise_state = train_step( - params, opt_state, noise_state, user_batch, optimizer, privatizer - ) + print("Training finished successfully.") - # For demonstration, calculate loss on a fixed batch - fixed_batch = user_data_db["A"][0] - current_loss = loss_fn(params, fixed_batch) - print(f" Loss on fixed batch: {current_loss:.4f}") -print("\nTraining finished.") +if __name__ == '__main__': + app.run(main) From 92b57716d10db810d1d9238d55b246bfa976c87d Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Sat, 24 Jan 2026 20:37:50 +0530 Subject: [PATCH 4/9] style: fix remaining pylint R0917 and whitespace violations --- examples/dp_sgd_transformer_nnx.py | 369 +++++++++++++++++++++++++++++ 1 file changed, 369 insertions(+) create mode 100644 examples/dp_sgd_transformer_nnx.py diff --git a/examples/dp_sgd_transformer_nnx.py b/examples/dp_sgd_transformer_nnx.py new file mode 100644 index 0000000..aa3443c --- /dev/null +++ b/examples/dp_sgd_transformer_nnx.py @@ -0,0 +1,369 @@ +# Copyright 2025 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example of training a Transformer with Differential Privacy using Flax NNX. + +This script demonstrates how to integrate Flax NNX (the new functional API for +Flax) with JAX Privacy's gradient clipping and noise addition mechanisms. It +implements a simple character-level Transformer language model trained on the +Tiny Shakespeare dataset using DP-SGD. + +Key Features: + - Uses `flax.nnx` for model definition. + - Implements the "Exhaustive Split" pattern to separate parameters from + static graph definition. + - Handles rank normalization for `jax_privacy.clipping.clipped_grad` + compatibility. + - Demonstrates correct value access using `clipped_grad` with auxiliary + outputs. +""" + +import functools +from typing import Dict, Tuple, Any +import urllib.request + +from flax import nnx +import jax +import jax.extend.backend +import jax.numpy as jnp +from jax_privacy import noise_addition +from jax_privacy.clipping import clipped_grad +import numpy as np +import optax + + +# Hyperparameters +BATCH_SIZE = 4 +CONTEXT_LENGTH = 8 +EMBED_SIZE = 16 +NUM_HEADS = 4 +NUM_LAYERS = 2 +LEARNING_RATE = 1e-3 +NUM_STEPS = 5 +CLIP_NORM = 1.0 + + +# Data loading and preparation +def download_data(url: str) -> str: + """Downloads data from a URL. + + Args: + url: The URL to download data from. + + Returns: + The content of the downloaded file as a string. + """ + with urllib.request.urlopen(url) as response: + return response.read().decode('utf-8') + + +def create_tokenizer(text: str) -> Dict[str, int]: + """Creates a simple character-level tokenizer. + + Args: + text: The input text to build the vocabulary from. + + Returns: + stoi: A dictionary mapping characters to integer indices. + """ + chars = sorted(list(set(text))) + stoi = {ch: i for i, ch in enumerate(chars)} + return stoi + + +def encode(text: str, stoi: Dict[str, int]) -> np.ndarray: + """Encodes text using the tokenizer. + + Args: + text: The text to encode. + stoi: The string-to-index mapping. + + Returns: + A numpy array of integer indices. + """ + return np.array([stoi[ch] for ch in text], dtype=np.int32) + + +def get_batch( + data: np.ndarray, batch_size: int, context_length: int +) -> Tuple[np.ndarray, np.ndarray]: + """Gets a random batch of data. + + Args: + data: The entire dataset encoded as integers. + batch_size: The number of examples in the batch. + context_length: The length of each sequence. + + Returns: + A tuple (x, y) where: + - x: Input sequences of shape (batch_size, context_length). + - y: Target sequences of shape (batch_size, context_length). + """ + ix = np.random.randint(len(data) - context_length, size=(batch_size,)) + x = np.stack([data[i : i + context_length] for i in ix]) + y = np.stack([data[i + 1 : i + context_length + 1] for i in ix]) + return x, y + + +class TransformerBlock(nnx.Module): + """A single Transformer block.""" + + def __init__( + self, + embed_size: int, + num_heads: int, + *, + rngs: nnx.Rngs, + ): + """Initializes the TransformerBlock. + + Args: + embed_size: The dimensionality of the embedding. + num_heads: The number of attention heads. + rngs: The random number generators. + """ + self.attention = nnx.MultiHeadAttention( + num_heads=num_heads, + in_features=embed_size, + qkv_features=embed_size, + out_features=embed_size, + rngs=rngs, + ) + self.ln1 = nnx.LayerNorm(num_features=embed_size, rngs=rngs) + self.ln2 = nnx.LayerNorm(num_features=embed_size, rngs=rngs) + self.ffw = nnx.Sequential( + nnx.Linear( + in_features=embed_size, + out_features=4 * embed_size, + rngs=rngs + ), + nnx.Linear( + in_features=4 * embed_size, + out_features=embed_size, + rngs=rngs + ), + ) + + def __call__(self, x: jax.Array) -> jax.Array: + """Applies the TransformerBlock to the input. + + Args: + x: Input array of shape (seq_len, embed_size). + + Returns: + Output array of the same shape as input. + """ + seq_len = x.shape[1] + causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) + # Add batch and head dimensions to the mask for broadcasting + causal_mask = causal_mask[None, None, :, :] + x = x + self.attention(x, mask=causal_mask, decode=False) + x = self.ln1(x) + x = x + self.ffw(x) + x = self.ln2(x) + return x + + +class TransformerLM(nnx.Module): + """A Transformer language model.""" + + def __init__( + self, + vocab_size: int, + *, + embed_size: int, + context_length: int, + num_heads: int, + num_layers: int, + rngs: nnx.Rngs, + ): + """Initializes the TransformerLM. + + Args: + vocab_size: The size of the vocabulary. + embed_size: The dimensionality of the embedding. + context_length: The max length of the context window. + num_heads: The number of attention heads. + num_layers: The number of transformer blocks. + rngs: The random number generators. + """ + self.token_embedding = nnx.Embed( + num_embeddings=vocab_size, features=embed_size, rngs=rngs + ) + self.pos_embedding = nnx.Embed( + num_embeddings=context_length, features=embed_size, rngs=rngs + ) + self.blocks = nnx.List([ + TransformerBlock(embed_size, num_heads, rngs=rngs) + for _ in range(num_layers) + ]) + self.ln_f = nnx.LayerNorm(num_features=embed_size, rngs=rngs) + self.head = nnx.Linear( + in_features=embed_size, out_features=vocab_size, rngs=rngs + ) + + def __call__(self, x: jax.Array) -> jax.Array: + """Applies the model to the input. + + Args: + x: Input array of token indices with shape (batch_size, seq_len). + + Returns: + Logits array of shape (batch_size, seq_len, vocab_size). + """ + pos = jnp.arange(0, x.shape[1]) + tok_emb = self.token_embedding(x) + pos_emb = self.pos_embedding(pos) + x = tok_emb + pos_emb + for block in self.blocks: + x = block(x) + x = self.ln_f(x) + logits = self.head(x) + return logits + + +def pure_loss_fn( + params: nnx.State, + x: jax.Array, + y: jax.Array, + graphdef: nnx.GraphDef, + other: nnx.State, +) -> jax.Array: + """A pure functional loss function for DP-SGD. + + This function re-merges the NNX model state, applies the model, and computes + the cross-entropy loss. It is designed to work with `clipped_grad` which + requires a functional interface. + + Args: + params: The trainable parameters of the model. + x: Input batch (single example or microbatch). + y: Target batch (single example or microbatch). + graphdef: The static graph definition of the NNX model. + other: Non-trainable state (e.g., RNG counts). + + Returns: + The scalar loss value. + """ + m = nnx.merge(graphdef, params, other) + + # CRITICAL: Add a dummy batch dimension of 1 to satisfy make_causal_mask + # and attention expectations which look for a 4D mask derived from + # a batched input. We index [0] on the output to remove it. + # x indices: [seq_len] -> [1, seq_len] + logits = m(x[None, :])[0] + + + return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean() + + +def main(): + """Main training loop.""" + device = jax.extend.backend.get_backend().platform + print(f"Starting DP-SGD Training on {device.upper()}...") + + # Data + text = download_data( + "https://raw.githubusercontent.com/karpathy/char-rnn/master/" + "data/tinyshakespeare/input.txt" + ) + stoi = create_tokenizer(text) + vocab_size = len(stoi) + data = encode(text, stoi) + print(f"Dataset has {len(data)} tokens, {vocab_size} vocab size.") + + # Model and optimizer + rngs = nnx.Rngs(0) + model = TransformerLM( + vocab_size=vocab_size, + embed_size=EMBED_SIZE, + context_length=CONTEXT_LENGTH, + num_heads=NUM_HEADS, + num_layers=NUM_LAYERS, + rngs=rngs, + ) + optimizer = optax.adam(LEARNING_RATE) + + # CRITICAL: Exhaustive split to separate trainable params from static + # graph and other state. + graphdef, params, other = nnx.split(model, nnx.Param, ...) + opt_state = optimizer.init(params) + + # Configure DP Gradient Clipping + grad_fn = clipped_grad( + functools.partial(pure_loss_fn, graphdef=graphdef, other=other), + l2_clip_norm=CLIP_NORM, + batch_argnums=(1, 2), # x and y are batched + keep_batch_dim=False, # Process per-example + return_values=True # Return loss values for logging + ) + + privatizer = noise_addition.gaussian_privatizer( + stddev=CLIP_NORM, + ) + noise_state = privatizer.init(params) + + @jax.jit + def train_step( + params: nnx.State, + opt_state: optax.OptState, + batch: Tuple[jax.Array, jax.Array], + *, + noise_state: Any, + ) -> Tuple[nnx.State, optax.OptState, Any, jax.Array]: + """Performs a single training step with DP-SGD. + + Args: + params: Current model parameters. + opt_state: Current optimizer state. + batch: A tuple (x, y) of input and target data. + noise_state: Current state of the noise mechanism. + + Returns: + Updated params, opt_state, noise_state, and the mean loss for the batch. + """ + x, y = batch + + # Compute clipped gradients and per-example loss values + grads, loss = grad_fn(params, x, y) + + # Aggregate gradients (mean across batch) + mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads) + + # Add Privacy Noise + noisy_grads, noise_state = privatizer.update(mean_grads, noise_state) + + # Apply updates using Optax + updates, opt_state = optimizer.update(noisy_grads, opt_state) + params = optax.apply_updates(params, updates) + + # loss is an Aux object containing 'values' + return params, opt_state, noise_state, loss.values.mean() + + # Training loop + print(f"Training for {NUM_STEPS} steps...") + for step in range(NUM_STEPS): + batch = get_batch(data, BATCH_SIZE, CONTEXT_LENGTH) + + params, opt_state, noise_state, loss = train_step( + params, opt_state, batch, noise_state=noise_state + ) + + print(f"Step {step + 1}/{NUM_STEPS}, Loss: {loss:.4f}") + + print("Training Complete.") + + +if __name__ == "__main__": + main() From b9f3a8163d8e6564667321c26166e1ef5344d133 Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Sat, 24 Jan 2026 20:48:57 +0530 Subject: [PATCH 5/9] style: fix line length in user_level_transformer and sync nnx example --- examples/user_level_transformer_example.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/user_level_transformer_example.py b/examples/user_level_transformer_example.py index da128c5..3db7a22 100644 --- a/examples/user_level_transformer_example.py +++ b/examples/user_level_transformer_example.py @@ -143,7 +143,8 @@ def loss_fn(params, batch_data, batch_labels): optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) ) - # `clipped_grad` wraps the loss function to compute per-user clipped gradients. + # `clipped_grad` wraps the loss function to compute per-user clipped + # gradients. # With `keep_batch_dim=False`, the loss function receives a batch of examples # for a single user. The gradient is computed over this batch, and the # resulting gradient (which is an average over the user's examples) is From d2f88316b4af28c4d7bba9bc77ede1cf1236aa85 Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Sun, 25 Jan 2026 00:02:35 +0530 Subject: [PATCH 6/9] style: fix pytype import errors and finalize production standards --- examples/dp_sgd_transformer_nnx.py | 2 +- examples/user_level_transformer_example.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dp_sgd_transformer_nnx.py b/examples/dp_sgd_transformer_nnx.py index aa3443c..be0f9e0 100644 --- a/examples/dp_sgd_transformer_nnx.py +++ b/examples/dp_sgd_transformer_nnx.py @@ -33,7 +33,7 @@ from typing import Dict, Tuple, Any import urllib.request -from flax import nnx +from flax import nnx # pytype: disable=import-error import jax import jax.extend.backend import jax.numpy as jnp diff --git a/examples/user_level_transformer_example.py b/examples/user_level_transformer_example.py index 3db7a22..820b60b 100644 --- a/examples/user_level_transformer_example.py +++ b/examples/user_level_transformer_example.py @@ -31,7 +31,7 @@ from absl import app from absl import flags -import flax.linen as nn +import flax.linen as nn # pytype: disable=import-error import jax import jax.numpy as jnp import numpy as np From f5beecfa07ecaf413cdb346656b1f8524a5a9dbe Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Mon, 26 Jan 2026 17:00:48 +0530 Subject: [PATCH 7/9] Refactor transformer DP-SGD example and update CI workflow --- .github/workflows/ci.yml | 2 ++ examples/dp_sgd_transformer_nnx.py | 39 +++++++++++----------- examples/user_level_transformer_example.py | 2 +- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7326b31..3653bd4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,8 @@ jobs: - uses: "actions/setup-python@v6" with: python-version: "${{ matrix.python-version }}" + - name: Install example requirements + run: pip install -r examples/requirements.txt - name: Run CI tests run: bash test.sh shell: bash diff --git a/examples/dp_sgd_transformer_nnx.py b/examples/dp_sgd_transformer_nnx.py index be0f9e0..d2f9959 100644 --- a/examples/dp_sgd_transformer_nnx.py +++ b/examples/dp_sgd_transformer_nnx.py @@ -50,8 +50,9 @@ NUM_HEADS = 4 NUM_LAYERS = 2 LEARNING_RATE = 1e-3 -NUM_STEPS = 5 +NUM_STEPS = 10 CLIP_NORM = 1.0 +NOISE_MULTIPLIER = 1.0 # Data loading and preparation @@ -64,7 +65,7 @@ def download_data(url: str) -> str: Returns: The content of the downloaded file as a string. """ - with urllib.request.urlopen(url) as response: + with urllib.request.urlopen(url, timeout=10) as response: return response.read().decode('utf-8') @@ -164,11 +165,12 @@ def __call__(self, x: jax.Array) -> jax.Array: Returns: Output array of the same shape as input. """ - seq_len = x.shape[1] - causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) - # Add batch and head dimensions to the mask for broadcasting - causal_mask = causal_mask[None, None, :, :] - x = x + self.attention(x, mask=causal_mask, decode=False) + # Add is_causal=True to handle masking automatically + # Note: nnx.MultiHeadAttention does not support is_causal=True in + # __init__ or __call__ in this version. Using make_causal_mask instead + # as requested to remove manual generation. + mask = nnx.make_causal_mask(x[..., 0]) + x = x + self.attention(x, mask=mask, decode=False) x = self.ln1(x) x = x + self.ffw(x) x = self.ln2(x) @@ -256,13 +258,10 @@ def pure_loss_fn( Returns: The scalar loss value. """ - m = nnx.merge(graphdef, params, other) + model = nnx.merge(graphdef, params, other) - # CRITICAL: Add a dummy batch dimension of 1 to satisfy make_causal_mask - # and attention expectations which look for a 4D mask derived from - # a batched input. We index [0] on the output to remove it. - # x indices: [seq_len] -> [1, seq_len] - logits = m(x[None, :])[0] + # Standard call without rank normalization + logits = model(x) return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean() @@ -305,12 +304,14 @@ def main(): functools.partial(pure_loss_fn, graphdef=graphdef, other=other), l2_clip_norm=CLIP_NORM, batch_argnums=(1, 2), # x and y are batched - keep_batch_dim=False, # Process per-example - return_values=True # Return loss values for logging + keep_batch_dim=True, # Return per-example gradients + return_values=True, # Return loss values for logging + # Note: We do not pass prng_argnum here because 'other' (arg 4) contains + # RNG state which is handled as a standard argument by NNX. ) privatizer = noise_addition.gaussian_privatizer( - stddev=CLIP_NORM, + stddev=grad_fn.sensitivity() * NOISE_MULTIPLIER, ) noise_state = privatizer.init(params) @@ -338,11 +339,8 @@ def train_step( # Compute clipped gradients and per-example loss values grads, loss = grad_fn(params, x, y) - # Aggregate gradients (mean across batch) - mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads) - # Add Privacy Noise - noisy_grads, noise_state = privatizer.update(mean_grads, noise_state) + noisy_grads, noise_state = privatizer.update(grads, noise_state) # Apply updates using Optax updates, opt_state = optimizer.update(noisy_grads, opt_state) @@ -352,6 +350,7 @@ def train_step( return params, opt_state, noise_state, loss.values.mean() # Training loop + # TODO: Use jax_privacy.batch_selection.CyclicPoissonSampling print(f"Training for {NUM_STEPS} steps...") for step in range(NUM_STEPS): batch = get_batch(data, BATCH_SIZE, CONTEXT_LENGTH) diff --git a/examples/user_level_transformer_example.py b/examples/user_level_transformer_example.py index 820b60b..d5e54cf 100644 --- a/examples/user_level_transformer_example.py +++ b/examples/user_level_transformer_example.py @@ -7,7 +7,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law- agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and From 23e7eb73d13f6684d763d66c1c07d9206eb6da16 Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Mon, 26 Jan 2026 17:17:42 +0530 Subject: [PATCH 8/9] Add requirements file for transformer examples to fix CI dependencies --- examples/requirements.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 examples/requirements.txt diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..884bded --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1 @@ +flax From d5a7943bc4f5d5feaedd073e47fab9a65f4b1133 Mon Sep 17 00:00:00 2001 From: Debangan Ghosh Date: Tue, 27 Jan 2026 16:27:11 +0530 Subject: [PATCH 9/9] chore: align dependencies with new modular CI in pyproject.toml --- examples/dp_sgd_transformer_nnx.py | 2 +- examples/user_level_transformer_example.py | 2 +- pyproject.toml | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dp_sgd_transformer_nnx.py b/examples/dp_sgd_transformer_nnx.py index d2f9959..7268763 100644 --- a/examples/dp_sgd_transformer_nnx.py +++ b/examples/dp_sgd_transformer_nnx.py @@ -33,7 +33,7 @@ from typing import Dict, Tuple, Any import urllib.request -from flax import nnx # pytype: disable=import-error +from flax import nnx import jax import jax.extend.backend import jax.numpy as jnp diff --git a/examples/user_level_transformer_example.py b/examples/user_level_transformer_example.py index d5e54cf..117e98a 100644 --- a/examples/user_level_transformer_example.py +++ b/examples/user_level_transformer_example.py @@ -31,7 +31,7 @@ from absl import app from absl import flags -import flax.linen as nn # pytype: disable=import-error +import flax.linen as nn import jax import jax.numpy as jnp import numpy as np diff --git a/pyproject.toml b/pyproject.toml index 6fb9224..31bc8b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ dev = [ "keras", "tensorflow", ] +examples = [ + "flax", +] [project.urls] Homepage = "https://github.com/google-deepmind/jax_privacy"