-
Notifications
You must be signed in to change notification settings - Fork 36
Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120 #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
debanganghosh08
wants to merge
9
commits into
google-deepmind:main
Choose a base branch
from
debanganghosh08:feat/nnx-transformer-dp-sgd
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+747
−0
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e7b5538
Implemented Deplayed Preconditioners with alternating-phase protocol …
debanganghosh08 a2acc2f
Implement User-Level Sampling (ULS) for Transformers with per-user av…
debanganghosh08 50c04ad
Refactor: Use UserSelectionStrategy and clipped_grad(keep_batch_dim=F…
debanganghosh08 92b5771
style: fix remaining pylint R0917 and whitespace violations
debanganghosh08 b9f3a81
style: fix line length in user_level_transformer and sync nnx example
debanganghosh08 d2f8831
style: fix pytype import errors and finalize production standards
debanganghosh08 f5beecf
Refactor transformer DP-SGD example and update CI workflow
debanganghosh08 23e7eb7
Add requirements file for transformer examples to fix CI dependencies
debanganghosh08 d5a7943
chore: align dependencies with new modular CI in pyproject.toml
debanganghosh08 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Example of a private adaptive optimizer with a Delayed Preconditioner (DP^2). | ||
|
|
||
| This example is based on the paper: | ||
| Li et al. (2023), "Differentially Private Adaptive Optimization with Delayed | ||
| Preconditioners" (ICLR 2023). | ||
|
|
||
| Implementation Notes: | ||
| This implementation demonstrates the "alternating-phase protocol" of DP^2. | ||
| The core idea is to introduce a trade-off between the staleness of the | ||
| preconditioner and noise reduction. The preconditioner (second-moment | ||
| estimate) is kept "stale" for a fixed number of steps (`delay_s`). During | ||
| this period, noised gradients are accumulated. | ||
|
|
||
| When the delay period is over, the preconditioner is updated using the | ||
| *average* of the accumulated gradients. By averaging over multiple steps, | ||
| the signal from the gradients is amplified relative to the noise, leading | ||
| to a more stable and reliable preconditioner update. This helps to prevent | ||
| the noise from overwhelming the adaptive learning rate calculation, a common | ||
| problem in standard private adaptive optimizers. | ||
| """ | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import flax.linen as nn | ||
| from jax_privacy.clipping import clipped_grad | ||
| from jax_privacy import noise_addition | ||
| import optax | ||
| import numpy as np | ||
|
|
||
| # 1. Model and Data Setup | ||
| class SimpleModel(nn.Module): | ||
| """A simple linear model.""" | ||
| @nn.compact | ||
| def __call__(self, x): | ||
| return nn.Dense(features=1)(x) | ||
|
|
||
| # Dummy data | ||
| inputs = jnp.array([[1.0, 2.0], [3.0, 4.0]]) | ||
| labels = jnp.array([[0.5], [1.5]]) | ||
|
|
||
| model = SimpleModel() | ||
| key = jax.random.PRNGKey(42) | ||
| params = model.init(key, inputs)['params'] | ||
| epsilon_s1 = 1e-8 | ||
|
|
||
| # 2. Optimizer and DP^2 State | ||
| optimizer = optax.sgd(learning_rate=0.1) | ||
| opt_state = optimizer.init(params) | ||
|
|
||
| # DP^2 parameters | ||
| delay_s = 5 | ||
| preconditioner_v = jax.tree_util.tree_map(jnp.ones_like, params) | ||
| gradient_accumulator = jax.tree_util.tree_map(jnp.zeros_like, params) | ||
|
|
||
| # Privatizer for noise addition | ||
| stddev = 1.0 * 0.1 # l2_clip_norm * noise_multiplier | ||
| privatizer = noise_addition.gaussian_privatizer( | ||
| stddev=stddev, | ||
| prng_key=jax.random.PRNGKey(0), | ||
| ) | ||
| noise_state = privatizer.init(params) | ||
|
|
||
| # 3. Core Logic: DP^2 Update Step | ||
| def loss_fn(p, x, y): | ||
| """MSE loss.""" | ||
| return jnp.mean((model.apply({'params': p}, x) - y)**2) | ||
|
|
||
| def update_step( | ||
| step_idx, params, opt_state, preconditioner_v, gradient_accumulator, noise_state | ||
| ): | ||
| """Performs one update step with delayed preconditioning.""" | ||
|
|
||
| # Scale gradients using the (potentially stale) preconditioner | ||
| s_t = jax.tree_util.tree_map( | ||
| lambda var: 1 / (jnp.sqrt(var) + epsilon_s1), preconditioner_v | ||
| ) | ||
| def pre_clipping_transform(grads): | ||
| return jax.tree_util.tree_map(lambda g, s: g * s, grads, s_t) | ||
|
|
||
| # Compute and clip gradients | ||
| priv_grad_fn = clipped_grad( | ||
| loss_fn, | ||
| l2_clip_norm=1.0, | ||
| batch_argnums=(1, 2), | ||
| pre_clipping_transform=pre_clipping_transform, | ||
| ) | ||
| clipped_grads = priv_grad_fn(params, inputs, labels) | ||
|
|
||
| # Add noise | ||
| noised_grads, new_noise_state = privatizer.update( | ||
| clipped_grads, noise_state | ||
| ) | ||
|
|
||
| # Rescale noised gradients back to original space | ||
| noised_grads_rescaled = jax.tree_util.tree_map( | ||
| lambda g, s: g / s, noised_grads, s_t | ||
| ) | ||
|
|
||
| # Apply updates to parameters | ||
| updates, new_opt_state = optimizer.update(noised_grads_rescaled, opt_state, params) | ||
| new_params = optax.apply_updates(params, updates) | ||
|
|
||
| # --- DP^2 Preconditioner Update Logic --- | ||
| # Accumulate the noised (but not rescaled) gradients | ||
| new_gradient_accumulator = jax.tree_util.tree_map( | ||
| lambda acc, g: acc + g, gradient_accumulator, noised_grads | ||
| ) | ||
|
|
||
| def update_preconditioner(operand): | ||
| """Update preconditioner with averaged gradients and reset accumulator.""" | ||
| p_v, grad_acc = operand | ||
| # By delaying the update, we ensure that the second-moment estimate is | ||
| # derived from a higher signal-to-noise ratio average, preventing the | ||
| # noise from dominating the adaptive learning rates. | ||
| avg_grad = jax.tree_util.tree_map(lambda x: x / delay_s, grad_acc) | ||
| new_p_v = jax.tree_util.tree_map( | ||
| lambda v, g: 0.9 * v + 0.1 * jnp.square(g), p_v, avg_grad | ||
| ) | ||
| new_grad_acc = jax.tree_util.tree_map(jnp.zeros_like, grad_acc) | ||
| return new_p_v, new_grad_acc | ||
|
|
||
| def keep_stale(operand): | ||
| """Keep preconditioner stale and continue accumulating.""" | ||
| return operand | ||
|
|
||
| # "Snap" update: conditionally update the preconditioner | ||
| new_preconditioner_v, new_gradient_accumulator = jax.lax.cond( | ||
| (step_idx + 1) % delay_s == 0, | ||
| update_preconditioner, | ||
| keep_stale, | ||
| (preconditioner_v, new_gradient_accumulator), | ||
| ) | ||
|
|
||
| return ( | ||
| new_params, | ||
| new_opt_state, | ||
| new_preconditioner_v, | ||
| new_gradient_accumulator, | ||
| new_noise_state, | ||
| ) | ||
|
|
||
| # 4. Training Loop & Verification | ||
| v_initial = preconditioner_v | ||
| snapshots = {} | ||
| num_steps = 20 | ||
|
|
||
| for i in range(num_steps): | ||
| ( | ||
| params, | ||
| opt_state, | ||
| preconditioner_v, | ||
| gradient_accumulator, | ||
| noise_state, | ||
| ) = update_step( | ||
| i, params, opt_state, preconditioner_v, gradient_accumulator, noise_state | ||
| ) | ||
| if i == 3: # Step 4 (0-indexed) | ||
| snapshots['v_step4'] = preconditioner_v | ||
| if i == 4: # Step 5 (0-indexed) | ||
| snapshots['v_step5'] = preconditioner_v | ||
|
|
||
| # 5. Correctness Verification | ||
| def flatten_pytree(pt): | ||
| return np.concatenate([np.ravel(x) for x in jax.tree_util.tree_leaves(pt)]) | ||
|
|
||
| v_initial_flat = flatten_pytree(v_initial) | ||
| v_step4_flat = flatten_pytree(snapshots['v_step4']) | ||
| v_step5_flat = flatten_pytree(snapshots['v_step5']) | ||
|
|
||
| # Assert that the preconditioner was stale until the update step | ||
| np.testing.assert_allclose(v_initial_flat, v_step4_flat, rtol=1e-6) | ||
| assert not np.allclose(v_step4_flat, v_step5_flat), "Preconditioner did not update at step 5." | ||
|
|
||
| print( | ||
| "[Verification] DP2 Protocol confirmed: Preconditioner remained stale for" | ||
| " steps 1-4 and updated successfully at step 5." | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this block of ci should not be here, it is unusual, it is not required