Skip to content

Conversation

@debanganghosh08
Copy link

This PR introduces a comprehensive example of training a Transformer model with Differential Privacy using the new Flax NNX API. While JAX Privacy provides robust support for Linen and Haiku, this addition provides a template for users moving toward the functional-object paradigm of NNX.

Key Technical Implementations:

✔️ Exhaustive State Partitioning: Utilizes nnx.split(model, nnx.Param, ...) to strictly separate trainable parameters from non-trainable state (RNG counts, etc.), ensuring the JAX tracer maintains leaf parity across functional boundaries.

✔️ Rank-Normalized Loss: Implements a rank-injection strategy within the pure loss function to account for vmap dimension-stripping. By forcing a singleton batch dimension during the forward pass, the model correctly generates 4D causal masks required by the attention mechanism.

✔️ Privacy-Safe State Reconstruction: Uses an internal nnx.merge pattern to ensure that mutations to RNG states during training remain local to the functional trace, preventing TraceContextError regressions.

Verification: The script was validated on the Tiny Shakespeare dataset for 20 steps, achieving stable convergence under DP constraints (Default: CLIP_NORM=1.0).

Screenshot of output attached 👇
training OP

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from 7cbfbb1 to 944df7c Compare January 24, 2026 14:49
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law- agreed to in writing, software
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix typo

Returns:
The content of the downloaded file as a string.
"""
with urllib.request.urlopen(url) as response:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add timeout to prevent indefinite blocking

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch brother. i have now added a timeout and is definitely best practice to avoid hangs in CI/CD. I've updated download_data to include a 10-second timeout. I'm also moving the flax dependency into a proper requirements file as you suggested.

import urllib.request

from flax import nnx
from flax import nnx # pytype: disable=import-error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is unusual

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so try adding in the requirements txt which is located in the docs folder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you can add it to the "dev" requirements in pyproject.toml

from absl import app
from absl import flags
import flax.linen as nn
import flax.linen as nn # pytype: disable=import-error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here too

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.

Copy link
Collaborator

@ryan112358 ryan112358 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great ,very clean - nice work! Left some comments

import urllib.request

from flax import nnx
from flax import nnx # pytype: disable=import-error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.

x: Input batch (single example or microbatch).
y: Target batch (single example or microbatch).
graphdef: The static graph definition of the NNX model.
other: Non-trainable state (e.g., RNG counts).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else other than the rng counts is captured here? Is it possible to call this argument prng and have it typed as a jax.Array, then somehow wire it through to flax? I ask because when you call clipped_grad, if the loss function contains a prng key it needs special handling.

Returns:
The scalar loss value.
"""
m = nnx.merge(graphdef, params, other)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give this a descriptive name like model

l2_clip_norm=CLIP_NORM,
batch_argnums=(1, 2), # x and y are batched
keep_batch_dim=False, # Process per-example
return_values=True # Return loss values for logging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to pass prng_argnum here as well to ensure the random key is handled appropriately. But it might require slight refactoring of your loss function

functools.partial(pure_loss_fn, graphdef=graphdef, other=other),
l2_clip_norm=CLIP_NORM,
batch_argnums=(1, 2), # x and y are batched
keep_batch_dim=False, # Process per-example
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we want to keep this to the default (True), unless we're doing user-level DP. If you set this to True (or remove it), can you remove the line that adds an extra batch axis in pure_loss_fn?

grads, loss = grad_fn(params, x, y)

# Aggregate gradients (mean across batch)
mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_fn already aggregates gradients across the batch dimension, so I think this is a bug

# Aggregate gradients (mean across batch)
mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads)

# Add Privacy Noise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it up to your discretion, but I think these inline comments can be removed.

# Training loop
print(f"Training for {NUM_STEPS} steps...")
for step in range(NUM_STEPS):
batch = get_batch(data, BATCH_SIZE, CONTEXT_LENGTH)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an ideal world this would use poisson sampling / jax_privacy.batch_selection. It's fine to leave a TODO for now and add it in a follow-up

)

privatizer = noise_addition.gaussian_privatizer(
stddev=CLIP_NORM,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stddev should be grad_fn.sensitiivty() * noise_multiplier. can you add NOISE_MULTIPLIER to the list of constants above?

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from 1d03537 to 9eac33d Compare January 26, 2026 11:35
@debanganghosh08
Copy link
Author

Hi @ryan112358 ,

I've pushed an update addressing all your feedback. Here is a summary of the changes I made:

  1. CI/CD Infrastructure: Moved the flax dependency to examples/requirements.txt and updated .github/workflows/ci.yml. This ensures all examples pass pytype without manual disable comments.

  2. NNX Causal Masking: Refactored TransformerBlock to use nnx.make_causal_mask(x[..., 0]).
    I explored the is_causal keyword, but as noted, it isn't currently supported in the nnx.MultiHeadAttention version we are using. This new approach handles the rank requirements cleanly.

  3. Gradient Aggregation Fix: Set keep_batch_dim=True in clipped_grad and removed the manual jnp.mean aggregation in the training step to prevent double-averaging.

  4. Privacy Parameters: Integrated the NOISE_MULTIPLIER constant and updated the privatizer to scale based on grad_fn.sensitivity().

  5. Refinement: I renamed internal variables for clarity (e.g., model instead of m), added a timeout to the data loader, and included a TODO for moving to Poisson sampling.

Verification: The script was verified for 10 steps locally, achieving a stable loss and passing a 10.00/10 pylint check.

Remind me if new changes are required!

@amyssnippet
Copy link
Contributor

#128 might fix the ci failures easy to debug

@debanganghosh08
Copy link
Author

#128 might fix the ci failures easy to debug

That's an Good approach for moving current CICD to modular DAG architecture. It is good for improving DX.

@amyssnippet
Copy link
Contributor

@debanganghosh08 , since now the new ci pipeline and new dependency flow has been introduced, so there will ci failures from now on. As you have added the one lib in examples/req...txt it will not considered from now on. Kindly first pull the lastest changes from upstream main, then delete the examples/req..txt file and add the deps to the pyproject.toml, you can see there is optional tab and a space for [examples], kindly add it there.

Now a central optional deps are managed at the root pyproject.toml file

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from b6d6d66 to d5a7943 Compare January 27, 2026 11:02
@debanganghosh08
Copy link
Author

@debanganghosh08 , since now the new ci pipeline and new dependency flow has been introduced, so there will ci failures from now on. As you have added the one lib in examples/req...txt it will not considered from now on. Kindly first pull the lastest changes from upstream main, then delete the examples/req..txt file and add the deps to the pyproject.toml, you can see there is optional tab and a space for [examples], kindly add it there.

Now a central optional deps are managed at the root pyproject.toml file

Thanks for the heads-up and the clear guidance on the new dependency flow, @amyssnippet! I've just pushed an update aligning with the new modular CI. I pulled the latest upstream changes, migrated flax to the [project.optional-dependencies] section in pyproject.toml, and cleaned up the temporary requirements file. Everything should be in sync now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants