@@ -577,3 +577,152 @@ def compare_nested_dicts(orig_dict, loaded_dict):
577577 original_state_tree ["metrics_variables" ],
578578 loaded_state_tree ["metrics_variables" ],
579579 )
580+
581+ @pytest .mark .requires_trainable_backend
582+ def _flatten_nested_dict (self , nested_dict ):
583+ """Flatten a nested dictionary into a flat dictionary with path keys."""
584+ flat_dict = {}
585+
586+ def _flatten (current_dict , prefix = "" ):
587+ for key , value in current_dict .items ():
588+ if isinstance (value , dict ):
589+ _flatten (value , f"{ prefix } { key } /" )
590+ else :
591+ flat_dict [f"{ prefix } { key } " ] = value
592+
593+ _flatten (nested_dict )
594+ return flat_dict
595+
596+ @pytest .mark .requires_trainable_backend
597+ def test_model_load_method (self ):
598+ """Test the Model.load() method for loading Orbax checkpoints."""
599+ # Test both synchronous and asynchronous saving modes
600+ self ._test_model_load_with_saving_mode (save_on_background = False )
601+ self ._test_model_load_with_saving_mode (save_on_background = True )
602+
603+ def _test_model_load_with_saving_mode (self , save_on_background ):
604+ """Helper method to test Model.load() with different saving modes."""
605+ model = self ._create_test_model ()
606+ x , y = self ._create_dummy_data ()
607+
608+ checkpoint_dir = os .path .join (
609+ self .get_temp_dir (),
610+ f"test_model_load_{ 'async' if save_on_background else 'sync' } " ,
611+ )
612+ callback = OrbaxCheckpoint (
613+ directory = checkpoint_dir ,
614+ save_freq = "epoch" ,
615+ save_on_background = save_on_background ,
616+ )
617+
618+ # Train for a few epochs to create checkpoints
619+ model .fit (x , y , epochs = 3 , callbacks = [callback ], verbose = 0 )
620+
621+ # Wait for async operations to complete if using async saving
622+ if save_on_background :
623+ callback .wait_until_finished ()
624+
625+ # Get the state of the trained model
626+ trained_state = model .get_state_tree ()
627+
628+ # Create a new model with same architecture
629+ new_model = self ._create_test_model ()
630+ original_weights = new_model .get_weights ()
631+
632+ # Test loading the latest checkpoint
633+ new_model .load (checkpoint_dir )
634+ loaded_weights = new_model .get_weights ()
635+ loaded_state = new_model .get_state_tree ()
636+
637+ # Weights should be different after loading
638+ # (from random init to trained)
639+ weights_changed = False
640+ for orig , loaded in zip (original_weights , loaded_weights ):
641+ if not np .allclose (orig , loaded ):
642+ weights_changed = True
643+ break
644+ self .assertTrue (
645+ weights_changed , "Weights should change after loading checkpoint"
646+ )
647+
648+ # Verify that loaded weights match the trained model's weights
649+ trained_weights = model .get_weights ()
650+ for trained_w , loaded_w in zip (trained_weights , loaded_weights ):
651+ self .assertTrue (
652+ np .allclose (trained_w , loaded_w ),
653+ "Loaded weights should match trained model's weights" ,
654+ )
655+
656+ # Verify that optimizer state was loaded
657+ trained_opt_flat = self ._flatten_nested_dict (
658+ trained_state ["optimizer_variables" ]
659+ )
660+ loaded_opt_flat = self ._flatten_nested_dict (
661+ loaded_state ["optimizer_variables" ]
662+ )
663+ self .assertEqual (
664+ set (trained_opt_flat .keys ()),
665+ set (loaded_opt_flat .keys ()),
666+ "Optimizer variable keys should match" ,
667+ )
668+ for key in trained_opt_flat :
669+ # Convert tensors to numpy for comparison
670+ trained_val = trained_opt_flat [key ]
671+ loaded_val = loaded_opt_flat [key ]
672+
673+ # Handle different tensor types
674+ if hasattr (trained_val , "detach" ): # PyTorch tensor
675+ trained_np = trained_val .detach ().cpu ().numpy ()
676+ elif hasattr (trained_val , "numpy" ): # TF variable
677+ trained_np = trained_val .numpy ()
678+ else : # numpy array
679+ trained_np = trained_val
680+
681+ if hasattr (loaded_val , "detach" ): # PyTorch tensor
682+ loaded_np = loaded_val .detach ().cpu ().numpy ()
683+ elif hasattr (loaded_val , "numpy" ): # TF variable
684+ loaded_np = loaded_val .numpy ()
685+ else : # numpy array
686+ loaded_np = loaded_val
687+
688+ self .assertTrue (
689+ np .allclose (trained_np , loaded_np ),
690+ f"Optimizer variable { key } should match" ,
691+ )
692+
693+ # Verify that metrics state was loaded
694+ trained_met_flat = self ._flatten_nested_dict (
695+ trained_state ["metrics_variables" ]
696+ )
697+ loaded_met_flat = self ._flatten_nested_dict (
698+ loaded_state ["metrics_variables" ]
699+ )
700+ self .assertEqual (
701+ set (trained_met_flat .keys ()),
702+ set (loaded_met_flat .keys ()),
703+ "Metrics variable keys should match" ,
704+ )
705+ for key in trained_met_flat :
706+ # Convert tensors to numpy for comparison
707+ trained_val = trained_met_flat [key ]
708+ loaded_val = loaded_met_flat [key ]
709+
710+ # Handle different tensor types
711+ if hasattr (trained_val , "detach" ): # PyTorch tensor
712+ trained_np = trained_val .detach ().cpu ().numpy ()
713+ elif hasattr (trained_val , "numpy" ): # TF variable
714+ trained_np = trained_val .numpy ()
715+ else : # numpy array
716+ trained_np = trained_val
717+
718+ if hasattr (loaded_val , "detach" ): # PyTorch tensor
719+ loaded_np = loaded_val .detach ().cpu ().numpy ()
720+ elif hasattr (loaded_val , "numpy" ): # TF variable
721+ loaded_np = loaded_val .numpy ()
722+ else : # numpy array
723+ loaded_np = loaded_val
724+
725+ self .assertTrue (
726+ np .allclose (trained_np , loaded_np ),
727+ f"Metrics variable { key } should match" ,
728+ )
0 commit comments