Skip to content

Commit c92610e

Browse files
Implement and test a registry function for tf.keras.layers.LayerNormalization.
PiperOrigin-RevId: 561423397
1 parent 372c934 commit c92610e

File tree

7 files changed

+314
-9
lines changed

7 files changed

+314
-9
lines changed

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD

+22
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,25 @@ py_test(
4848
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
4949
],
5050
)
51+
52+
py_library(
53+
name = "layer_normalization",
54+
srcs = ["layer_normalization.py"],
55+
srcs_version = "PY3",
56+
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
57+
)
58+
59+
py_test(
60+
name = "layer_normalization_test",
61+
srcs = ["layer_normalization_test.py"],
62+
python_version = "PY3",
63+
shard_count = 8,
64+
srcs_version = "PY3",
65+
deps = [
66+
":dense",
67+
":layer_normalization",
68+
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
69+
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
70+
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
71+
],
72+
)

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Fast clipping function for `tf.keras.layers.Dense`."""
1515

16-
from typing import Any, Dict, Optional, Text, Tuple
16+
from typing import Any, Mapping, Tuple, Union
1717
import tensorflow as tf
1818
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
1919
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@@ -22,9 +22,9 @@
2222
def dense_layer_computation(
2323
layer_instance: tf.keras.layers.Dense,
2424
input_args: Tuple[Any, ...],
25-
input_kwargs: Dict[Text, Any],
25+
input_kwargs: Mapping[str, Any],
2626
tape: tf.GradientTape,
27-
num_microbatches: Optional[tf.Tensor] = None,
27+
num_microbatches: Union[tf.Tensor, None] = None,
2828
) -> type_aliases.RegistryFunctionOutput:
2929
"""Registry function for `tf.keras.layers.Dense`.
3030

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
# limitations under the License.
1414
"""Fast clipping function for `tf.keras.layers.Embedding`."""
1515

16-
from typing import Any, Dict, Optional, Text, Tuple
16+
from typing import Any, Mapping, Tuple, Union
1717
import tensorflow as tf
1818
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
1919

2020

2121
def embedding_layer_computation(
2222
layer_instance: tf.keras.layers.Embedding,
2323
input_args: Tuple[Any, ...],
24-
input_kwargs: Dict[Text, Any],
24+
input_kwargs: Mapping[str, Any],
2525
tape: tf.GradientTape,
26-
num_microbatches: Optional[tf.Tensor] = None,
26+
num_microbatches: Union[tf.Tensor, None] = None,
2727
) -> type_aliases.RegistryFunctionOutput:
2828
"""Registry function for `tf.keras.layers.Embedding`.
2929
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2023, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""
15+
16+
from typing import Any, Mapping, Tuple, Union
17+
import tensorflow as tf
18+
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
19+
20+
21+
# ==============================================================================
22+
# Supported Keras layers
23+
# ==============================================================================
24+
def _sqr_norm_fn(grads):
25+
stacked_grads = tf.stack(grads, axis=-1)
26+
reduction_axes = tf.range(1, tf.rank(stacked_grads))
27+
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
28+
29+
30+
def layer_normalization_computation(
31+
layer_instance: tf.keras.layers.LayerNormalization,
32+
input_args: Tuple[Any, ...],
33+
input_kwargs: Mapping[str, Any],
34+
tape: tf.GradientTape,
35+
num_microbatches: Union[tf.Tensor, None] = None,
36+
) -> type_aliases.RegistryFunctionOutput:
37+
"""Registry function for `tf.keras.layers.LayerNormalization`.
38+
39+
This function computes actual per-example gradients and computes their
40+
norms directly, instead of employing a chain-rule trick. This is done using
41+
some slick reshaping calls.
42+
43+
Args:
44+
layer_instance: A `tf.keras.layers.LayerNormalization` instance.
45+
input_args: See `dense_layer_computation()` in `dense.py`.
46+
input_kwargs: See `dense_layer_computation()` in `dense.py`.
47+
tape: See `dense_layer_computation()` in `dense.py`.
48+
num_microbatches: See `dense_layer_computation()` in `dense.py`.
49+
50+
Returns:
51+
See `dense_layer_computation()` in `dense.py`.
52+
"""
53+
del input_kwargs # Unused in layer normaliztion calls.
54+
if num_microbatches is not None:
55+
raise NotImplementedError("Microbatching is not currently supported.")
56+
57+
# To make sure the watched variables (beta, gamma) generate per-example
58+
# gradients, we need to convert trainable variables from shape [S] to
59+
# [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting.
60+
inputs = input_args[0]
61+
base_vars = []
62+
batch_size = tf.shape(inputs)[0]
63+
64+
def process_variable(var):
65+
"""Expand univariate `var` and the expanded tensor to `base_vars`."""
66+
expanded_var = tf.tile(
67+
tf.expand_dims(var, axis=0), [batch_size] + [1] * len(var.shape)
68+
)
69+
tape.watch(expanded_var)
70+
base_vars.append(expanded_var)
71+
broadcast_shape = [1] * len(inputs.shape)
72+
broadcast_shape[0] = batch_size
73+
for d in layer_instance.axis:
74+
broadcast_shape[d] = tf.shape(inputs)[d]
75+
final_var = tf.reshape(expanded_var, broadcast_shape)
76+
return final_var
77+
78+
orig_gamma = layer_instance.gamma
79+
orig_beta = layer_instance.beta
80+
layer_instance.gamma = process_variable(orig_gamma)
81+
layer_instance.beta = process_variable(orig_beta)
82+
83+
# Do the computation, ensure that the output conforms to the unexpanded
84+
# computation, and restore the state of the original instance.
85+
outputs = layer_instance.call(inputs)
86+
layer_instance.gamma = orig_gamma
87+
layer_instance.beta = orig_beta
88+
89+
return base_vars, outputs, _sqr_norm_fn
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2023, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import parameterized
16+
import numpy as np
17+
import tensorflow as tf
18+
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
19+
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
20+
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
21+
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
22+
23+
24+
# ==============================================================================
25+
# Helper functions.
26+
# ==============================================================================
27+
def get_layer_norm_layer_generators():
28+
return {
29+
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
30+
}
31+
32+
33+
def get_layer_norm_model_generators():
34+
return {
35+
# TODO(b/274483956): Test more complex models once the we can support
36+
# `nD` inputs for `tf.keras.layers.Dense`.
37+
'func1': common_test_utils.make_one_layer_functional_model,
38+
}
39+
40+
41+
def get_layer_norm_parameter_tuples():
42+
"""Consists of (input_dims, parameter_axes)."""
43+
return [
44+
# Rank-2
45+
([3], -1),
46+
([3], [1]),
47+
# Rank-3
48+
([3, 4], -1),
49+
([3, 4], [1]),
50+
([3, 4], [2]),
51+
([3, 4], [1, 2]),
52+
# Rank-4
53+
([3, 4, 5], -1),
54+
([3, 4, 5], [1]),
55+
([3, 4, 5], [2]),
56+
([3, 4, 5], [3]),
57+
([3, 4, 5], [1, 2]),
58+
([3, 4, 5], [1, 3]),
59+
([3, 4, 5], [2, 3]),
60+
([3, 4, 5], [1, 2, 3]),
61+
]
62+
63+
64+
def get_layer_norm_registries():
65+
ln_registry = layer_registry.LayerRegistry()
66+
ln_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
67+
ln_registry.insert(
68+
tf.keras.layers.LayerNormalization,
69+
layer_normalization.layer_normalization_computation,
70+
)
71+
return {
72+
'layer_norm_only': ln_registry,
73+
}
74+
75+
76+
# ==============================================================================
77+
# Main tests.
78+
# ==============================================================================
79+
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
80+
81+
def setUp(self):
82+
super().setUp()
83+
self.strategy = tf.distribute.get_strategy()
84+
85+
@parameterized.product(
86+
model_name=list(get_layer_norm_model_generators().keys()),
87+
layer_name=list(get_layer_norm_layer_generators().keys()),
88+
parameter_tuple=get_layer_norm_parameter_tuples(),
89+
layer_registry_name=list(get_layer_norm_registries().keys()),
90+
is_eager=[True, False],
91+
)
92+
def test_gradient_norms_on_various_models(
93+
self,
94+
model_name,
95+
layer_name,
96+
parameter_tuple,
97+
layer_registry_name,
98+
is_eager,
99+
):
100+
# Parse inputs to generate test data.
101+
input_dims, parameter_axes = parameter_tuple
102+
103+
def curried_generator(a, b):
104+
del a, b # Unused by the generator.
105+
layer_norm_generator = get_layer_norm_layer_generators()[layer_name]
106+
return layer_norm_generator(parameter_axes)
107+
108+
# Load shared assets to all devices.
109+
with self.strategy.scope():
110+
dummy_output_dim = 1
111+
model = common_test_utils.get_model_from_generator(
112+
model_generator=get_layer_norm_model_generators()[model_name],
113+
layer_generator=curried_generator,
114+
input_dims=input_dims,
115+
output_dims=[dummy_output_dim],
116+
is_eager=is_eager,
117+
)
118+
119+
# Define the main testing ops. These may be later compiled to a Graph op.
120+
def test_op(x_batch):
121+
return common_test_utils.get_computed_and_true_norms_from_model(
122+
model=model,
123+
per_example_loss_fn=None,
124+
num_microbatches=None,
125+
x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch,
126+
weight_batch=None,
127+
registry=get_layer_norm_registries()[layer_registry_name],
128+
)
129+
130+
# TPUs can only run `tf.function`-decorated functions.
131+
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
132+
if using_tpu:
133+
test_op = tf.function(test_op, jit_compile=True, autograph=False)
134+
135+
# TPUs use lower precision than CPUs, so we relax our criterion (see
136+
# `dense_test.py` for additional discussions).
137+
rtol = 1e-2 if using_tpu else 1e-3
138+
atol = 1e-1 if using_tpu else 1e-2
139+
140+
# Each batched input is a reshape of a `tf.range()` call.
141+
batch_size = 2
142+
example_size = np.prod(input_dims)
143+
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
144+
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
145+
batch_size = x_batch.shape[0]
146+
# Set up the device ops and run the test.
147+
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
148+
# TPUs return replica contexts, which must be unwrapped.
149+
if using_tpu:
150+
common_test_utils.assert_replica_values_are_close(self, computed_norms)
151+
common_test_utils.assert_replica_values_are_close(self, true_norms)
152+
computed_norms = computed_norms.values[0]
153+
true_norms = true_norms.values[0]
154+
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
155+
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
156+
157+
158+
if __name__ == '__main__':
159+
tf.test.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2023, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
17+
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test
18+
19+
20+
class GradNormTpuTest(layer_normalization_test.GradNormTest):
21+
22+
def setUp(self):
23+
super().setUp()
24+
self.strategy = ctu.create_tpu_strategy()
25+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
26+
27+
28+
if __name__ == '__main__':
29+
tf.test.main()

Diff for: tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414
"""A collection of type aliases used throughout the clipping library."""
1515

16-
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
16+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
1717
import tensorflow as tf
1818

1919

2020
# Tensorflow aliases.
21-
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
21+
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]]
2222

2323
InputTensors = PackedTensors
2424

@@ -34,7 +34,13 @@
3434
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
3535

3636
RegistryFunction = Callable[
37-
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
37+
[
38+
Any,
39+
Tuple[Any, ...],
40+
Mapping[str, Any],
41+
tf.GradientTape,
42+
Union[tf.Tensor, None],
43+
],
3844
RegistryFunctionOutput,
3945
]
4046

0 commit comments

Comments
 (0)