Skip to content

Commit b80dd12

Browse files
Add is_legacy_optimizer to optimizer config to keep saving/loading consistent. (#16856)
PiperOrigin-RevId: 463928027 Co-authored-by: Chen Qian <[email protected]>
1 parent 8cdcea7 commit b80dd12

File tree

4 files changed

+15
-2
lines changed

4 files changed

+15
-2
lines changed

keras/mixed_precision/loss_scale_optimizer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,9 @@ def from_config(cls, config, custom_objects=None):
929929
)
930930
config["inner_optimizer"] = config.pop("optimizer")
931931
inner_optimizer = optimizers.deserialize(
932-
config["inner_optimizer"], custom_objects=custom_objects
932+
config["inner_optimizer"],
933+
custom_objects=custom_objects,
934+
use_legacy_optimizer=True,
933935
)
934936
del config["inner_optimizer"]
935937
return cls(inner_optimizer, **config)
@@ -1366,7 +1368,9 @@ def get_config(self):
13661368
def from_config(cls, config, custom_objects=None):
13671369
config = config.copy() # Make a copy, since we mutate config
13681370
inner_optimizer = optimizers.deserialize(
1369-
config["inner_optimizer"], custom_objects=custom_objects
1371+
config["inner_optimizer"],
1372+
custom_objects=custom_objects,
1373+
use_legacy_optimizer=False,
13701374
)
13711375
del config["inner_optimizer"]
13721376
return cls(inner_optimizer, **config)

keras/optimizers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def deserialize(config, custom_objects=None, **kwargs):
124124
)
125125

126126
use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", True)
127+
if len(config["config"]) > 0:
128+
# If the optimizer config is not empty, then we use the value of
129+
# `is_legacy_optimizer` to override `use_legacy_optimizer`. If
130+
# `is_legacy_optimizer` does not exist in config, it means we are
131+
# using the legacy optimzier.
132+
use_legacy_optimizer = config["config"].get("is_legacy_optimizer", True)
127133
if (
128134
tf.__internal__.tf2.enabled()
129135
and tf.executing_eagerly()

keras/optimizers/optimizer_experimental/optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def _create_iteration_variable(self):
9797
)
9898

9999
def _process_kwargs(self, kwargs):
100+
kwargs.pop("is_legacy_optimizer", None)
100101
legacy_kwargs = {
101102
"lr",
102103
"decay",
@@ -619,6 +620,7 @@ def get_config(self):
619620
"ema_momentum": self.ema_momentum,
620621
"ema_overwrite_frequency": self.ema_overwrite_frequency,
621622
"jit_compile": self.jit_compile,
623+
"is_legacy_optimizer": False,
622624
}
623625
return config
624626

keras/optimizers/optimizer_experimental/optimizer_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def testGetAndFromConfig(self):
297297
"use_ema": True,
298298
"ema_momentum": 0.5,
299299
"ema_overwrite_frequency": 50,
300+
"is_legacy_optimizer": False,
300301
}
301302
self.assertDictContainsSubset(expected_config, config)
302303
restored_optimizer = adam_new.Adam.from_config(config)

0 commit comments

Comments
 (0)