@@ -69,9 +69,10 @@ def registry_generator_fn(layer_instance, args, kwargs):
69
69
70
70
def compute_gradient_norms (
71
71
input_model : tf .keras .Model ,
72
+ layer_registry : lr .LayerRegistry ,
72
73
x_batch : InputTensor ,
73
74
y_batch : tf .Tensor ,
74
- layer_registry : lr . LayerRegistry ,
75
+ weight_batch : Optional [ tf . Tensor ] = None ,
75
76
per_example_loss_fn : Optional [LossFn ] = None ,
76
77
num_microbatches : Optional [lr .BatchSize ] = None ,
77
78
trainable_vars : Optional [List [tf .Variable ]] = None ,
@@ -84,15 +85,16 @@ def compute_gradient_norms(
84
85
Args:
85
86
input_model: The `tf.keras.Model` from which to obtain the layers from. The
86
87
loss of the model *must* be a scalar loss.
88
+ layer_registry: A `LayerRegistry` instance containing functions that help
89
+ compute gradient norms quickly. See
90
+ `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
91
+ more details.
87
92
x_batch: An `InputTensor` representing a batch of inputs to the model. The
88
93
first axis must be the batch dimension.
89
94
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
90
95
must be the batch dimension. The number of examples should match the
91
96
number of examples in `x_batch`.
92
- layer_registry: A `LayerRegistry` instance containing functions that help
93
- compute gradient norms quickly. See
94
- `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
95
- more details.
97
+ weight_batch: Optional batch of weights, passed to the loss function.
96
98
per_example_loss_fn: takes as input predictions, labels and weights, and
97
99
outputs a vector of per-example losses. If None, derived from
98
100
`input_model.loss` by disabling its reduction.
@@ -108,8 +110,9 @@ def compute_gradient_norms(
108
110
variables are included.
109
111
110
112
Returns:
111
- A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
112
- per-example loss function.
113
+ A scalar vector, whose i-th entry is the norm of the gradient of the i-th
114
+ example loss (when num_microbatches is None) or the norm of the gradient of
115
+ the i-th microbatch loss (define as a mean over the microbatch).
113
116
"""
114
117
tape = tf .GradientTape (persistent = True , watch_accessed_variables = False )
115
118
registry_generator_fn = get_registry_generator_fn (
@@ -127,7 +130,7 @@ def compute_gradient_norms(
127
130
loss_config = input_model .loss .get_config ()
128
131
loss_config ['reduction' ] = tf .keras .losses .Reduction .NONE
129
132
per_example_loss_fn = input_model .loss .from_config (loss_config )
130
- losses = per_example_loss_fn (y_batch , model_outputs )
133
+ losses = per_example_loss_fn (y_batch , model_outputs , weight_batch )
131
134
if losses .shape is None :
132
135
raise NotImplementedError (
133
136
"The unreduced (or per-example) loss's shape cannot be `None`"
@@ -140,7 +143,7 @@ def compute_gradient_norms(
140
143
)
141
144
if num_microbatches is not None :
142
145
losses = tf .reduce_mean (
143
- lr .add_microbatch_axis (losses , num_microbatches ), axis = 1
146
+ lr .maybe_add_microbatch_axis (losses , num_microbatches ), axis = 1
144
147
)
145
148
summed_loss = tf .reduce_sum (losses )
146
149
# Unwrap the generator outputs so that the next loop avoids duplicating
@@ -165,8 +168,10 @@ def compute_gradient_norms(
165
168
vars_list ,
166
169
unconnected_gradients = tf .UnconnectedGradients .ZERO ,
167
170
)
171
+ if not grads_list :
172
+ raise ValueError ('Empty gradient list.' )
168
173
sqr_norm_list = []
169
- for grads , f in zip (grads_list , sqr_norm_fns_list ):
174
+ for grads , f in zip (grads_list , sqr_norm_fns_list , strict = True ):
170
175
sqr_norm_list .append (f (grads ))
171
176
sqr_norm_tsr = tf .stack (sqr_norm_list , axis = 1 )
172
177
return tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
@@ -199,10 +204,11 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
199
204
200
205
def compute_clipped_gradients_and_outputs (
201
206
input_model : tf .keras .Model ,
202
- x_batch : InputTensor ,
203
- y_batch : tf .Tensor ,
204
207
l2_norm_clip : float ,
205
208
layer_registry : lr .LayerRegistry ,
209
+ x_batch : InputTensor ,
210
+ y_batch : tf .Tensor ,
211
+ weight_batch : Optional [tf .Tensor ] = None ,
206
212
num_microbatches : Optional [lr .BatchSize ] = None ,
207
213
clipping_loss : Optional [LossFn ] = None ,
208
214
) -> Tuple [List [tf .Tensor ], tf .Tensor , tf .Tensor ]:
@@ -218,11 +224,6 @@ def compute_clipped_gradients_and_outputs(
218
224
219
225
Args:
220
226
input_model: The `tf.keras.Model` from which to obtain the layers from.
221
- x_batch: An `InputTensor` representing a batch of inputs to the model. The
222
- first axis must be the batch dimension.
223
- y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
224
- must be the batch dimension. The number of examples should match the
225
- number of examples in `x_batch`.
226
227
l2_norm_clip: A `float` indicating the norm to which per-example gradients
227
228
will be clipped. That is, all gradients of the per-example loss functions
228
229
will have norm at most `l2_norm_clip`.
@@ -232,6 +233,15 @@ def compute_clipped_gradients_and_outputs(
232
233
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
233
234
squared norms of a layer's pre-activation tensor, and `vars` are relevant
234
235
trainable weights (see `layer_registry_factories.py` for examples).
236
+ x_batch: An `InputTensor` representing a batch of inputs to the model. The
237
+ first axis must be the batch dimension.
238
+ y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
239
+ must be the batch dimension. The number of examples should match the
240
+ number of examples in `x_batch`.
241
+ weight_batch: Optional vector of weights, passed to the loss function. Must
242
+ be of size [batch_size]. In case of microbatching, this will be reshaped
243
+ to [num_microbatches, batch_size/num_microbatches] before passing it to
244
+ the loss.
235
245
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
236
246
microbatches. If not None, indicates that the loss is grouped into
237
247
num_microbatches (in this case, the batch dimension needs to be a multiple
@@ -243,11 +253,10 @@ def compute_clipped_gradients_and_outputs(
243
253
the value of the clipped loss does not reflect the true loss.
244
254
245
255
Returns:
246
- A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
247
- clipped gradient of the loss function, the second is the result of
248
- applying `input_model` to `x_batch`, and the third is loss value of
249
- `input_model`, weighted by the loss weights generated by a specific
250
- `compute_clip_weights()` call.
256
+ clipped_grad: the clipped gradient of the loss function
257
+ y_pred: the result of applying `input_model` to `x_batch`
258
+ clipping_loss_value: the loss value weighted in such a way that its gradient
259
+ is `clipped_grad`.
251
260
"""
252
261
if input_model .loss .reduction == 'none' :
253
262
raise NotImplementedError (
@@ -258,13 +267,25 @@ def compute_clipped_gradients_and_outputs(
258
267
clipping_loss = input_model .compiled_loss
259
268
gradient_norms = compute_gradient_norms (
260
269
input_model ,
270
+ layer_registry ,
261
271
x_batch ,
262
272
y_batch ,
263
- layer_registry ,
273
+ weight_batch ,
264
274
num_microbatches = num_microbatches ,
265
275
trainable_vars = input_model .trainable_variables ,
266
276
)
267
- loss_weights = compute_clip_weights (l2_norm_clip , gradient_norms )
277
+ clip_weights = compute_clip_weights (l2_norm_clip , gradient_norms )
278
+ if weight_batch is not None :
279
+ if num_microbatches is None :
280
+ clip_weights = clip_weights * weight_batch # shape [num_microbatches]
281
+ else :
282
+ # In this case, weight_batch is of shape [batch_size], we first reshape to
283
+ # [num_microbatches, microbatch_size] then multiply by the clip_weights
284
+ # (which is of shape [num_microbatches])
285
+ weight_batch = lr .maybe_add_microbatch_axis (
286
+ weight_batch , num_microbatches
287
+ )
288
+ clip_weights = clip_weights [:, tf .newaxis ] * weight_batch
268
289
with tf .GradientTape () as tape :
269
290
# WARNING: When num_microbatches is not None, we need to be sure that
270
291
# `compute_loss` always computes the mean over the microbatches
@@ -274,17 +295,9 @@ def compute_clipped_gradients_and_outputs(
274
295
# is not defined in the contract so may not hold, especially for
275
296
# custom losses.
276
297
y_pred = input_model (x_batch , training = True )
277
- loss_y_batch = (
278
- y_batch
279
- if num_microbatches is None
280
- else lr .add_microbatch_axis (y_batch , num_microbatches )
281
- )
282
- loss_y_pred = (
283
- y_pred
284
- if num_microbatches is None
285
- else lr .add_microbatch_axis (y_pred , num_microbatches )
286
- )
287
- clipping_loss_value = clipping_loss (loss_y_batch , loss_y_pred , loss_weights )
298
+ mb_y_batch = lr .maybe_add_microbatch_axis (y_batch , num_microbatches )
299
+ mb_y_pred = lr .maybe_add_microbatch_axis (y_pred , num_microbatches )
300
+ clipping_loss_value = clipping_loss (mb_y_batch , mb_y_pred , clip_weights )
288
301
clipped_grads = tape .gradient (
289
302
clipping_loss_value ,
290
303
input_model .trainable_variables ,
0 commit comments