Skip to content

Conversation

@debanganghosh08
Copy link

This PR adds a comprehensive, self-contained example of training a Transformer decoder using the JAX Privacy Core API (Issue #86).

Key Features:

Architecture: Implements a Transformer Decoder in Flax Linen for character-level language modeling.
Bare-Metal Core API: Explicitly uses jax_privacy.clipped_grad and noise_addition.gaussian_privatizer without high-level wrappers.
Configurability: Fully integrated with ABSL flags for epsilon, batch size, learning rate, and clipping norm.
Verification: Confirmed model convergence (Loss decreased from ~2.68 to ~1.87 in 50 steps).

Quality Checks:
Pylint Score: 9.91/10
Flake8: Passed

Fixes #86

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add simple end-to-end example of DP training of a transformer

2 participants