Skip to content

Commit 6ab5ef4

Browse files
Merge pull request #94 from staticpayload:fix/keras-noise-multiplier-getter
PiperOrigin-RevId: 859306453
2 parents 1e3834c + 8b3edaf commit 6ab5ef4

File tree

2 files changed

+103
-6
lines changed

2 files changed

+103
-6
lines changed

jax_privacy/keras_api.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,36 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""API for adding DP-SGD to a Keras model."""
16+
# pylint: disable=todo-style
17+
"""API for adding DP-SGD to a Keras model.
18+
19+
Example Usage:
20+
21+
.. code-block:: python
22+
23+
import os
24+
os.environ["KERAS_BACKEND"] = "jax"
25+
import keras
26+
from jax_privacy import keras_api
27+
28+
model = keras.Sequential([
29+
keras.Input(shape=(1,)),
30+
keras.layers.Dense(1),
31+
])
32+
params = keras_api.DPKerasConfig(
33+
epsilon=1.0,
34+
delta=1e-5,
35+
clipping_norm=1.0,
36+
batch_size=8,
37+
gradient_accumulation_steps=1,
38+
train_steps=10,
39+
train_size=80,
40+
noise_multiplier=1.0,
41+
)
42+
private_model = keras_api.make_private(model, params)
43+
private_model.get_noise_multiplier()
44+
45+
"""
1746

1847
import dataclasses
1948
import functools
@@ -241,6 +270,7 @@ def make_private(model: keras.Model, params: DPKerasConfig) -> keras.Model:
241270
# that updates the metrics variables for DP-SGD training.
242271

243272
_add_dp_sgd_attributes(model, params)
273+
model.get_noise_multiplier = types.MethodType(get_noise_multiplier, model)
244274
model.fit = types.MethodType(
245275
_create_fit_fn_with_validation(model.fit, params), model
246276
)
@@ -256,6 +286,28 @@ def make_private(model: keras.Model, params: DPKerasConfig) -> keras.Model:
256286
return model
257287

258288

289+
def get_noise_multiplier(model: keras.Model) -> float:
290+
"""Returns the noise multiplier used for DP-SGD training.
291+
292+
If the noise multiplier is not set in DPKerasConfig, this will calibrate it
293+
once and cache the value on the model.
294+
295+
Args:
296+
model: A Keras model previously wrapped with make_private().
297+
298+
Returns:
299+
The configured or calibrated noise multiplier.
300+
"""
301+
if not hasattr(model, '_dp_params'):
302+
raise ValueError(
303+
'Model does not appear to be a DP-SGD Keras model. '
304+
'Call make_private() first.'
305+
)
306+
return _resolve_noise_multiplier(
307+
model._dp_params, model # pylint: disable=protected-access
308+
)
309+
310+
259311
def _validate_model(model: keras.Model) -> None:
260312
if not isinstance(model, keras.Model):
261313
raise ValueError(f'Model {model} is not a Keras model.')
@@ -285,6 +337,7 @@ def _validate_optimizer(model: keras.Model, params: DPKerasConfig) -> None:
285337
def _add_dp_sgd_attributes(model: keras.Model, params: DPKerasConfig) -> None:
286338
"""Adds DP-SGD training attributes to the Keras model."""
287339
model._dp_params = params # pylint: disable=protected-access
340+
model._dp_noise_multiplier = params.noise_multiplier # pylint: disable=protected-access
288341
seed = _get_random_int64() if params.seed is None else params.seed
289342
model.add_weight(
290343
name='_rng',
@@ -488,6 +541,7 @@ def _dp_train_step(
488541
self._dp_params, # pylint: disable=protected-access
489542
state,
490543
data,
544+
model=self,
491545
)
492546
(
493547
unscaled_loss,
@@ -533,11 +587,38 @@ def _dp_train_step(
533587
LossFn = typing.Callable[..., tuple[chex.Numeric, _AuxType]]
534588

535589

590+
def _resolve_noise_multiplier(
591+
dp_params: DPKerasConfig, model: keras.Model | None = None
592+
) -> float:
593+
"""Returns a cached noise multiplier or calibrates it once.
594+
595+
Args:
596+
dp_params: DP configuration to read or calibrate the noise multiplier
597+
from.
598+
model: Optional Keras model used to cache/reuse the calibrated value.
599+
600+
Returns:
601+
The configured or calibrated noise multiplier.
602+
"""
603+
if dp_params.noise_multiplier is not None:
604+
return dp_params.noise_multiplier
605+
if model is not None:
606+
cached = getattr(model, '_dp_noise_multiplier', None)
607+
if cached is not None:
608+
return cached
609+
calibrated = dp_params.update_with_calibrated_noise_multiplier()
610+
noise_multiplier = calibrated.noise_multiplier
611+
if model is not None:
612+
model._dp_noise_multiplier = noise_multiplier # pylint: disable=protected-access
613+
return noise_multiplier
614+
615+
536616
def _noised_clipped_grads(
537617
compute_loss_and_updates_fn: LossFn,
538618
dp_params: DPKerasConfig,
539619
state: _StateType,
540620
data: _KerasInputsDataType,
621+
model: keras.Model | None = None,
541622
) -> tuple[tuple[chex.Numeric, _AuxType], chex.ArrayTree]:
542623
"""Computes noised and clipped gradients.
543624
@@ -548,6 +629,7 @@ def _noised_clipped_grads(
548629
state: The state of the model.
549630
data: The data for the model: triple of x, y (can be None), sample_weight
550631
(can be None).
632+
model: Optional Keras model used to cache the calibrated noise multiplier.
551633
552634
Returns:
553635
(loss, aux), grads
@@ -584,11 +666,7 @@ def _noised_clipped_grads(
584666
optimizer_variables,
585667
)
586668

587-
noise_multiplier = (
588-
dp_params.noise_multiplier
589-
if dp_params.noise_multiplier is not None
590-
else dp_params.update_with_calibrated_noise_multiplier().noise_multiplier
591-
)
669+
noise_multiplier = _resolve_noise_multiplier(dp_params, model)
592670
l2_sensitivity = clipped_grad_fn.l2_norm_bound
593671
accumulation_factor = np.sqrt(dp_params.gradient_accumulation_steps)
594672
stddev = noise_multiplier * l2_sensitivity / accumulation_factor

tests/keras_api_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,25 @@ def test_add_dp_sgd_attributes(self):
143143

144144
self.assertTrue(hasattr(model, "_dp_params"))
145145
self.assertEqual(model._dp_params, params)
146+
self.assertEqual(model._dp_noise_multiplier, params.noise_multiplier)
147+
148+
def test_get_noise_multiplier_uses_config_value(self):
149+
model = keras.Sequential([keras.layers.Dense(10, input_shape=(784,))])
150+
params = dataclasses.replace(self._get_params(), noise_multiplier=2.5)
151+
private_model = keras_api.make_private(model, params)
152+
153+
self.assertEqual(private_model.get_noise_multiplier(), 2.5)
154+
155+
def test_get_noise_multiplier_calibrates_once(self):
156+
model = keras.Sequential([keras.layers.Dense(10, input_shape=(784,))])
157+
params = self._get_params()
158+
private_model = keras_api.make_private(model, params)
159+
160+
noise_multiplier = private_model.get_noise_multiplier()
161+
self.assertIsNotNone(noise_multiplier)
162+
self.assertGreater(noise_multiplier, 0.0)
163+
self.assertEqual(private_model._dp_noise_multiplier, noise_multiplier)
164+
self.assertEqual(private_model.get_noise_multiplier(), noise_multiplier)
146165

147166
@parameterized.named_parameters(
148167
("no_rescale_no_clip", 100.0, 1, False, [-10.0, -20.0]),

0 commit comments

Comments
 (0)