@@ -533,26 +533,97 @@ def test_metadata(self):
533
533
self .assertIn ("keras_version" , metadata )
534
534
self .assertIn ("date_saved" , metadata )
535
535
536
- def test_gfile_copy_called (self ):
537
- temp_filepath = Path (
538
- os .path .join (self .get_temp_dir (), "my_model.keras" )
536
+ def test_save_keras_gfile_copy_called (self ):
537
+ path = Path (os .path .join (self .get_temp_dir (), "my_model.keras" ))
538
+ model = keras .Sequential (
539
+ [
540
+ keras .Input (shape = (1 , 1 )),
541
+ keras .layers .Dense (4 ),
542
+ ]
539
543
)
540
- model = CompileOverridingModel ()
541
544
with mock .patch (
542
545
"re.match" , autospec = True
543
546
) as mock_re_match , mock .patch .object (
544
547
tf .io .gfile , "copy"
545
548
) as mock_gfile_copy :
546
549
# Check regex matching
547
550
mock_re_match .return_value = True
548
- model .save (temp_filepath , save_format = "keras_v3" )
551
+ model .save (path , save_format = "keras_v3" )
549
552
mock_re_match .assert_called ()
550
- self .assertIn (str (temp_filepath ), mock_re_match .call_args .args )
553
+ self .assertIn (str (path ), mock_re_match .call_args .args )
551
554
552
555
# Check gfile copied with filepath specified as destination
553
- self .assertEqual (
554
- str (temp_filepath ), str (mock_gfile_copy .call_args .args [1 ])
555
- )
556
+ mock_gfile_copy .assert_called ()
557
+ self .assertEqual (str (path ), str (mock_gfile_copy .call_args .args [1 ]))
558
+
559
+ def test_save_tf_gfile_copy_not_called (self ):
560
+ path = Path (os .path .join (self .get_temp_dir (), "my_model.keras" ))
561
+ model = keras .Sequential (
562
+ [
563
+ keras .Input (shape = (1 , 1 )),
564
+ keras .layers .Dense (4 ),
565
+ ]
566
+ )
567
+ with mock .patch (
568
+ "re.match" , autospec = True
569
+ ) as mock_re_match , mock .patch .object (
570
+ tf .io .gfile , "copy"
571
+ ) as mock_gfile_copy :
572
+ # Check regex matching
573
+ mock_re_match .return_value = True
574
+ model .save (path , save_format = "tf" )
575
+ mock_re_match .assert_called ()
576
+ self .assertIn (str (path ), mock_re_match .call_args .args )
577
+
578
+ # Check gfile.copy was not used.
579
+ mock_gfile_copy .assert_not_called ()
580
+
581
+ def test_save_weights_h5_gfile_copy_called (self ):
582
+ path = Path (os .path .join (self .get_temp_dir (), "my_model.weights.h5" ))
583
+ model = keras .Sequential (
584
+ [
585
+ keras .Input (shape = (1 , 1 )),
586
+ keras .layers .Dense (4 ),
587
+ ]
588
+ )
589
+ model (tf .constant ([[1.0 ]]))
590
+ with mock .patch (
591
+ "re.match" , autospec = True
592
+ ) as mock_re_match , mock .patch .object (
593
+ tf .io .gfile , "copy"
594
+ ) as mock_gfile_copy :
595
+ # Check regex matching
596
+ mock_re_match .return_value = True
597
+ model .save_weights (path )
598
+ mock_re_match .assert_called ()
599
+ self .assertIn (str (path ), mock_re_match .call_args .args )
600
+
601
+ # Check gfile copied with filepath specified as destination
602
+ mock_gfile_copy .assert_called ()
603
+ self .assertEqual (str (path ), str (mock_gfile_copy .call_args .args [1 ]))
604
+
605
+ def test_save_weights_tf_gfile_copy_not_called (self ):
606
+ path = Path (os .path .join (self .get_temp_dir (), "my_model.ckpt" ))
607
+ model = keras .Sequential (
608
+ [
609
+ keras .Input (shape = (1 , 1 )),
610
+ keras .layers .Dense (4 ),
611
+ ]
612
+ )
613
+ model (tf .constant ([[1.0 ]]))
614
+ with mock .patch (
615
+ "re.match" , autospec = True
616
+ ) as mock_re_match , mock .patch .object (
617
+ tf .io .gfile , "copy"
618
+ ) as mock_gfile_copy :
619
+ # Check regex matching
620
+ mock_re_match .return_value = True
621
+ model .save_weights (path )
622
+ mock_re_match .assert_called ()
623
+ self .assertIn (str (path ), mock_re_match .call_args .args )
624
+
625
+ # Check gfile.copy was not used.
626
+ mock_gfile_copy .assert_not_called ()
556
627
557
628
def test_load_model_api_endpoint (self ):
558
629
temp_filepath = Path (os .path .join (self .get_temp_dir (), "mymodel.keras" ))
0 commit comments