21
21
`compute_gradient_norms()` function).
22
22
"""
23
23
24
+ from typing import Union , Iterable , Text , TypeAlias
25
+
24
26
import tensorflow as tf
25
27
from tensorflow_privacy .privacy .fast_gradient_clipping import gradient_clipping_utils
28
+ from tensorflow_privacy .privacy .fast_gradient_clipping import layer_registry as lr
29
+
30
+ InputTensor : TypeAlias = Union [
31
+ tf .Tensor , Iterable [tf .Tensor ], dict [Text , tf .Tensor ]
32
+ ]
26
33
27
34
28
- def get_registry_generator_fn (tape , layer_registry ):
35
+ def get_registry_generator_fn (
36
+ tape : tf .GradientTape , layer_registry : lr .LayerRegistry
37
+ ):
29
38
"""Creates the generator function for `compute_gradient_norms()`."""
30
39
if layer_registry is None :
31
40
# Needed for backwards compatibility.
@@ -53,7 +62,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
53
62
return registry_generator_fn
54
63
55
64
56
- def compute_gradient_norms (input_model , x_batch , y_batch , layer_registry ):
65
+ def compute_gradient_norms (
66
+ input_model : tf .keras .Model ,
67
+ x_batch : InputTensor ,
68
+ y_batch : tf .Tensor ,
69
+ layer_registry : lr .LayerRegistry ,
70
+ ):
57
71
"""Computes the per-example loss gradient norms for given data.
58
72
59
73
Applies a variant of the approach given in
@@ -62,7 +76,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
62
76
Args:
63
77
input_model: The `tf.keras.Model` from which to obtain the layers from. The
64
78
loss of the model *must* be a scalar loss.
65
- x_batch: A `tf.Tensor ` representing a batch of inputs to the model. The
79
+ x_batch: An `InputTensor ` representing a batch of inputs to the model. The
66
80
first axis must be the batch dimension.
67
81
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
68
82
must be the batch dimension. The number of examples should match the
@@ -106,7 +120,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
106
120
return tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
107
121
108
122
109
- def compute_clip_weights (l2_norm_clip , gradient_norms ):
123
+ def compute_clip_weights (l2_norm_clip : float , gradient_norms : tf . Tensor ):
110
124
"""Computes the per-example loss/clip weights for clipping.
111
125
112
126
When the sum of the per-example losses is replaced a weighted sum, where
@@ -132,7 +146,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
132
146
133
147
134
148
def compute_pred_and_clipped_gradients (
135
- input_model , x_batch , y_batch , l2_norm_clip , layer_registry
149
+ input_model : tf .keras .Model ,
150
+ x_batch : InputTensor ,
151
+ y_batch : tf .Tensor ,
152
+ l2_norm_clip : float ,
153
+ layer_registry : lr .LayerRegistry ,
136
154
):
137
155
"""Computes the per-example predictions and per-example clipped loss gradient.
138
156
@@ -147,7 +165,7 @@ def compute_pred_and_clipped_gradients(
147
165
148
166
Args:
149
167
input_model: The `tf.keras.Model` from which to obtain the layers from.
150
- x_batch: A `tf.Tensor ` representing a batch of inputs to the model. The
168
+ x_batch: An `InputTensor ` representing a batch of inputs to the model. The
151
169
first axis must be the batch dimension.
152
170
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
153
171
must be the batch dimension. The number of examples should match the
0 commit comments