Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120#126
Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120#126debanganghosh08 wants to merge 23 commits into
Conversation
7cbfbb1 to
944df7c
Compare
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law- agreed to in writing, software |
| Returns: | ||
| The content of the downloaded file as a string. | ||
| """ | ||
| with urllib.request.urlopen(url) as response: |
There was a problem hiding this comment.
add timeout to prevent indefinite blocking
There was a problem hiding this comment.
That's a good catch brother. i have now added a timeout and is definitely best practice to avoid hangs in CI/CD. I've updated download_data to include a 10-second timeout. I'm also moving the flax dependency into a proper requirements file as you suggested.
| import urllib.request | ||
|
|
||
| from flax import nnx | ||
| from flax import nnx # pytype: disable=import-error |
There was a problem hiding this comment.
No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.
There was a problem hiding this comment.
so try adding in the requirements txt which is located in the docs folder
There was a problem hiding this comment.
The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.
There was a problem hiding this comment.
Or you can add it to the "dev" requirements in pyproject.toml
| from absl import app | ||
| from absl import flags | ||
| import flax.linen as nn | ||
| import flax.linen as nn # pytype: disable=import-error |
There was a problem hiding this comment.
No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.
ryan112358
left a comment
There was a problem hiding this comment.
Looks great ,very clean - nice work! Left some comments
| import urllib.request | ||
|
|
||
| from flax import nnx | ||
| from flax import nnx # pytype: disable=import-error |
There was a problem hiding this comment.
The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.
| x: Input batch (single example or microbatch). | ||
| y: Target batch (single example or microbatch). | ||
| graphdef: The static graph definition of the NNX model. | ||
| other: Non-trainable state (e.g., RNG counts). |
There was a problem hiding this comment.
What else other than the rng counts is captured here? Is it possible to call this argument prng and have it typed as a jax.Array, then somehow wire it through to flax? I ask because when you call clipped_grad, if the loss function contains a prng key it needs special handling.
| Returns: | ||
| The scalar loss value. | ||
| """ | ||
| m = nnx.merge(graphdef, params, other) |
There was a problem hiding this comment.
Give this a descriptive name like model
| l2_clip_norm=CLIP_NORM, | ||
| batch_argnums=(1, 2), # x and y are batched | ||
| keep_batch_dim=False, # Process per-example | ||
| return_values=True # Return loss values for logging |
There was a problem hiding this comment.
You might need to pass prng_argnum here as well to ensure the random key is handled appropriately. But it might require slight refactoring of your loss function
| functools.partial(pure_loss_fn, graphdef=graphdef, other=other), | ||
| l2_clip_norm=CLIP_NORM, | ||
| batch_argnums=(1, 2), # x and y are batched | ||
| keep_batch_dim=False, # Process per-example |
There was a problem hiding this comment.
Usually we want to keep this to the default (True), unless we're doing user-level DP. If you set this to True (or remove it), can you remove the line that adds an extra batch axis in pure_loss_fn?
| grads, loss = grad_fn(params, x, y) | ||
|
|
||
| # Aggregate gradients (mean across batch) | ||
| mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads) |
There was a problem hiding this comment.
grad_fn already aggregates gradients across the batch dimension, so I think this is a bug
| # Aggregate gradients (mean across batch) | ||
| mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads) | ||
|
|
||
| # Add Privacy Noise |
There was a problem hiding this comment.
I'll leave it up to your discretion, but I think these inline comments can be removed.
| # Training loop | ||
| print(f"Training for {NUM_STEPS} steps...") | ||
| for step in range(NUM_STEPS): | ||
| batch = get_batch(data, BATCH_SIZE, CONTEXT_LENGTH) |
There was a problem hiding this comment.
In an ideal world this would use poisson sampling / jax_privacy.batch_selection. It's fine to leave a TODO for now and add it in a follow-up
| ) | ||
|
|
||
| privatizer = noise_addition.gaussian_privatizer( | ||
| stddev=CLIP_NORM, |
There was a problem hiding this comment.
The stddev should be grad_fn.sensitiivty() * noise_multiplier. can you add NOISE_MULTIPLIER to the list of constants above?
1d03537 to
9eac33d
Compare
|
Hi @ryan112358 , I've pushed an update addressing all your feedback. Here is a summary of the changes I made:
✅ Verification: The script was verified for 10 steps locally, achieving a stable loss and passing a 10.00/10 pylint check. Remind me if new changes are required! |
|
#128 might fix the ci failures easy to debug |
That's an Good approach for moving current CICD to modular DAG architecture. It is good for improving DX. |
|
@debanganghosh08 , since now the new ci pipeline and new dependency flow has been introduced, so there will ci failures from now on. As you have added the one lib in examples/req...txt it will not considered from now on. Kindly first pull the lastest changes from upstream main, then delete the examples/req..txt file and add the deps to the pyproject.toml, you can see there is optional tab and a space for [examples], kindly add it there. Now a central optional deps are managed at the root pyproject.toml file |
b6d6d66 to
d5a7943
Compare
Thanks for the heads-up and the clear guidance on the new dependency flow, @amyssnippet! I've just pushed an update aligning with the new modular CI. I pulled the latest upstream changes, migrated flax to the [project.optional-dependencies] section in pyproject.toml, and cleaned up the temporary requirements file. Everything should be in sync now! |
amyssnippet
left a comment
There was a problem hiding this comment.
i guess check the files changed tab, there are still some files visible, kindly fix them all, i already left comments
| - name: Install example requirements | ||
| run: pip install -r examples/requirements.txt |
There was a problem hiding this comment.
this block of ci should not be here, it is unusual, it is not required
| examples = [ | ||
| "flax", | ||
| ] |
There was a problem hiding this comment.
i have already created arrays to manage all optional dependencies, check it here https://github.com/google-deepmind/jax_privacy/blob/main/pyproject.toml
i have made deps in the prev task with ci, make sure you pulled the changes properly. including this file
There was a problem hiding this comment.
i guess its still available here, which is not required
|
Hello @ryan112358 and @Neerajpathak07 , I’ve updated the implementation for both the NNX Transformer (#126) and ULS Transformer (#107) examples to align with the architectural suggestions provided. I performed a side-by-side experimental benchmark to evaluate the impact of moving from a manual loop to the library's internal execution_plan abstraction. Key Refactors Implemented: Standardized Orchestration: Switched to execution_plan.BandMFExecutionPlanConfig to wire the privatizer and clipped_grad. This ensures the noise addition and sampling strategies are mathematically synchronized with the library's core mechanisms. ULS Integration: In the User-Level example, I successfully wrapped the plan.batch_selection_strategy within our UserSelectionStrategy, maintaining the required intra-user averaging while utilizing the standard batch_iterator. Production Standards: Adopted the main(argv: Sequence[str]) entry point and migrated hyperparameters to centralized constants for better readability. Both files now achieve a 10.00/10 Pylint score. Benchmarking Observation: During local 10-step runs, I noted a significant initialization overhead (~45s). This is due to the Toeplitz.optimize_banded_toeplitz step required by the BandMF strategy. While this increases the 'wall-clock' time for short CI checks, it is a fixed cost that will be fully amortized during production-scale training runs. @Neerajpathak07, thanks for pointing out the BandMF configuration, it makes the examples much more idiomatic. @ryan112358, do you agree that the increased alignment with the core library's 'Plan' API is worth the trade-off in script simplicity for these examples? |
bf44580 to
1b84638
Compare
|
This PR has been idle for 7 days. Please provide an update or review. |
c206071 to
fbdc2b3
Compare
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #126 +/- ##
=======================================
Coverage ? 73.61%
=======================================
Files ? 25
Lines ? 3025
Branches ? 0
=======================================
Hits ? 2227
Misses ? 798
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hi @ryan112358, I have finalized the NNX Transformer example and resolved all action items from the previous review cycle. Technical Resolutions: Secure RNG handling: Refactored pure_loss_fn to isolate the prng_key. By using prng_argnum=3 in clipped_grad, the JAX tracer now correctly splits the PRNG key per-example in the microbatch, preventing the "RNG broadcasting" leak. Gradient Aggregation: Switched to keep_batch_dim=True and implemented manual jax.tree.map(jnp.sum) post-clipping. This ensures the gradients are correctly aggregated for the privatizer while respecting the per-example clipping constraints. CI/CD Hardening: Resolved a Windows-specific crash in the unit tests caused by flax pulling in uvloop. I've added a sys_platform != 'win32' marker to the test dependencies in pyproject.toml to ensure cross-platform CI stability. Production Standards:
|
| "hypothesis", | ||
| "crc32c", | ||
| "keras", | ||
| # Add the environment marker to flax here: |
| import jax | ||
| import jax.extend.backend | ||
| import jax.numpy as jnp | ||
| from jax_privacy.clipping import clipped_grad |
There was a problem hiding this comment.
prefer importing clipped_grad directly from jax_privacy (no .clipping)
| - Uses `flax.nnx` for model definition. | ||
| - Implements the "Exhaustive Split" pattern to separate parameters from | ||
| static graph definition. | ||
| - Handles rank normalization for `jax_privacy.clipping.clipped_grad` |
There was a problem hiding this comment.
prefer importing clipped_grad directly from jax_privacy (no .clipping)
| stoi: A dictionary mapping characters to integer indices. | ||
| """ | ||
| chars = sorted(list(set(text))) | ||
| stoi = {ch: i for i, ch in enumerate(chars)} |
There was a problem hiding this comment.
Let's directly return the dict here, rather than creating stoi intermediate
| """Gets a random batch of data. | ||
|
|
||
| Args: | ||
| data: The entire dataset encoded as integers. |
There was a problem hiding this comment.
I'm a little confused by this function. Is data 1D or 2D?
For DP training, we need to have a solid notion of what the unit of DP is. For this dataset, I'm not 100% what a natural unit is, maybe a single sentence, or single paragraph?
| ) | ||
| stoi = create_tokenizer(text) | ||
| vocab_size = len(stoi) | ||
| data = encode(text, stoi) |
There was a problem hiding this comment.
Let's make sure that data is 2D, and each token only appears in a single example. This is important to get valid DP guarantees. Can you split up the data either into constant sized chunks or sentence level chunks with truncation & padding?
| context_length=CONTEXT_LENGTH, | ||
| num_heads=NUM_HEADS, | ||
| num_layers=NUM_LAYERS, | ||
| rngs=rngs, |
There was a problem hiding this comment.
let's directly instantiate nnx.Rngs(0) here if it's unused below this line.
| l2_clip_norm=CLIP_NORM, | ||
| batch_argnums=(1, 2), # x and y are batched | ||
| prng_argnum=3, # Explicitly vmap the PRNG key per batch example | ||
| keep_batch_dim=True, # Return per-example gradients |
There was a problem hiding this comment.
This inline comments is not right. clipped_grad always returns per-example gradients, no matter what this is set as. But keep_batch_dim=True is the default, so I'd just delete this line all-together.
| grads, loss = grad_fn(params, x, y, prng_key) | ||
|
|
||
| # Manually aggregate the per-example clipped gradients using sum | ||
| summed_grads = jax.tree_util.tree_map(lambda g: jnp.sum(g, axis=0), grads) |
There was a problem hiding this comment.
this is not right, grads already sums along the batch dimension. Maybe add an assert that the shape of grads matches the shape of params to verify this.
| return np.array([stoi[ch] for ch in text], dtype=np.int32) | ||
|
|
||
|
|
||
| def get_batch( |
There was a problem hiding this comment.
This function appears unused and can be deleted
762c6b8 to
a34049a
Compare
|
Hi @ryan112358, I have completed the final polish of the NNX Transformer example and resolved all action items from your latest review. Technical Resolutions: Privacy Unit Integrity: Refactored the data loader to strictly enforce non-overlapping 2D sequences of shape (num_sequences, CONTEXT_LENGTH). This ensures each token belongs to exactly one DP unit. Gradient Logic & Correctness: Removed the redundant manual gradient summation and keep_batch_dim=True flag. I've added a shape assertion in train_step to verify that clipped_grad is correctly aggregating gradients to match parameter leaf shapes. Privacy Hardening (Zero-Batch): The training loop now processes empty batches (if sampled) by procesing zero-vectors and adding privacy noise, ensuring no metadata leakage through batch-skipping. Production Optimization: Integrated donate_argnums for efficient buffer reuse and achieved a perfect 10.0/10 Pylint score with full pyink compliance. Cleanups: Deleted the unused get_batch utility and simplified the tokenizer return. Verification Results: Verified convergence over 20 steps (Final Loss: 4.48). Ready for final review! |
|
Hi @ryan112358, I have finalized the PR by aligning user_level_transformer_example.py with the latest library standards and resolving the final CI linting failure. Technical Resolutions (Final Polish): API Migration: Refactored BandMFExecutionPlanConfig in the user-level example to use the .default() factory method and the config.plan interface. This resolves the E1123 pylint error caused by the recent constructor update in execution_plan.py. Buffer Donation: Integrated donate_argnums=(0, 1, 4) in the train_step for both Transformer examples to ensure production-level memory efficiency. Cross-Script Consistency: Both the NNX and User-Level examples now share identical configuration logic, imports, and linter compliance. Verification: Achieved a 10.0/10 Pylint score and verified functional 2D non-overlapping data units across 20 steps of training. Please provide a review. |
|
This PR has been idle for 7 days. Please provide an update or review. |
…alse) per maintainer review
…isolation, and linter compliance
…atting for user-level examples
… library API alignment
… user-level buffer donation
…an.clipped_grad, and patch batch_selection for multi-dim support
c501354 to
2206902
Compare
|
Hi @ryan112358, I have successfully rebased the PR against the latest upstream/main and resolved the conflicts in pyproject.toml. Post-Rebase Verification: Logic Integrity: Verified that the core utility patch in jax_privacy/batch_selection.py for multi-dimensional padding is intact. Performance Stability: Re-ran the JIT verification logs. The train_step continues to compile exactly once using the PADDING_MULTIPLE = 8 strategy, ensuring zero performance regression. Production Standards: All scripts maintain a 10.0/10 Pylint score and full pyink compliance. The PR is now clean and ready for your final walkthrough. Thank you for your patience during my university project break! |
|
This PR has been idle for 7 days. Please provide an update or review. |
This PR introduces a comprehensive example of training a Transformer model with Differential Privacy using the new Flax NNX API. While JAX Privacy provides robust support for Linen and Haiku, this addition provides a template for users moving toward the functional-object paradigm of NNX.
Key Technical Implementations:
✔️ Exhaustive State Partitioning: Utilizes nnx.split(model, nnx.Param, ...) to strictly separate trainable parameters from non-trainable state (RNG counts, etc.), ensuring the JAX tracer maintains leaf parity across functional boundaries.
✔️ Rank-Normalized Loss: Implements a rank-injection strategy within the pure loss function to account for vmap dimension-stripping. By forcing a singleton batch dimension during the forward pass, the model correctly generates 4D causal masks required by the attention mechanism.
✔️ Privacy-Safe State Reconstruction: Uses an internal nnx.merge pattern to ensure that mutations to RNG states during training remain local to the functional trace, preventing TraceContextError regressions.
✅ Verification: The script was validated on the Tiny Shakespeare dataset for 20 steps, achieving stable convergence under DP constraints (Default: CLIP_NORM=1.0).
Screenshot of output attached 👇
