Skip to content

Commit 66d05a2

Browse files
wwkongtensorflower-gardener
authored andcommitted
Fix a gradient clipping bug for layer normalization layers with microbatch axes.
The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug. PiperOrigin-RevId: 669413064
1 parent b396397 commit 66d05a2

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ def sqr_norm_fn(grads):
8080
stacked_grads = tf.stack(grads, axis=-1)
8181
if num_microbatches is not None:
8282
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
83-
grads, num_microbatches
83+
stacked_grads, num_microbatches
8484
)
85+
# We will need to sum over the new microbatch size axis (axis=1) in order
86+
# to account for microbatch aggregation.
87+
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
8588
reduction_axes = tf.range(1, tf.rank(stacked_grads))
8689
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
8790

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_op(x_batch):
134134
atol = 1e-1 if self.using_tpu else 1e-2
135135

136136
# Each batched input is a reshape of a `tf.range()` call.
137-
batch_size = 2
137+
batch_size = 6
138138
example_size = np.prod(input_dims)
139139
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
140140
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
@@ -147,7 +147,9 @@ def test_op(x_batch):
147147
common_test_utils.assert_replica_values_are_close(self, true_norms)
148148
computed_norms = computed_norms.values[0]
149149
true_norms = true_norms.values[0]
150-
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
150+
self.assertEqual(
151+
tf.shape(computed_norms)[0], num_microbatches or batch_size
152+
)
151153
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
152154

153155

0 commit comments

Comments
 (0)