Skip to content

Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120#126

Open
debanganghosh08 wants to merge 23 commits into
google-deepmind:mainfrom
debanganghosh08:feat/nnx-transformer-dp-sgd
Open

Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120#126
debanganghosh08 wants to merge 23 commits into
google-deepmind:mainfrom
debanganghosh08:feat/nnx-transformer-dp-sgd

Conversation

@debanganghosh08
Copy link
Copy Markdown

This PR introduces a comprehensive example of training a Transformer model with Differential Privacy using the new Flax NNX API. While JAX Privacy provides robust support for Linen and Haiku, this addition provides a template for users moving toward the functional-object paradigm of NNX.

Key Technical Implementations:

✔️ Exhaustive State Partitioning: Utilizes nnx.split(model, nnx.Param, ...) to strictly separate trainable parameters from non-trainable state (RNG counts, etc.), ensuring the JAX tracer maintains leaf parity across functional boundaries.

✔️ Rank-Normalized Loss: Implements a rank-injection strategy within the pure loss function to account for vmap dimension-stripping. By forcing a singleton batch dimension during the forward pass, the model correctly generates 4D causal masks required by the attention mechanism.

✔️ Privacy-Safe State Reconstruction: Uses an internal nnx.merge pattern to ensure that mutations to RNG states during training remain local to the functional trace, preventing TraceContextError regressions.

Verification: The script was validated on the Tiny Shakespeare dataset for 20 steps, achieving stable convergence under DP constraints (Default: CLIP_NORM=1.0).

Screenshot of output attached 👇
training OP

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from 7cbfbb1 to 944df7c Compare January 24, 2026 14:49
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law- agreed to in writing, software
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix typo

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
Returns:
The content of the downloaded file as a string.
"""
with urllib.request.urlopen(url) as response:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add timeout to prevent indefinite blocking

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch brother. i have now added a timeout and is definitely best practice to avoid hangs in CI/CD. I've updated download_data to include a 10-second timeout. I'm also moving the flax dependency into a proper requirements file as you suggested.

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
import urllib.request

from flax import nnx
from flax import nnx # pytype: disable=import-error
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is unusual

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so try adding in the requirements txt which is located in the docs folder

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you can add it to the "dev" requirements in pyproject.toml

from absl import app
from absl import flags
import flax.linen as nn
import flax.linen as nn # pytype: disable=import-error
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.

Copy link
Copy Markdown
Collaborator

@ryan112358 ryan112358 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great ,very clean - nice work! Left some comments

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
import urllib.request

from flax import nnx
from flax import nnx # pytype: disable=import-error
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
x: Input batch (single example or microbatch).
y: Target batch (single example or microbatch).
graphdef: The static graph definition of the NNX model.
other: Non-trainable state (e.g., RNG counts).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else other than the rng counts is captured here? Is it possible to call this argument prng and have it typed as a jax.Array, then somehow wire it through to flax? I ask because when you call clipped_grad, if the loss function contains a prng key it needs special handling.

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
Returns:
The scalar loss value.
"""
m = nnx.merge(graphdef, params, other)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give this a descriptive name like model

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
l2_clip_norm=CLIP_NORM,
batch_argnums=(1, 2), # x and y are batched
keep_batch_dim=False, # Process per-example
return_values=True # Return loss values for logging
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to pass prng_argnum here as well to ensure the random key is handled appropriately. But it might require slight refactoring of your loss function

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
functools.partial(pure_loss_fn, graphdef=graphdef, other=other),
l2_clip_norm=CLIP_NORM,
batch_argnums=(1, 2), # x and y are batched
keep_batch_dim=False, # Process per-example
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we want to keep this to the default (True), unless we're doing user-level DP. If you set this to True (or remove it), can you remove the line that adds an extra batch axis in pure_loss_fn?

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
grads, loss = grad_fn(params, x, y)

# Aggregate gradients (mean across batch)
mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_fn already aggregates gradients across the batch dimension, so I think this is a bug

# Aggregate gradients (mean across batch)
mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads)

# Add Privacy Noise
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it up to your discretion, but I think these inline comments can be removed.

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
# Training loop
print(f"Training for {NUM_STEPS} steps...")
for step in range(NUM_STEPS):
batch = get_batch(data, BATCH_SIZE, CONTEXT_LENGTH)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an ideal world this would use poisson sampling / jax_privacy.batch_selection. It's fine to leave a TODO for now and add it in a follow-up

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
)

privatizer = noise_addition.gaussian_privatizer(
stddev=CLIP_NORM,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stddev should be grad_fn.sensitiivty() * noise_multiplier. can you add NOISE_MULTIPLIER to the list of constants above?

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from 1d03537 to 9eac33d Compare January 26, 2026 11:35
@debanganghosh08
Copy link
Copy Markdown
Author

Hi @ryan112358 ,

I've pushed an update addressing all your feedback. Here is a summary of the changes I made:

  1. CI/CD Infrastructure: Moved the flax dependency to examples/requirements.txt and updated .github/workflows/ci.yml. This ensures all examples pass pytype without manual disable comments.

  2. NNX Causal Masking: Refactored TransformerBlock to use nnx.make_causal_mask(x[..., 0]).
    I explored the is_causal keyword, but as noted, it isn't currently supported in the nnx.MultiHeadAttention version we are using. This new approach handles the rank requirements cleanly.

  3. Gradient Aggregation Fix: Set keep_batch_dim=True in clipped_grad and removed the manual jnp.mean aggregation in the training step to prevent double-averaging.

  4. Privacy Parameters: Integrated the NOISE_MULTIPLIER constant and updated the privatizer to scale based on grad_fn.sensitivity().

  5. Refinement: I renamed internal variables for clarity (e.g., model instead of m), added a timeout to the data loader, and included a TODO for moving to Poisson sampling.

Verification: The script was verified for 10 steps locally, achieving a stable loss and passing a 10.00/10 pylint check.

Remind me if new changes are required!

@amyssnippet
Copy link
Copy Markdown
Contributor

#128 might fix the ci failures easy to debug

@debanganghosh08
Copy link
Copy Markdown
Author

#128 might fix the ci failures easy to debug

That's an Good approach for moving current CICD to modular DAG architecture. It is good for improving DX.

@amyssnippet
Copy link
Copy Markdown
Contributor

@debanganghosh08 , since now the new ci pipeline and new dependency flow has been introduced, so there will ci failures from now on. As you have added the one lib in examples/req...txt it will not considered from now on. Kindly first pull the lastest changes from upstream main, then delete the examples/req..txt file and add the deps to the pyproject.toml, you can see there is optional tab and a space for [examples], kindly add it there.

Now a central optional deps are managed at the root pyproject.toml file

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from b6d6d66 to d5a7943 Compare January 27, 2026 11:02
@debanganghosh08
Copy link
Copy Markdown
Author

@debanganghosh08 , since now the new ci pipeline and new dependency flow has been introduced, so there will ci failures from now on. As you have added the one lib in examples/req...txt it will not considered from now on. Kindly first pull the lastest changes from upstream main, then delete the examples/req..txt file and add the deps to the pyproject.toml, you can see there is optional tab and a space for [examples], kindly add it there.

Now a central optional deps are managed at the root pyproject.toml file

Thanks for the heads-up and the clear guidance on the new dependency flow, @amyssnippet! I've just pushed an update aligning with the new modular CI. I pulled the latest upstream changes, migrated flax to the [project.optional-dependencies] section in pyproject.toml, and cleaned up the temporary requirements file. Everything should be in sync now!

Copy link
Copy Markdown
Contributor

@amyssnippet amyssnippet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess check the files changed tab, there are still some files visible, kindly fix them all, i already left comments

Comment thread .github/workflows/ci.yml Outdated
Comment on lines +24 to +25
- name: Install example requirements
run: pip install -r examples/requirements.txt
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block of ci should not be here, it is unusual, it is not required

Comment thread pyproject.toml Outdated
Comment on lines +38 to +40
examples = [
"flax",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have already created arrays to manage all optional dependencies, check it here https://github.com/google-deepmind/jax_privacy/blob/main/pyproject.toml

i have made deps in the prev task with ci, make sure you pulled the changes properly. including this file

Comment thread examples/requirements.txt Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess its still available here, which is not required

@debanganghosh08
Copy link
Copy Markdown
Author

Hello @ryan112358 and @Neerajpathak07 ,

I’ve updated the implementation for both the NNX Transformer (#126) and ULS Transformer (#107) examples to align with the architectural suggestions provided. I performed a side-by-side experimental benchmark to evaluate the impact of moving from a manual loop to the library's internal execution_plan abstraction.

Key Refactors Implemented:

Standardized Orchestration: Switched to execution_plan.BandMFExecutionPlanConfig to wire the privatizer and clipped_grad. This ensures the noise addition and sampling strategies are mathematically synchronized with the library's core mechanisms.

ULS Integration: In the User-Level example, I successfully wrapped the plan.batch_selection_strategy within our UserSelectionStrategy, maintaining the required intra-user averaging while utilizing the standard batch_iterator.

Production Standards: Adopted the main(argv: Sequence[str]) entry point and migrated hyperparameters to centralized constants for better readability. Both files now achieve a 10.00/10 Pylint score.

Benchmarking Observation: During local 10-step runs, I noted a significant initialization overhead (~45s). This is due to the Toeplitz.optimize_banded_toeplitz step required by the BandMF strategy. While this increases the 'wall-clock' time for short CI checks, it is a fixed cost that will be fully amortized during production-scale training runs.

@Neerajpathak07, thanks for pointing out the BandMF configuration, it makes the examples much more idiomatic. @ryan112358, do you agree that the increased alignment with the core library's 'Plan' API is worth the trade-off in script simplicity for these examples?

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch 3 times, most recently from bf44580 to 1b84638 Compare February 4, 2026 19:02
@github-actions
Copy link
Copy Markdown

This PR has been idle for 7 days. Please provide an update or review.

@github-actions github-actions Bot added the Stale label Feb 20, 2026
@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from c206071 to fbdc2b3 Compare March 15, 2026 05:25
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 15, 2026

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

❌ Patch coverage is 75.00000% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@6cf0972). Learn more about missing BASE report.

Files with missing lines Patch % Lines
jax_privacy/batch_selection.py 75.00% 2 Missing ⚠️
❗ Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #126   +/-   ##
=======================================
  Coverage        ?   73.61%           
=======================================
  Files           ?       25           
  Lines           ?     3025           
  Branches        ?        0           
=======================================
  Hits            ?     2227           
  Misses          ?      798           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@debanganghosh08
Copy link
Copy Markdown
Author

Hi @ryan112358,

I have finalized the NNX Transformer example and resolved all action items from the previous review cycle.

Technical Resolutions:

Secure RNG handling: Refactored pure_loss_fn to isolate the prng_key. By using prng_argnum=3 in clipped_grad, the JAX tracer now correctly splits the PRNG key per-example in the microbatch, preventing the "RNG broadcasting" leak.

Gradient Aggregation: Switched to keep_batch_dim=True and implemented manual jax.tree.map(jnp.sum) post-clipping. This ensures the gradients are correctly aggregated for the privatizer while respecting the per-example clipping constraints.

CI/CD Hardening: Resolved a Windows-specific crash in the unit tests caused by flax pulling in uvloop. I've added a sys_platform != 'win32' marker to the test dependencies in pyproject.toml to ensure cross-platform CI stability.

Production Standards:

  1. Aligned the execution_plan to use an explicit NOISE_MULTIPLIER for transparent privacy scaling.

  2. Verified a 10.0/10 Pylint score and full pyink compliance for both the NNX and User-Level examples.

  3. Convergence Verified: In a local run, the training loss successfully decreased from 4.68 to 4.61 over the first 20 steps.

Copy link
Copy Markdown
Contributor

@amyssnippet amyssnippet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rest lgtm

Comment thread pyproject.toml Outdated
"hypothesis",
"crc32c",
"keras",
# Add the environment marker to flax here:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for comment

@github-actions github-actions Bot removed the Stale label Mar 16, 2026
Comment thread examples/dp_sgd_transformer_nnx.py Outdated
import jax
import jax.extend.backend
import jax.numpy as jnp
from jax_privacy.clipping import clipped_grad
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer importing clipped_grad directly from jax_privacy (no .clipping)

- Uses `flax.nnx` for model definition.
- Implements the "Exhaustive Split" pattern to separate parameters from
static graph definition.
- Handles rank normalization for `jax_privacy.clipping.clipped_grad`
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer importing clipped_grad directly from jax_privacy (no .clipping)

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
stoi: A dictionary mapping characters to integer indices.
"""
chars = sorted(list(set(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's directly return the dict here, rather than creating stoi intermediate

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
"""Gets a random batch of data.

Args:
data: The entire dataset encoded as integers.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused by this function. Is data 1D or 2D?

For DP training, we need to have a solid notion of what the unit of DP is. For this dataset, I'm not 100% what a natural unit is, maybe a single sentence, or single paragraph?

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
)
stoi = create_tokenizer(text)
vocab_size = len(stoi)
data = encode(text, stoi)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure that data is 2D, and each token only appears in a single example. This is important to get valid DP guarantees. Can you split up the data either into constant sized chunks or sentence level chunks with truncation & padding?

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
context_length=CONTEXT_LENGTH,
num_heads=NUM_HEADS,
num_layers=NUM_LAYERS,
rngs=rngs,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's directly instantiate nnx.Rngs(0) here if it's unused below this line.

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
l2_clip_norm=CLIP_NORM,
batch_argnums=(1, 2), # x and y are batched
prng_argnum=3, # Explicitly vmap the PRNG key per batch example
keep_batch_dim=True, # Return per-example gradients
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inline comments is not right. clipped_grad always returns per-example gradients, no matter what this is set as. But keep_batch_dim=True is the default, so I'd just delete this line all-together.

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
grads, loss = grad_fn(params, x, y, prng_key)

# Manually aggregate the per-example clipped gradients using sum
summed_grads = jax.tree_util.tree_map(lambda g: jnp.sum(g, axis=0), grads)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not right, grads already sums along the batch dimension. Maybe add an assert that the shape of grads matches the shape of params to verify this.

Comment thread examples/dp_sgd_transformer_nnx.py Outdated
return np.array([stoi[ch] for ch in text], dtype=np.int32)


def get_batch(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function appears unused and can be deleted

@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from 762c6b8 to a34049a Compare March 23, 2026 05:17
@debanganghosh08
Copy link
Copy Markdown
Author

Hi @ryan112358,

I have completed the final polish of the NNX Transformer example and resolved all action items from your latest review.

Technical Resolutions:

Privacy Unit Integrity: Refactored the data loader to strictly enforce non-overlapping 2D sequences of shape (num_sequences, CONTEXT_LENGTH). This ensures each token belongs to exactly one DP unit.

Gradient Logic & Correctness: Removed the redundant manual gradient summation and keep_batch_dim=True flag. I've added a shape assertion in train_step to verify that clipped_grad is correctly aggregating gradients to match parameter leaf shapes.

Privacy Hardening (Zero-Batch): The training loop now processes empty batches (if sampled) by procesing zero-vectors and adding privacy noise, ensuring no metadata leakage through batch-skipping.

Production Optimization: Integrated donate_argnums for efficient buffer reuse and achieved a perfect 10.0/10 Pylint score with full pyink compliance.

Cleanups: Deleted the unused get_batch utility and simplified the tokenizer return.

Verification Results: Verified convergence over 20 steps (Final Loss: 4.48).
Both the NNX and User-Level examples are now fully aligned with the library's latest modular CI.

Ready for final review!

@debanganghosh08
Copy link
Copy Markdown
Author

Hi @ryan112358,

I have finalized the PR by aligning user_level_transformer_example.py with the latest library standards and resolving the final CI linting failure.

Technical Resolutions (Final Polish):

API Migration: Refactored BandMFExecutionPlanConfig in the user-level example to use the .default() factory method and the config.plan interface. This resolves the E1123 pylint error caused by the recent constructor update in execution_plan.py.

Buffer Donation: Integrated donate_argnums=(0, 1, 4) in the train_step for both Transformer examples to ensure production-level memory efficiency.

Cross-Script Consistency: Both the NNX and User-Level examples now share identical configuration logic, imports, and linter compliance.

Verification: Achieved a 10.0/10 Pylint score and verified functional 2D non-overlapping data units across 20 steps of training.

Please provide a review.

@github-actions
Copy link
Copy Markdown

This PR has been idle for 7 days. Please provide an update or review.

@github-actions github-actions Bot added the Stale label Apr 13, 2026
…an.clipped_grad, and patch batch_selection for multi-dim support
@debanganghosh08 debanganghosh08 force-pushed the feat/nnx-transformer-dp-sgd branch from c501354 to 2206902 Compare April 25, 2026 16:00
@debanganghosh08
Copy link
Copy Markdown
Author

Hi @ryan112358,

I have successfully rebased the PR against the latest upstream/main and resolved the conflicts in pyproject.toml.

Post-Rebase Verification:

Logic Integrity: Verified that the core utility patch in jax_privacy/batch_selection.py for multi-dimensional padding is intact.

Performance Stability: Re-ran the JIT verification logs. The train_step continues to compile exactly once using the PADDING_MULTIPLE = 8 strategy, ensuring zero performance regression.

Production Standards: All scripts maintain a 10.0/10 Pylint score and full pyink compliance.

The PR is now clean and ready for your final walkthrough. Thank you for your patience during my university project break!

@github-actions github-actions Bot removed the Stale label Apr 26, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 3, 2026

This PR has been idle for 7 days. Please provide an update or review.

@github-actions github-actions Bot added the Stale label May 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants