-
Notifications
You must be signed in to change notification settings - Fork 36
Fix Issue #85: Support for cryptographically secure randomness #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix Issue #85: Support for cryptographically secure randomness #104
Conversation
ryan112358
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR, very interesting approach! I did initially like the API change you made to noise_addition, but realized there are some tricky subtleties that could be problematic. For now, I think you should be able to submit the example after resolving a few comments. Overall looks great!
jax_privacy/noise_addition.py
Outdated
| raise ValueError(f'Expected 2D matrix, found {noising_matrix.shape=}.') | ||
|
|
||
| def privatize(sum_of_clipped_grads, noise_state): | ||
| def privatize(sum_of_clipped_grads, noise_state, *, noise=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we need to be careful about potential mis-configurations by the user: I think this should be called iid_normal
jax_privacy/noise_addition.py
Outdated
| sampler=functools.partial(_gaussian_linear_combination, matrix_row), | ||
| dtype=dtype | ||
| ) | ||
| if noise is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semantics of noise are very different here than they are for _streaming_matrix_privatizer. Here, noise is correlated according to the given noising matrix, while in the other it's iid. Also, this approach to noise generation crucially relies on the ability to regenerate old-noise, which may be incompatible with hardware rngs.
examples/secure_noise_example.py
Outdated
|
|
||
| step_time = end_time - start_time | ||
| total_time += step_time | ||
| print(f" Step time: {step_time:.4f}s") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update this script with a command line flag on whether to use secure generation? And can you verify that the step time and final utility are comparable whether or not it is used?
examples/secure_noise_example.py
Outdated
| ) | ||
|
|
||
| # Block to ensure the step is finished before measuring time | ||
| jax.block_until_ready(params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block will likely slow down training, can you remove it and just report the final time at the end of training? If it is a long-running program, you can use tqdm on the loop to monitor progress
examples/secure_noise_example.py
Outdated
|
|
||
| def generate_secure_noise(params_tree, stddev, generator): | ||
| """Generates a PyTree of Gaussian noise on the CPU.""" | ||
| numpy_generator = np.random.default_rng(generator.bit_generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment that it is crucial that this be called outside of a jitted context, otherwise it may produce identical noise every time it is called (since generator would have to be a static arg)
| @@ -0,0 +1,150 @@ | |||
| # coding=utf-8 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this file pass pytype and pylint checks?
Hi @ryan112358, Thank you for the insightful feedback, especially regarding the semantic differences between IID and correlated noise. I’ve refactored the PR to move all the custom logic into the example script to maintain the library's architectural purity. Summary of changes, I made: API Purity: Reverted all changes to jax_privacy/noise_addition.py. The library API remains untouched. I believe this standalone approach resolves the subtleties you mentioned regarding hardware RNG compatibility and noise correlation. Looking forward to your review! |
| def main(_): | ||
| params = toy_model_params() | ||
| # This privatizer will be used to generate noise if use_secure_rng is False | ||
| privatizer = noise_addition.gaussian_privatizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's bypass the privatizer API all-together in this example, just directly add the noise using jax.tree.map(jnp.add, grads, noise)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made the changes and also commented the technical detailed summery.
examples/secure_noise_example.py
Outdated
| # Dummy batch | ||
| batch = None | ||
| # We need to define grads_treedef once outside the loop | ||
| grads_treedef = jax.eval_shape(lambda p: jax.grad(loss_fn)(p, batch), params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to use jax.tree.structure(params) here
examples/secure_noise_example.py
Outdated
| def generate_secure_noise(stddev, grads_treedef): | ||
| """Generates i.i.d. Gaussian noise on the CPU using NumPy.""" | ||
| return jax.tree.map( | ||
| lambda x: np.random.normal(scale=stddev, size=x.shape).astype(x.dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use randomgen here with an appropriate rng
examples/secure_noise_example.py
Outdated
| if secure_noise is not None: | ||
| # Manually add the CPU-generated secure noise | ||
| iid_normal = secure_noise | ||
| noisy_grads = jax.tree.map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can use jnp.add instead of lambda here
…n_step using jnp.add
|
Hello @ryan112358 , Thanks again for the approval and for those final suggestions. They were really helpful for cleaning up the implementation and making it feel more like a native JAX workflow. I’ve just pushed an update that addresses all your points: ✅ Architectural Decoupling: I’ve completely decoupled the training step from the privatizer API in this example. Instead of calling privatizer.update and then conditionally overwriting the output, I refactored train_step to manually handle the noise addition. It now uses a direct jax.tree_util.tree_map(jnp.add, grads, secure_noise) call. This is much cleaner and, as you noted, avoids any semantic confusion between IID and correlated noise mechanisms. ✅ Secure Entropy with randomgen I’ve integrated randomgen.AESCounter as the bit generator for the CPU-side noise generation. To ensure we have fresh entropy at every step without triggering the JIT static-arg issue, I’m now initializing a PyTree of persistent Generator objects outside the loop and mapping over them. This allows us to keep the generator state alive across iterations while sourcing high-quality randomness from the AES counter. ✅ Structural Simplification: I took your suggestion to simplify the tree structure logic. I’m now mapping directly over the params PyTree to initialize the keys and generators. This removed the need for jax.eval_shape and makes the setup phase much more readable. ✅ Asynchronous Performance & Benchmarking: I removed the jax.block_until_ready call from within the training loop. This allows JAX to effectively overlap the CPU-side noise generation for the next step with the accelerator-side gradient computation for the current step. ✅ I also ran a performance comparison to ensure there was no significant overhead: The secure path is performing well and good (even slightly faster in this toy setup), confirming that sourcing randomness from the CPU asynchronously is a highly viable approach. ✅ The code has been linted, and I've ensured that generate_secure_noise is clearly documented as a non-JIT function to prevent any future misconfigurations. Looking forward to the merge! |
This PR introduces support for cryptographically secure randomness in both batch selection and noise addition, as requested in Issue #85.
Key Changes:
✅Core API Extension: Modified privatizers in noise_addition.py to accept an optional noise PyTree. This allows for external noise injection from secure CPU sources, bypassing standard JAX PRNGs when required.
✅Secure Example Binary: Created examples/secure_noise_example.py which demonstrates:
✅Using randomgen (AES-based) for secure batch selection indices.
✅Asynchronous CPU-based noise generation using JAX's asynchronous dispatch.
✅Injection of CPU-generated Gaussian noise into the device-bound training step.
Validation:
✅Verified compatibility with randomgen.generator.ExtendedGenerator.
✅Confirmed that the "Secure Noise" example executes successfully with a hybrid CPU-Device loop.
Fixes #85