Skip to content

Commit 0f5acf8

Browse files
Add additional tests and checks on the passed loss function.
PiperOrigin-RevId: 532225904
1 parent 8fdac5f commit 0f5acf8

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py

+15
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def compute_gradient_norms(
128128
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
129129
per_example_loss_fn = input_model.loss.from_config(loss_config)
130130
losses = per_example_loss_fn(y_batch, model_outputs)
131+
if losses.shape is None:
132+
raise NotImplementedError(
133+
"The unreduced (or per-example) loss's shape cannot be `None`"
134+
)
135+
if len(losses.shape) != 1:
136+
raise NotImplementedError(
137+
'The unreduced (or per-example) loss needs to have a shape of length '
138+
'one, but received an unreduced loss of shape length %s'
139+
% len(losses.shape)
140+
)
131141
if num_microbatches is not None:
132142
losses = tf.reduce_mean(
133143
lr.add_microbatch_axis(losses, num_microbatches), axis=1
@@ -239,6 +249,11 @@ def compute_clipped_gradients_and_outputs(
239249
`input_model`, weighted by the loss weights generated by a specific
240250
`compute_clip_weights()` call.
241251
"""
252+
if input_model.loss.reduction == 'none':
253+
raise NotImplementedError(
254+
'Fast gradient clipping does not support '
255+
'models with unreduced loss functions.'
256+
)
242257
if clipping_loss is None:
243258
clipping_loss = input_model.compiled_loss
244259
gradient_norms = compute_gradient_norms(

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py

+79-1
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def test_gradient_norms_on_various_models(
426426

427427
class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
428428

429-
# TODO(wkong): Test sparse input tensors when the GitHub CI environment
429+
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
430430
# supports them for embeddings.
431431
@parameterized.product(
432432
x_batch=[
@@ -541,5 +541,83 @@ def test_gradient_norms_on_various_models(
541541
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
542542

543543

544+
class ClipGradsComputeClippedGradsAndOutputsTest(
545+
tf.test.TestCase, parameterized.TestCase
546+
):
547+
548+
def setUp(self):
549+
super().setUp()
550+
dense_generator = lambda a, b: tf.keras.layers.Dense(b)
551+
self._input_dim = 2
552+
self._output_dim = 3
553+
self._model = make_two_layer_sequential_model(
554+
dense_generator, self._input_dim, self._output_dim
555+
)
556+
557+
@parameterized.product(
558+
batch_size=[1, 2, 10],
559+
l2_norm_clip=[0.1, 1.0, 10],
560+
is_eager=[True, False],
561+
reduction=['auto', 'sum', 'sum_over_batch_size', 'none'],
562+
)
563+
def test_clipped_gradients_on_different_losses(
564+
self, batch_size, l2_norm_clip, is_eager, reduction
565+
):
566+
loss_fn = tf.keras.losses.MeanSquaredError(reduction=reduction)
567+
self._model.compile(loss=loss_fn, run_eagerly=is_eager)
568+
x_batch = tf.reshape(
569+
tf.range(batch_size * self._input_dim, dtype=tf.float32),
570+
[batch_size, -1],
571+
)
572+
y_batch = tf.reshape(
573+
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
574+
)
575+
# Stop early for efficiency.
576+
if reduction == 'none':
577+
self.assertRaises(
578+
NotImplementedError,
579+
# function tested
580+
clip_grads.compute_clipped_gradients_and_outputs,
581+
# function args
582+
self._model,
583+
x_batch,
584+
y_batch,
585+
l2_norm_clip,
586+
layer_registry.make_default_layer_registry(),
587+
)
588+
return
589+
# NOTE: losses from this point are scalar losses.
590+
with tf.GradientTape() as tape:
591+
y_pred = self._model(x_batch)
592+
loss_value = loss_fn(y_pred, y_batch)
593+
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
594+
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
595+
self._model,
596+
x_batch,
597+
y_batch,
598+
l2_norm_clip,
599+
layer_registry.make_default_layer_registry(),
600+
)
601+
602+
# Computes the L2 norm manually.
603+
def compute_l2_norm(t):
604+
sqr_sum_fn = lambda x: tf.reduce_sum(tf.square(x))
605+
return tf.sqrt(tf.add_n(tf.nest.map_structure(sqr_sum_fn, t)))
606+
607+
true_norm = compute_l2_norm(true_grads)
608+
computed_norm = compute_l2_norm(clipped_grads)
609+
norm_bound = (
610+
l2_norm_clip * batch_size if reduction == 'sum' else l2_norm_clip
611+
)
612+
if true_norm >= norm_bound:
613+
# All of the per-example gradient norms should be less than the L2 norm
614+
# clip value. Hence, by the triangle inequality, the gradient norm of the
615+
# summed loss (averaged loss) should be less than the clip value times
616+
# the batch size (just the clip value).
617+
self.assertLessEqual(computed_norm, norm_bound)
618+
else:
619+
self.assertAlmostEqual(computed_norm, true_norm)
620+
621+
544622
if __name__ == '__main__':
545623
tf.test.main()

0 commit comments

Comments
 (0)