diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7326b31..3653bd4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,8 @@ jobs: - uses: "actions/setup-python@v6" with: python-version: "${{ matrix.python-version }}" + - name: Install example requirements + run: pip install -r examples/requirements.txt - name: Run CI tests run: bash test.sh shell: bash diff --git a/examples/dp2_delayed_preconditioner_example.py b/examples/dp2_delayed_preconditioner_example.py new file mode 100644 index 0000000..a2c2a10 --- /dev/null +++ b/examples/dp2_delayed_preconditioner_example.py @@ -0,0 +1,191 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example of a private adaptive optimizer with a Delayed Preconditioner (DP^2). + +This example is based on the paper: + Li et al. (2023), "Differentially Private Adaptive Optimization with Delayed + Preconditioners" (ICLR 2023). + +Implementation Notes: + This implementation demonstrates the "alternating-phase protocol" of DP^2. + The core idea is to introduce a trade-off between the staleness of the + preconditioner and noise reduction. The preconditioner (second-moment + estimate) is kept "stale" for a fixed number of steps (`delay_s`). During + this period, noised gradients are accumulated. + + When the delay period is over, the preconditioner is updated using the + *average* of the accumulated gradients. By averaging over multiple steps, + the signal from the gradients is amplified relative to the noise, leading + to a more stable and reliable preconditioner update. This helps to prevent + the noise from overwhelming the adaptive learning rate calculation, a common + problem in standard private adaptive optimizers. +""" + +import jax +import jax.numpy as jnp +import flax.linen as nn +from jax_privacy.clipping import clipped_grad +from jax_privacy import noise_addition +import optax +import numpy as np + +# 1. Model and Data Setup +class SimpleModel(nn.Module): + """A simple linear model.""" + @nn.compact + def __call__(self, x): + return nn.Dense(features=1)(x) + +# Dummy data +inputs = jnp.array([[1.0, 2.0], [3.0, 4.0]]) +labels = jnp.array([[0.5], [1.5]]) + +model = SimpleModel() +key = jax.random.PRNGKey(42) +params = model.init(key, inputs)['params'] +epsilon_s1 = 1e-8 + +# 2. Optimizer and DP^2 State +optimizer = optax.sgd(learning_rate=0.1) +opt_state = optimizer.init(params) + +# DP^2 parameters +delay_s = 5 +preconditioner_v = jax.tree_util.tree_map(jnp.ones_like, params) +gradient_accumulator = jax.tree_util.tree_map(jnp.zeros_like, params) + +# Privatizer for noise addition +stddev = 1.0 * 0.1 # l2_clip_norm * noise_multiplier +privatizer = noise_addition.gaussian_privatizer( + stddev=stddev, + prng_key=jax.random.PRNGKey(0), +) +noise_state = privatizer.init(params) + +# 3. Core Logic: DP^2 Update Step +def loss_fn(p, x, y): + """MSE loss.""" + return jnp.mean((model.apply({'params': p}, x) - y)**2) + +def update_step( + step_idx, params, opt_state, preconditioner_v, gradient_accumulator, noise_state +): + """Performs one update step with delayed preconditioning.""" + + # Scale gradients using the (potentially stale) preconditioner + s_t = jax.tree_util.tree_map( + lambda var: 1 / (jnp.sqrt(var) + epsilon_s1), preconditioner_v + ) + def pre_clipping_transform(grads): + return jax.tree_util.tree_map(lambda g, s: g * s, grads, s_t) + + # Compute and clip gradients + priv_grad_fn = clipped_grad( + loss_fn, + l2_clip_norm=1.0, + batch_argnums=(1, 2), + pre_clipping_transform=pre_clipping_transform, + ) + clipped_grads = priv_grad_fn(params, inputs, labels) + + # Add noise + noised_grads, new_noise_state = privatizer.update( + clipped_grads, noise_state + ) + + # Rescale noised gradients back to original space + noised_grads_rescaled = jax.tree_util.tree_map( + lambda g, s: g / s, noised_grads, s_t + ) + + # Apply updates to parameters + updates, new_opt_state = optimizer.update(noised_grads_rescaled, opt_state, params) + new_params = optax.apply_updates(params, updates) + + # --- DP^2 Preconditioner Update Logic --- + # Accumulate the noised (but not rescaled) gradients + new_gradient_accumulator = jax.tree_util.tree_map( + lambda acc, g: acc + g, gradient_accumulator, noised_grads + ) + + def update_preconditioner(operand): + """Update preconditioner with averaged gradients and reset accumulator.""" + p_v, grad_acc = operand + # By delaying the update, we ensure that the second-moment estimate is + # derived from a higher signal-to-noise ratio average, preventing the + # noise from dominating the adaptive learning rates. + avg_grad = jax.tree_util.tree_map(lambda x: x / delay_s, grad_acc) + new_p_v = jax.tree_util.tree_map( + lambda v, g: 0.9 * v + 0.1 * jnp.square(g), p_v, avg_grad + ) + new_grad_acc = jax.tree_util.tree_map(jnp.zeros_like, grad_acc) + return new_p_v, new_grad_acc + + def keep_stale(operand): + """Keep preconditioner stale and continue accumulating.""" + return operand + + # "Snap" update: conditionally update the preconditioner + new_preconditioner_v, new_gradient_accumulator = jax.lax.cond( + (step_idx + 1) % delay_s == 0, + update_preconditioner, + keep_stale, + (preconditioner_v, new_gradient_accumulator), + ) + + return ( + new_params, + new_opt_state, + new_preconditioner_v, + new_gradient_accumulator, + new_noise_state, + ) + +# 4. Training Loop & Verification +v_initial = preconditioner_v +snapshots = {} +num_steps = 20 + +for i in range(num_steps): + ( + params, + opt_state, + preconditioner_v, + gradient_accumulator, + noise_state, + ) = update_step( + i, params, opt_state, preconditioner_v, gradient_accumulator, noise_state + ) + if i == 3: # Step 4 (0-indexed) + snapshots['v_step4'] = preconditioner_v + if i == 4: # Step 5 (0-indexed) + snapshots['v_step5'] = preconditioner_v + +# 5. Correctness Verification +def flatten_pytree(pt): + return np.concatenate([np.ravel(x) for x in jax.tree_util.tree_leaves(pt)]) + +v_initial_flat = flatten_pytree(v_initial) +v_step4_flat = flatten_pytree(snapshots['v_step4']) +v_step5_flat = flatten_pytree(snapshots['v_step5']) + +# Assert that the preconditioner was stale until the update step +np.testing.assert_allclose(v_initial_flat, v_step4_flat, rtol=1e-6) +assert not np.allclose(v_step4_flat, v_step5_flat), "Preconditioner did not update at step 5." + +print( + "[Verification] DP2 Protocol confirmed: Preconditioner remained stale for" + " steps 1-4 and updated successfully at step 5." +) diff --git a/examples/dp_sgd_transformer_nnx.py b/examples/dp_sgd_transformer_nnx.py new file mode 100644 index 0000000..7268763 --- /dev/null +++ b/examples/dp_sgd_transformer_nnx.py @@ -0,0 +1,368 @@ +# Copyright 2025 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example of training a Transformer with Differential Privacy using Flax NNX. + +This script demonstrates how to integrate Flax NNX (the new functional API for +Flax) with JAX Privacy's gradient clipping and noise addition mechanisms. It +implements a simple character-level Transformer language model trained on the +Tiny Shakespeare dataset using DP-SGD. + +Key Features: + - Uses `flax.nnx` for model definition. + - Implements the "Exhaustive Split" pattern to separate parameters from + static graph definition. + - Handles rank normalization for `jax_privacy.clipping.clipped_grad` + compatibility. + - Demonstrates correct value access using `clipped_grad` with auxiliary + outputs. +""" + +import functools +from typing import Dict, Tuple, Any +import urllib.request + +from flax import nnx +import jax +import jax.extend.backend +import jax.numpy as jnp +from jax_privacy import noise_addition +from jax_privacy.clipping import clipped_grad +import numpy as np +import optax + + +# Hyperparameters +BATCH_SIZE = 4 +CONTEXT_LENGTH = 8 +EMBED_SIZE = 16 +NUM_HEADS = 4 +NUM_LAYERS = 2 +LEARNING_RATE = 1e-3 +NUM_STEPS = 10 +CLIP_NORM = 1.0 +NOISE_MULTIPLIER = 1.0 + + +# Data loading and preparation +def download_data(url: str) -> str: + """Downloads data from a URL. + + Args: + url: The URL to download data from. + + Returns: + The content of the downloaded file as a string. + """ + with urllib.request.urlopen(url, timeout=10) as response: + return response.read().decode('utf-8') + + +def create_tokenizer(text: str) -> Dict[str, int]: + """Creates a simple character-level tokenizer. + + Args: + text: The input text to build the vocabulary from. + + Returns: + stoi: A dictionary mapping characters to integer indices. + """ + chars = sorted(list(set(text))) + stoi = {ch: i for i, ch in enumerate(chars)} + return stoi + + +def encode(text: str, stoi: Dict[str, int]) -> np.ndarray: + """Encodes text using the tokenizer. + + Args: + text: The text to encode. + stoi: The string-to-index mapping. + + Returns: + A numpy array of integer indices. + """ + return np.array([stoi[ch] for ch in text], dtype=np.int32) + + +def get_batch( + data: np.ndarray, batch_size: int, context_length: int +) -> Tuple[np.ndarray, np.ndarray]: + """Gets a random batch of data. + + Args: + data: The entire dataset encoded as integers. + batch_size: The number of examples in the batch. + context_length: The length of each sequence. + + Returns: + A tuple (x, y) where: + - x: Input sequences of shape (batch_size, context_length). + - y: Target sequences of shape (batch_size, context_length). + """ + ix = np.random.randint(len(data) - context_length, size=(batch_size,)) + x = np.stack([data[i : i + context_length] for i in ix]) + y = np.stack([data[i + 1 : i + context_length + 1] for i in ix]) + return x, y + + +class TransformerBlock(nnx.Module): + """A single Transformer block.""" + + def __init__( + self, + embed_size: int, + num_heads: int, + *, + rngs: nnx.Rngs, + ): + """Initializes the TransformerBlock. + + Args: + embed_size: The dimensionality of the embedding. + num_heads: The number of attention heads. + rngs: The random number generators. + """ + self.attention = nnx.MultiHeadAttention( + num_heads=num_heads, + in_features=embed_size, + qkv_features=embed_size, + out_features=embed_size, + rngs=rngs, + ) + self.ln1 = nnx.LayerNorm(num_features=embed_size, rngs=rngs) + self.ln2 = nnx.LayerNorm(num_features=embed_size, rngs=rngs) + self.ffw = nnx.Sequential( + nnx.Linear( + in_features=embed_size, + out_features=4 * embed_size, + rngs=rngs + ), + nnx.Linear( + in_features=4 * embed_size, + out_features=embed_size, + rngs=rngs + ), + ) + + def __call__(self, x: jax.Array) -> jax.Array: + """Applies the TransformerBlock to the input. + + Args: + x: Input array of shape (seq_len, embed_size). + + Returns: + Output array of the same shape as input. + """ + # Add is_causal=True to handle masking automatically + # Note: nnx.MultiHeadAttention does not support is_causal=True in + # __init__ or __call__ in this version. Using make_causal_mask instead + # as requested to remove manual generation. + mask = nnx.make_causal_mask(x[..., 0]) + x = x + self.attention(x, mask=mask, decode=False) + x = self.ln1(x) + x = x + self.ffw(x) + x = self.ln2(x) + return x + + +class TransformerLM(nnx.Module): + """A Transformer language model.""" + + def __init__( + self, + vocab_size: int, + *, + embed_size: int, + context_length: int, + num_heads: int, + num_layers: int, + rngs: nnx.Rngs, + ): + """Initializes the TransformerLM. + + Args: + vocab_size: The size of the vocabulary. + embed_size: The dimensionality of the embedding. + context_length: The max length of the context window. + num_heads: The number of attention heads. + num_layers: The number of transformer blocks. + rngs: The random number generators. + """ + self.token_embedding = nnx.Embed( + num_embeddings=vocab_size, features=embed_size, rngs=rngs + ) + self.pos_embedding = nnx.Embed( + num_embeddings=context_length, features=embed_size, rngs=rngs + ) + self.blocks = nnx.List([ + TransformerBlock(embed_size, num_heads, rngs=rngs) + for _ in range(num_layers) + ]) + self.ln_f = nnx.LayerNorm(num_features=embed_size, rngs=rngs) + self.head = nnx.Linear( + in_features=embed_size, out_features=vocab_size, rngs=rngs + ) + + def __call__(self, x: jax.Array) -> jax.Array: + """Applies the model to the input. + + Args: + x: Input array of token indices with shape (batch_size, seq_len). + + Returns: + Logits array of shape (batch_size, seq_len, vocab_size). + """ + pos = jnp.arange(0, x.shape[1]) + tok_emb = self.token_embedding(x) + pos_emb = self.pos_embedding(pos) + x = tok_emb + pos_emb + for block in self.blocks: + x = block(x) + x = self.ln_f(x) + logits = self.head(x) + return logits + + +def pure_loss_fn( + params: nnx.State, + x: jax.Array, + y: jax.Array, + graphdef: nnx.GraphDef, + other: nnx.State, +) -> jax.Array: + """A pure functional loss function for DP-SGD. + + This function re-merges the NNX model state, applies the model, and computes + the cross-entropy loss. It is designed to work with `clipped_grad` which + requires a functional interface. + + Args: + params: The trainable parameters of the model. + x: Input batch (single example or microbatch). + y: Target batch (single example or microbatch). + graphdef: The static graph definition of the NNX model. + other: Non-trainable state (e.g., RNG counts). + + Returns: + The scalar loss value. + """ + model = nnx.merge(graphdef, params, other) + + # Standard call without rank normalization + logits = model(x) + + + return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean() + + +def main(): + """Main training loop.""" + device = jax.extend.backend.get_backend().platform + print(f"Starting DP-SGD Training on {device.upper()}...") + + # Data + text = download_data( + "https://raw.githubusercontent.com/karpathy/char-rnn/master/" + "data/tinyshakespeare/input.txt" + ) + stoi = create_tokenizer(text) + vocab_size = len(stoi) + data = encode(text, stoi) + print(f"Dataset has {len(data)} tokens, {vocab_size} vocab size.") + + # Model and optimizer + rngs = nnx.Rngs(0) + model = TransformerLM( + vocab_size=vocab_size, + embed_size=EMBED_SIZE, + context_length=CONTEXT_LENGTH, + num_heads=NUM_HEADS, + num_layers=NUM_LAYERS, + rngs=rngs, + ) + optimizer = optax.adam(LEARNING_RATE) + + # CRITICAL: Exhaustive split to separate trainable params from static + # graph and other state. + graphdef, params, other = nnx.split(model, nnx.Param, ...) + opt_state = optimizer.init(params) + + # Configure DP Gradient Clipping + grad_fn = clipped_grad( + functools.partial(pure_loss_fn, graphdef=graphdef, other=other), + l2_clip_norm=CLIP_NORM, + batch_argnums=(1, 2), # x and y are batched + keep_batch_dim=True, # Return per-example gradients + return_values=True, # Return loss values for logging + # Note: We do not pass prng_argnum here because 'other' (arg 4) contains + # RNG state which is handled as a standard argument by NNX. + ) + + privatizer = noise_addition.gaussian_privatizer( + stddev=grad_fn.sensitivity() * NOISE_MULTIPLIER, + ) + noise_state = privatizer.init(params) + + @jax.jit + def train_step( + params: nnx.State, + opt_state: optax.OptState, + batch: Tuple[jax.Array, jax.Array], + *, + noise_state: Any, + ) -> Tuple[nnx.State, optax.OptState, Any, jax.Array]: + """Performs a single training step with DP-SGD. + + Args: + params: Current model parameters. + opt_state: Current optimizer state. + batch: A tuple (x, y) of input and target data. + noise_state: Current state of the noise mechanism. + + Returns: + Updated params, opt_state, noise_state, and the mean loss for the batch. + """ + x, y = batch + + # Compute clipped gradients and per-example loss values + grads, loss = grad_fn(params, x, y) + + # Add Privacy Noise + noisy_grads, noise_state = privatizer.update(grads, noise_state) + + # Apply updates using Optax + updates, opt_state = optimizer.update(noisy_grads, opt_state) + params = optax.apply_updates(params, updates) + + # loss is an Aux object containing 'values' + return params, opt_state, noise_state, loss.values.mean() + + # Training loop + # TODO: Use jax_privacy.batch_selection.CyclicPoissonSampling + print(f"Training for {NUM_STEPS} steps...") + for step in range(NUM_STEPS): + batch = get_batch(data, BATCH_SIZE, CONTEXT_LENGTH) + + params, opt_state, noise_state, loss = train_step( + params, opt_state, batch, noise_state=noise_state + ) + + print(f"Step {step + 1}/{NUM_STEPS}, Loss: {loss:.4f}") + + print("Training Complete.") + + +if __name__ == "__main__": + main() diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..884bded --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1 @@ +flax diff --git a/examples/user_level_transformer_example.py b/examples/user_level_transformer_example.py new file mode 100644 index 0000000..117e98a --- /dev/null +++ b/examples/user_level_transformer_example.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2025 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Example of user-level DP-SGD for fine-tuning a transformer model. + +This example implements user-level differentially private stochastic gradient +descent (DP-SGD) for a transformer model, using the `jax_privacy` library's +native components. + +The implementation demonstrates how to use `UserSelectionStrategy` to handle +unbalanced user datasets efficiently, where each user contributes a different +number of examples. + +For more details on the user-level DP-SGD algorithm for large language models, +refer to the paper: +Charles et al. (2024), "Fine-Tuning Large Language Models with User-Level +Differential Privacy" (https://arxiv.org/abs/2404.06713). +""" + +from absl import app +from absl import flags +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +from jax_privacy.batch_selection import CyclicPoissonSampling +from jax_privacy.batch_selection import UserSelectionStrategy +from jax_privacy.clipping import clipped_grad + + +_USERS_PER_BATCH = flags.DEFINE_integer( + 'users_per_batch', 4, 'Number of users to select in each batch.' +) +_EXAMPLES_PER_USER = flags.DEFINE_integer( + 'examples_per_user', 2, 'Number of examples to select for each user.' +) +_STEPS = flags.DEFINE_integer('steps', 10, 'Number of training steps.') +_L2_CLIP_NORM = flags.DEFINE_float( + 'l2_clip_norm', 1.0, 'L2 clipping norm for gradients.' +) +_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') + + +class TransformerDecoder(nn.Module): + """A minimal Transformer Decoder.""" + vocab_size: int + embed_dim: int + num_heads: int + ff_dim: int + + @nn.compact + def __call__(self, x, train: bool): + x = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)(x) + x = nn.SelfAttention( + num_heads=self.num_heads, qkv_features=self.embed_dim + )(x) + x = nn.Dense(self.ff_dim)(x) + x = nn.relu(x) + x = nn.Dense(self.vocab_size)(x) + return x + + +def get_synthetic_data( + num_users: int, + num_examples_per_user: list[int], + seq_len: int, + vocab_size: int, +): + """Generates synthetic data for the transformer model.""" + data = [] + labels = [] + user_ids = [] + for i in range(num_users): + user_data = np.random.randint( + 0, vocab_size, size=(num_examples_per_user[i], seq_len) + ) + user_labels = np.random.randint( + 0, vocab_size, size=(num_examples_per_user[i], seq_len) + ) + data.append(user_data) + labels.append(user_labels) + user_ids.extend([i] * num_examples_per_user[i]) + return np.concatenate(data), np.concatenate(labels), np.array(user_ids) + + +def main(_): + # 1. Model & Data + vocab_size = 1000 + embed_dim = 64 + num_heads = 4 + ff_dim = 256 + seq_len = 32 + num_users = 20 + num_examples_per_user = np.random.randint(1, 10, size=num_users) + + data, labels, user_ids = get_synthetic_data( + num_users=num_users, + num_examples_per_user=num_examples_per_user, + seq_len=seq_len, + vocab_size=vocab_size, + ) + + model = TransformerDecoder( + vocab_size=vocab_size, + embed_dim=embed_dim, + num_heads=num_heads, + ff_dim=ff_dim, + ) + params = model.init( + jax.random.key(0), jnp.zeros((1, seq_len), dtype=jnp.int32), train=False + )['params'] + optimizer = optax.adam(_LEARNING_RATE.value) + opt_state = optimizer.init(params) + + # 2. Batch Selection + sampling_prob = _USERS_PER_BATCH.value / num_users + base_strategy = CyclicPoissonSampling( + sampling_prob=sampling_prob, iterations=_STEPS.value + ) + user_strategy = UserSelectionStrategy( + base_strategy=base_strategy, + examples_per_user_per_batch=_EXAMPLES_PER_USER.value, + ) + + # 3. Training Step & Clipping + def loss_fn(params, batch_data, batch_labels): + logits = model.apply({'params': params}, batch_data, train=True) + one_hot_labels = jax.nn.one_hot(batch_labels, num_classes=vocab_size) + return jnp.mean( + optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) + ) + + # `clipped_grad` wraps the loss function to compute per-user clipped + # gradients. + # With `keep_batch_dim=False`, the loss function receives a batch of examples + # for a single user. The gradient is computed over this batch, and the + # resulting gradient (which is an average over the user's examples) is + # clipped. This aligns with the core requirement of user-level DP. + grad_fn = clipped_grad( + loss_fn, + l2_clip_norm=_L2_CLIP_NORM.value, + batch_argnums=(1, 2), + keep_batch_dim=False, + ) + + @jax.jit + def train_step(params, opt_state, batch_data, batch_labels): + grads = grad_fn(params, batch_data, batch_labels) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state + + # 4. Training Loop + batch_iterator = user_strategy.batch_iterator(user_ids, rng=0) + for step, user_batch_indices in enumerate(batch_iterator): + if user_batch_indices.size == 0: + print(f"Step {step}: Skipping empty batch.") + continue + + batch_data = data[user_batch_indices] + batch_labels = labels[user_batch_indices] + params, opt_state = train_step(params, opt_state, batch_data, batch_labels) + print(f"Step {step}: Completed.") + + print("Training finished successfully.") + + +if __name__ == '__main__': + app.run(main) diff --git a/pyproject.toml b/pyproject.toml index 6fb9224..31bc8b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ dev = [ "keras", "tensorflow", ] +examples = [ + "flax", +] [project.urls] Homepage = "https://github.com/google-deepmind/jax_privacy"