Skip to content

Commit 6512e44

Browse files
Add registry function and tests for tf.keras.layers.EinsumDense.
PiperOrigin-RevId: 511864469
1 parent d7cd3f8 commit 6512e44

File tree

6 files changed

+652
-89
lines changed

6 files changed

+652
-89
lines changed

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/BUILD

+8
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@ py_library(
66
name = "gradient_clipping_utils",
77
srcs = ["gradient_clipping_utils.py"],
88
srcs_version = "PY3",
9+
deps = [":layer_registry"],
10+
)
11+
12+
py_library(
13+
name = "einsum_utils",
14+
srcs = ["einsum_utils.py"],
15+
srcs_version = "PY3",
916
)
1017

1118
py_library(
1219
name = "layer_registry",
1320
srcs = ["layer_registry.py"],
1421
srcs_version = "PY3",
22+
deps = [":einsum_utils"],
1523
)
1624

1725
py_library(

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,20 @@
2121
`compute_gradient_norms()` function).
2222
"""
2323

24+
from typing import Union, Iterable, Text, TypeAlias
25+
2426
import tensorflow as tf
2527
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
28+
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
29+
30+
InputTensor: TypeAlias = Union[
31+
tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]
32+
]
2633

2734

28-
def get_registry_generator_fn(tape, layer_registry):
35+
def get_registry_generator_fn(
36+
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
37+
):
2938
"""Creates the generator function for `compute_gradient_norms()`."""
3039
if layer_registry is None:
3140
# Needed for backwards compatibility.
@@ -53,7 +62,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
5362
return registry_generator_fn
5463

5564

56-
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
65+
def compute_gradient_norms(
66+
input_model: tf.keras.Model,
67+
x_batch: InputTensor,
68+
y_batch: tf.Tensor,
69+
layer_registry: lr.LayerRegistry,
70+
):
5771
"""Computes the per-example loss gradient norms for given data.
5872
5973
Applies a variant of the approach given in
@@ -62,7 +76,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
6276
Args:
6377
input_model: The `tf.keras.Model` from which to obtain the layers from. The
6478
loss of the model *must* be a scalar loss.
65-
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
79+
x_batch: An `InputTensor` representing a batch of inputs to the model. The
6680
first axis must be the batch dimension.
6781
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
6882
must be the batch dimension. The number of examples should match the
@@ -106,7 +120,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
106120
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
107121

108122

109-
def compute_clip_weights(l2_norm_clip, gradient_norms):
123+
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
110124
"""Computes the per-example loss/clip weights for clipping.
111125
112126
When the sum of the per-example losses is replaced a weighted sum, where
@@ -132,7 +146,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
132146

133147

134148
def compute_pred_and_clipped_gradients(
135-
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
149+
input_model: tf.keras.Model,
150+
x_batch: InputTensor,
151+
y_batch: tf.Tensor,
152+
l2_norm_clip: float,
153+
layer_registry: lr.LayerRegistry,
136154
):
137155
"""Computes the per-example predictions and per-example clipped loss gradient.
138156
@@ -147,7 +165,7 @@ def compute_pred_and_clipped_gradients(
147165
148166
Args:
149167
input_model: The `tf.keras.Model` from which to obtain the layers from.
150-
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
168+
x_batch: An `InputTensor` representing a batch of inputs to the model. The
151169
first axis must be the batch dimension.
152170
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
153171
must be the batch dimension. The number of examples should match the

0 commit comments

Comments
 (0)