Skip to content

Commit 4f72d68

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Add support for autocast keyword argument in Layer.add_weight.
This feature was already supported with the `experimental_autocast` argument. This change simply adds an alias for the same argument to have the same API in Keras 2 and Keras 3. PiperOrigin-RevId: 662965658
1 parent f93f30e commit 4f72d68

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

Diff for: tf_keras/engine/base_layer.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,8 @@ def add_weight(
578578
Accepted values are constants defined in the class
579579
`tf.VariableAggregation`.
580580
**kwargs: Additional keyword arguments. Accepted values are `getter`,
581-
`collections`, `experimental_autocast` and `caching_device`.
581+
`collections`, `autocast`, `experimental_autocast` and
582+
`caching_device`.
582583
583584
Returns:
584585
The variable created.
@@ -594,6 +595,7 @@ def add_weight(
594595
# Validate optional keyword arguments.
595596
for kwarg in kwargs:
596597
if kwarg not in [
598+
"autocast",
597599
"collections",
598600
"experimental_autocast",
599601
"caching_device",
@@ -603,9 +605,13 @@ def add_weight(
603605
]:
604606
raise TypeError("Unknown keyword argument:", kwarg)
605607
collections_arg = kwargs.pop("collections", None)
606-
# 'experimental_autocast' can be set to False by the caller to indicate
607-
# an AutoCastVariable should never be created.
608-
autocast = kwargs.pop("experimental_autocast", True)
608+
# 'autocast' or 'experimental_autocast' can be set to False by the
609+
# caller to indicate an AutoCastVariable should never be created.
610+
autocast = kwargs.pop("autocast", None)
611+
if autocast is None:
612+
autocast = kwargs.pop("experimental_autocast", None)
613+
if autocast is None:
614+
autocast = True
609615
# See the docstring for tf.Variable about the details for
610616
# caching_device.
611617
caching_device = kwargs.pop("caching_device", None)

Diff for: tf_keras/engine/base_layer_v1.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ def add_weight(
352352
Accepted values are constants defined in the class
353353
`tf.VariableAggregation`.
354354
**kwargs: Additional keyword arguments. Accepted values are `getter`,
355-
`collections`, `experimental_autocast` and `caching_device`.
355+
`collections`, `autocast`, `experimental_autocast` and
356+
`caching_device`.
356357
357358
Returns:
358359
The created variable. Usually either a `Variable` or
@@ -371,6 +372,7 @@ def add_weight(
371372
# Validate optional keyword arguments.
372373
for kwarg in kwargs:
373374
if kwarg not in [
375+
"autocast",
374376
"getter",
375377
"collections",
376378
"experimental_autocast",
@@ -380,9 +382,13 @@ def add_weight(
380382
has_custom_getter = "getter" in kwargs
381383
getter = kwargs.pop("getter", base_layer_utils.make_variable)
382384
collections_arg = kwargs.pop("collections", None)
383-
# 'experimental_autocast' can be set to False by the caller to indicate
384-
# an AutoCastVariable should never be created.
385-
autocast = kwargs.pop("experimental_autocast", True)
385+
# 'autocast' or 'experimental_autocast' can be set to False by the
386+
# caller to indicate an AutoCastVariable should never be created.
387+
autocast = kwargs.pop("autocast", None)
388+
if autocast is None:
389+
autocast = kwargs.pop("experimental_autocast", None)
390+
if autocast is None:
391+
autocast = True
386392
# See the docstring for tf.Variable about the details for
387393
# caching_device.
388394
caching_device = kwargs.pop("caching_device", None)

Diff for: tf_keras/mixed_precision/test_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def build(self, _):
214214
(),
215215
initializer="ones",
216216
dtype=dtype,
217-
experimental_autocast=False,
217+
autocast=False,
218218
regularizer=self._regularizer,
219219
)
220220
self.built = True

0 commit comments

Comments
 (0)