1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- """API for adding DP-SGD to a Keras model."""
16+ # pylint: disable=todo-style
17+ """API for adding DP-SGD to a Keras model.
18+
19+ Example Usage:
20+
21+ .. code-block:: python
22+
23+ import os
24+ os.environ["KERAS_BACKEND"] = "jax"
25+ import keras
26+ from jax_privacy import keras_api
27+
28+ model = keras.Sequential([
29+ keras.Input(shape=(1,)),
30+ keras.layers.Dense(1),
31+ ])
32+ params = keras_api.DPKerasConfig(
33+ epsilon=1.0,
34+ delta=1e-5,
35+ clipping_norm=1.0,
36+ batch_size=8,
37+ gradient_accumulation_steps=1,
38+ train_steps=10,
39+ train_size=80,
40+ noise_multiplier=1.0,
41+ )
42+ private_model = keras_api.make_private(model, params)
43+ private_model.get_noise_multiplier()
44+
45+ """
1746
1847import dataclasses
1948import functools
@@ -241,6 +270,7 @@ def make_private(model: keras.Model, params: DPKerasConfig) -> keras.Model:
241270 # that updates the metrics variables for DP-SGD training.
242271
243272 _add_dp_sgd_attributes (model , params )
273+ model .get_noise_multiplier = types .MethodType (get_noise_multiplier , model )
244274 model .fit = types .MethodType (
245275 _create_fit_fn_with_validation (model .fit , params ), model
246276 )
@@ -256,6 +286,28 @@ def make_private(model: keras.Model, params: DPKerasConfig) -> keras.Model:
256286 return model
257287
258288
289+ def get_noise_multiplier (model : keras .Model ) -> float :
290+ """Returns the noise multiplier used for DP-SGD training.
291+
292+ If the noise multiplier is not set in DPKerasConfig, this will calibrate it
293+ once and cache the value on the model.
294+
295+ Args:
296+ model: A Keras model previously wrapped with make_private().
297+
298+ Returns:
299+ The configured or calibrated noise multiplier.
300+ """
301+ if not hasattr (model , '_dp_params' ):
302+ raise ValueError (
303+ 'Model does not appear to be a DP-SGD Keras model. '
304+ 'Call make_private() first.'
305+ )
306+ return _resolve_noise_multiplier (
307+ model ._dp_params , model # pylint: disable=protected-access
308+ )
309+
310+
259311def _validate_model (model : keras .Model ) -> None :
260312 if not isinstance (model , keras .Model ):
261313 raise ValueError (f'Model { model } is not a Keras model.' )
@@ -285,6 +337,7 @@ def _validate_optimizer(model: keras.Model, params: DPKerasConfig) -> None:
285337def _add_dp_sgd_attributes (model : keras .Model , params : DPKerasConfig ) -> None :
286338 """Adds DP-SGD training attributes to the Keras model."""
287339 model ._dp_params = params # pylint: disable=protected-access
340+ model ._dp_noise_multiplier = params .noise_multiplier # pylint: disable=protected-access
288341 seed = _get_random_int64 () if params .seed is None else params .seed
289342 model .add_weight (
290343 name = '_rng' ,
@@ -488,6 +541,7 @@ def _dp_train_step(
488541 self ._dp_params , # pylint: disable=protected-access
489542 state ,
490543 data ,
544+ model = self ,
491545 )
492546 (
493547 unscaled_loss ,
@@ -533,11 +587,38 @@ def _dp_train_step(
533587LossFn = typing .Callable [..., tuple [chex .Numeric , _AuxType ]]
534588
535589
590+ def _resolve_noise_multiplier (
591+ dp_params : DPKerasConfig , model : keras .Model | None = None
592+ ) -> float :
593+ """Returns a cached noise multiplier or calibrates it once.
594+
595+ Args:
596+ dp_params: DP configuration to read or calibrate the noise multiplier
597+ from.
598+ model: Optional Keras model used to cache/reuse the calibrated value.
599+
600+ Returns:
601+ The configured or calibrated noise multiplier.
602+ """
603+ if dp_params .noise_multiplier is not None :
604+ return dp_params .noise_multiplier
605+ if model is not None :
606+ cached = getattr (model , '_dp_noise_multiplier' , None )
607+ if cached is not None :
608+ return cached
609+ calibrated = dp_params .update_with_calibrated_noise_multiplier ()
610+ noise_multiplier = calibrated .noise_multiplier
611+ if model is not None :
612+ model ._dp_noise_multiplier = noise_multiplier # pylint: disable=protected-access
613+ return noise_multiplier
614+
615+
536616def _noised_clipped_grads (
537617 compute_loss_and_updates_fn : LossFn ,
538618 dp_params : DPKerasConfig ,
539619 state : _StateType ,
540620 data : _KerasInputsDataType ,
621+ model : keras .Model | None = None ,
541622) -> tuple [tuple [chex .Numeric , _AuxType ], chex .ArrayTree ]:
542623 """Computes noised and clipped gradients.
543624
@@ -548,6 +629,7 @@ def _noised_clipped_grads(
548629 state: The state of the model.
549630 data: The data for the model: triple of x, y (can be None), sample_weight
550631 (can be None).
632+ model: Optional Keras model used to cache the calibrated noise multiplier.
551633
552634 Returns:
553635 (loss, aux), grads
@@ -584,11 +666,7 @@ def _noised_clipped_grads(
584666 optimizer_variables ,
585667 )
586668
587- noise_multiplier = (
588- dp_params .noise_multiplier
589- if dp_params .noise_multiplier is not None
590- else dp_params .update_with_calibrated_noise_multiplier ().noise_multiplier
591- )
669+ noise_multiplier = _resolve_noise_multiplier (dp_params , model )
592670 l2_sensitivity = clipped_grad_fn .l2_norm_bound
593671 accumulation_factor = np .sqrt (dp_params .gradient_accumulation_steps )
594672 stddev = noise_multiplier * l2_sensitivity / accumulation_factor
0 commit comments