File tree Expand file tree Collapse file tree 2 files changed +25
-2
lines changed
keras/src/layers/preprocessing Expand file tree Collapse file tree 2 files changed +25
-2
lines changed Original file line number Diff line number Diff line change @@ -861,7 +861,11 @@ def save_own_variables(self, store):
861861
862862 def load_own_variables (self , store ):
863863 if self .output_mode == "tf_idf" :
864- self .idf_weights .assign (store ["idf_weights" ])
864+ idf_weights = store ["idf_weights" ]
865+ if hasattr (self , "idf_weights" ):
866+ self .idf_weights .assign (idf_weights )
867+ else :
868+ self .idf_weights = tf .Variable (idf_weights , trainable = False )
865869 self .idf_weights_const = self .idf_weights .value ()
866870
867871 def save_assets (self , dir_path ):
@@ -889,7 +893,8 @@ def load_assets(self, dir_path):
889893 else :
890894 values = [int (line ) for line in lines ]
891895 if self .output_mode == "tf_idf" :
892- self .set_vocabulary (values , idf_weights = False )
896+ idf_weights = self .idf_weights_const .numpy ()
897+ self .set_vocabulary (values , idf_weights = idf_weights )
893898 else :
894899 self .set_vocabulary (values )
895900
Original file line number Diff line number Diff line change @@ -85,6 +85,24 @@ def test_save_load_with_ngrams_flow(self):
8585 model = saving .load_model (temp_filepath )
8686 self .assertAllClose (output , model (input_data ))
8787
88+ @pytest .mark .skipif (
89+ backend .backend () != "tensorflow" , reason = "Requires string input dtype"
90+ )
91+ def test_save_load_tf_idf_mode (self ):
92+ input_data = np .array (["foo bar" , "bar baz" , "baz bada boom" ])
93+ model = Sequential (
94+ [
95+ layers .Input (dtype = "string" , shape = ()),
96+ layers .TextVectorization (max_tokens = 100 , output_mode = "tf_idf" ),
97+ ]
98+ )
99+ model .layers [0 ].adapt (input_data )
100+ output = model (input_data )
101+ temp_filepath = os .path .join (self .get_temp_dir (), "model.keras" )
102+ model .save (temp_filepath )
103+ loaded_model = saving .load_model (temp_filepath )
104+ self .assertAllClose (output , loaded_model (input_data ))
105+
88106 def test_tf_data_compatibility (self ):
89107 max_tokens = 5000
90108 max_len = 4
You can’t perform that action at this time.
0 commit comments