Skip to content

Commit 345fe5b

Browse files
walidktensorflower-gardener
authored andcommitted
Support fast clipping in DPAM.
PiperOrigin-RevId: 533589057
1 parent 60d237b commit 345fe5b

File tree

5 files changed

+283
-211
lines changed

5 files changed

+283
-211
lines changed

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py

+48-35
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def registry_generator_fn(layer_instance, args, kwargs):
6969

7070
def compute_gradient_norms(
7171
input_model: tf.keras.Model,
72+
layer_registry: lr.LayerRegistry,
7273
x_batch: InputTensor,
7374
y_batch: tf.Tensor,
74-
layer_registry: lr.LayerRegistry,
75+
weight_batch: Optional[tf.Tensor] = None,
7576
per_example_loss_fn: Optional[LossFn] = None,
7677
num_microbatches: Optional[lr.BatchSize] = None,
7778
trainable_vars: Optional[List[tf.Variable]] = None,
@@ -84,15 +85,16 @@ def compute_gradient_norms(
8485
Args:
8586
input_model: The `tf.keras.Model` from which to obtain the layers from. The
8687
loss of the model *must* be a scalar loss.
88+
layer_registry: A `LayerRegistry` instance containing functions that help
89+
compute gradient norms quickly. See
90+
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
91+
more details.
8792
x_batch: An `InputTensor` representing a batch of inputs to the model. The
8893
first axis must be the batch dimension.
8994
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
9095
must be the batch dimension. The number of examples should match the
9196
number of examples in `x_batch`.
92-
layer_registry: A `LayerRegistry` instance containing functions that help
93-
compute gradient norms quickly. See
94-
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
95-
more details.
97+
weight_batch: Optional batch of weights, passed to the loss function.
9698
per_example_loss_fn: takes as input predictions, labels and weights, and
9799
outputs a vector of per-example losses. If None, derived from
98100
`input_model.loss` by disabling its reduction.
@@ -108,8 +110,9 @@ def compute_gradient_norms(
108110
variables are included.
109111
110112
Returns:
111-
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
112-
per-example loss function.
113+
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
114+
example loss (when num_microbatches is None) or the norm of the gradient of
115+
the i-th microbatch loss (define as a mean over the microbatch).
113116
"""
114117
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
115118
registry_generator_fn = get_registry_generator_fn(
@@ -127,7 +130,7 @@ def compute_gradient_norms(
127130
loss_config = input_model.loss.get_config()
128131
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
129132
per_example_loss_fn = input_model.loss.from_config(loss_config)
130-
losses = per_example_loss_fn(y_batch, model_outputs)
133+
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
131134
if losses.shape is None:
132135
raise NotImplementedError(
133136
"The unreduced (or per-example) loss's shape cannot be `None`"
@@ -140,7 +143,7 @@ def compute_gradient_norms(
140143
)
141144
if num_microbatches is not None:
142145
losses = tf.reduce_mean(
143-
lr.add_microbatch_axis(losses, num_microbatches), axis=1
146+
lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1
144147
)
145148
summed_loss = tf.reduce_sum(losses)
146149
# Unwrap the generator outputs so that the next loop avoids duplicating
@@ -165,8 +168,10 @@ def compute_gradient_norms(
165168
vars_list,
166169
unconnected_gradients=tf.UnconnectedGradients.ZERO,
167170
)
171+
if not grads_list:
172+
raise ValueError('Empty gradient list.')
168173
sqr_norm_list = []
169-
for grads, f in zip(grads_list, sqr_norm_fns_list):
174+
for grads, f in zip(grads_list, sqr_norm_fns_list, strict=True):
170175
sqr_norm_list.append(f(grads))
171176
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
172177
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
@@ -199,10 +204,11 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
199204

200205
def compute_clipped_gradients_and_outputs(
201206
input_model: tf.keras.Model,
202-
x_batch: InputTensor,
203-
y_batch: tf.Tensor,
204207
l2_norm_clip: float,
205208
layer_registry: lr.LayerRegistry,
209+
x_batch: InputTensor,
210+
y_batch: tf.Tensor,
211+
weight_batch: Optional[tf.Tensor] = None,
206212
num_microbatches: Optional[lr.BatchSize] = None,
207213
clipping_loss: Optional[LossFn] = None,
208214
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
@@ -218,11 +224,6 @@ def compute_clipped_gradients_and_outputs(
218224
219225
Args:
220226
input_model: The `tf.keras.Model` from which to obtain the layers from.
221-
x_batch: An `InputTensor` representing a batch of inputs to the model. The
222-
first axis must be the batch dimension.
223-
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
224-
must be the batch dimension. The number of examples should match the
225-
number of examples in `x_batch`.
226227
l2_norm_clip: A `float` indicating the norm to which per-example gradients
227228
will be clipped. That is, all gradients of the per-example loss functions
228229
will have norm at most `l2_norm_clip`.
@@ -232,6 +233,15 @@ def compute_clipped_gradients_and_outputs(
232233
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
233234
squared norms of a layer's pre-activation tensor, and `vars` are relevant
234235
trainable weights (see `layer_registry_factories.py` for examples).
236+
x_batch: An `InputTensor` representing a batch of inputs to the model. The
237+
first axis must be the batch dimension.
238+
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
239+
must be the batch dimension. The number of examples should match the
240+
number of examples in `x_batch`.
241+
weight_batch: Optional vector of weights, passed to the loss function. Must
242+
be of size [batch_size]. In case of microbatching, this will be reshaped
243+
to [num_microbatches, batch_size/num_microbatches] before passing it to
244+
the loss.
235245
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
236246
microbatches. If not None, indicates that the loss is grouped into
237247
num_microbatches (in this case, the batch dimension needs to be a multiple
@@ -243,11 +253,10 @@ def compute_clipped_gradients_and_outputs(
243253
the value of the clipped loss does not reflect the true loss.
244254
245255
Returns:
246-
A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
247-
clipped gradient of the loss function, the second is the result of
248-
applying `input_model` to `x_batch`, and the third is loss value of
249-
`input_model`, weighted by the loss weights generated by a specific
250-
`compute_clip_weights()` call.
256+
clipped_grad: the clipped gradient of the loss function
257+
y_pred: the result of applying `input_model` to `x_batch`
258+
clipping_loss_value: the loss value weighted in such a way that its gradient
259+
is `clipped_grad`.
251260
"""
252261
if input_model.loss.reduction == 'none':
253262
raise NotImplementedError(
@@ -258,13 +267,25 @@ def compute_clipped_gradients_and_outputs(
258267
clipping_loss = input_model.compiled_loss
259268
gradient_norms = compute_gradient_norms(
260269
input_model,
270+
layer_registry,
261271
x_batch,
262272
y_batch,
263-
layer_registry,
273+
weight_batch,
264274
num_microbatches=num_microbatches,
265275
trainable_vars=input_model.trainable_variables,
266276
)
267-
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
277+
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
278+
if weight_batch is not None:
279+
if num_microbatches is None:
280+
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
281+
else:
282+
# In this case, weight_batch is of shape [batch_size], we first reshape to
283+
# [num_microbatches, microbatch_size] then multiply by the clip_weights
284+
# (which is of shape [num_microbatches])
285+
weight_batch = lr.maybe_add_microbatch_axis(
286+
weight_batch, num_microbatches
287+
)
288+
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
268289
with tf.GradientTape() as tape:
269290
# WARNING: When num_microbatches is not None, we need to be sure that
270291
# `compute_loss` always computes the mean over the microbatches
@@ -274,17 +295,9 @@ def compute_clipped_gradients_and_outputs(
274295
# is not defined in the contract so may not hold, especially for
275296
# custom losses.
276297
y_pred = input_model(x_batch, training=True)
277-
loss_y_batch = (
278-
y_batch
279-
if num_microbatches is None
280-
else lr.add_microbatch_axis(y_batch, num_microbatches)
281-
)
282-
loss_y_pred = (
283-
y_pred
284-
if num_microbatches is None
285-
else lr.add_microbatch_axis(y_pred, num_microbatches)
286-
)
287-
clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
298+
mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches)
299+
mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches)
300+
clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights)
288301
clipped_grads = tape.gradient(
289302
clipping_loss_value,
290303
input_model.trainable_variables,

0 commit comments

Comments
 (0)