Skip to content

Commit bc9060c

Browse files
Added Load method for orbax
1 parent f0a48a6 commit bc9060c

File tree

3 files changed

+349
-46
lines changed

3 files changed

+349
-46
lines changed

keras/src/callbacks/orbax_checkpoint_test.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,152 @@ def compare_nested_dicts(orig_dict, loaded_dict):
577577
original_state_tree["metrics_variables"],
578578
loaded_state_tree["metrics_variables"],
579579
)
580+
581+
@pytest.mark.requires_trainable_backend
582+
def _flatten_nested_dict(self, nested_dict):
583+
"""Flatten a nested dictionary into a flat dictionary with path keys."""
584+
flat_dict = {}
585+
586+
def _flatten(current_dict, prefix=""):
587+
for key, value in current_dict.items():
588+
if isinstance(value, dict):
589+
_flatten(value, f"{prefix}{key}/")
590+
else:
591+
flat_dict[f"{prefix}{key}"] = value
592+
593+
_flatten(nested_dict)
594+
return flat_dict
595+
596+
@pytest.mark.requires_trainable_backend
597+
def test_model_load_method(self):
598+
"""Test the Model.load() method for loading Orbax checkpoints."""
599+
# Test both synchronous and asynchronous saving modes
600+
self._test_model_load_with_saving_mode(save_on_background=False)
601+
self._test_model_load_with_saving_mode(save_on_background=True)
602+
603+
def _test_model_load_with_saving_mode(self, save_on_background):
604+
"""Helper method to test Model.load() with different saving modes."""
605+
model = self._create_test_model()
606+
x, y = self._create_dummy_data()
607+
608+
checkpoint_dir = os.path.join(
609+
self.get_temp_dir(),
610+
f"test_model_load_{'async' if save_on_background else 'sync'}",
611+
)
612+
callback = OrbaxCheckpoint(
613+
directory=checkpoint_dir,
614+
save_freq="epoch",
615+
save_on_background=save_on_background,
616+
)
617+
618+
# Train for a few epochs to create checkpoints
619+
model.fit(x, y, epochs=3, callbacks=[callback], verbose=0)
620+
621+
# Wait for async operations to complete if using async saving
622+
if save_on_background:
623+
callback.wait_until_finished()
624+
625+
# Get the state of the trained model
626+
trained_state = model.get_state_tree()
627+
628+
# Create a new model with same architecture
629+
new_model = self._create_test_model()
630+
original_weights = new_model.get_weights()
631+
632+
# Test loading the latest checkpoint
633+
new_model.load(checkpoint_dir)
634+
loaded_weights = new_model.get_weights()
635+
loaded_state = new_model.get_state_tree()
636+
637+
# Weights should be different after loading
638+
# (from random init to trained)
639+
weights_changed = False
640+
for orig, loaded in zip(original_weights, loaded_weights):
641+
if not np.allclose(orig, loaded):
642+
weights_changed = True
643+
break
644+
self.assertTrue(
645+
weights_changed, "Weights should change after loading checkpoint"
646+
)
647+
648+
# Verify that loaded weights match the trained model's weights
649+
trained_weights = model.get_weights()
650+
for trained_w, loaded_w in zip(trained_weights, loaded_weights):
651+
self.assertTrue(
652+
np.allclose(trained_w, loaded_w),
653+
"Loaded weights should match trained model's weights",
654+
)
655+
656+
# Verify that optimizer state was loaded
657+
trained_opt_flat = self._flatten_nested_dict(
658+
trained_state["optimizer_variables"]
659+
)
660+
loaded_opt_flat = self._flatten_nested_dict(
661+
loaded_state["optimizer_variables"]
662+
)
663+
self.assertEqual(
664+
set(trained_opt_flat.keys()),
665+
set(loaded_opt_flat.keys()),
666+
"Optimizer variable keys should match",
667+
)
668+
for key in trained_opt_flat:
669+
# Convert tensors to numpy for comparison
670+
trained_val = trained_opt_flat[key]
671+
loaded_val = loaded_opt_flat[key]
672+
673+
# Handle different tensor types
674+
if hasattr(trained_val, "detach"): # PyTorch tensor
675+
trained_np = trained_val.detach().cpu().numpy()
676+
elif hasattr(trained_val, "numpy"): # TF variable
677+
trained_np = trained_val.numpy()
678+
else: # numpy array
679+
trained_np = trained_val
680+
681+
if hasattr(loaded_val, "detach"): # PyTorch tensor
682+
loaded_np = loaded_val.detach().cpu().numpy()
683+
elif hasattr(loaded_val, "numpy"): # TF variable
684+
loaded_np = loaded_val.numpy()
685+
else: # numpy array
686+
loaded_np = loaded_val
687+
688+
self.assertTrue(
689+
np.allclose(trained_np, loaded_np),
690+
f"Optimizer variable {key} should match",
691+
)
692+
693+
# Verify that metrics state was loaded
694+
trained_met_flat = self._flatten_nested_dict(
695+
trained_state["metrics_variables"]
696+
)
697+
loaded_met_flat = self._flatten_nested_dict(
698+
loaded_state["metrics_variables"]
699+
)
700+
self.assertEqual(
701+
set(trained_met_flat.keys()),
702+
set(loaded_met_flat.keys()),
703+
"Metrics variable keys should match",
704+
)
705+
for key in trained_met_flat:
706+
# Convert tensors to numpy for comparison
707+
trained_val = trained_met_flat[key]
708+
loaded_val = loaded_met_flat[key]
709+
710+
# Handle different tensor types
711+
if hasattr(trained_val, "detach"): # PyTorch tensor
712+
trained_np = trained_val.detach().cpu().numpy()
713+
elif hasattr(trained_val, "numpy"): # TF variable
714+
trained_np = trained_val.numpy()
715+
else: # numpy array
716+
trained_np = trained_val
717+
718+
if hasattr(loaded_val, "detach"): # PyTorch tensor
719+
loaded_np = loaded_val.detach().cpu().numpy()
720+
elif hasattr(loaded_val, "numpy"): # TF variable
721+
loaded_np = loaded_val.numpy()
722+
else: # numpy array
723+
loaded_np = loaded_val
724+
725+
self.assertTrue(
726+
np.allclose(trained_np, loaded_np),
727+
f"Metrics variable {key} should match",
728+
)

0 commit comments

Comments
 (0)