Skip to content

Commit 80bad7a

Browse files
authored
Fix TextVectorization tf-idf mode deserialization failure (#22330)
* fix(saving): fix TextVectorization tf-idf mode deserialization * test(saving): skip tf-idf save/load test on non-tensorflow backends
1 parent 6671dff commit 80bad7a

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

keras/src/layers/preprocessing/index_lookup.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,11 @@ def save_own_variables(self, store):
861861

862862
def load_own_variables(self, store):
863863
if self.output_mode == "tf_idf":
864-
self.idf_weights.assign(store["idf_weights"])
864+
idf_weights = store["idf_weights"]
865+
if hasattr(self, "idf_weights"):
866+
self.idf_weights.assign(idf_weights)
867+
else:
868+
self.idf_weights = tf.Variable(idf_weights, trainable=False)
865869
self.idf_weights_const = self.idf_weights.value()
866870

867871
def save_assets(self, dir_path):
@@ -889,7 +893,8 @@ def load_assets(self, dir_path):
889893
else:
890894
values = [int(line) for line in lines]
891895
if self.output_mode == "tf_idf":
892-
self.set_vocabulary(values, idf_weights=False)
896+
idf_weights = self.idf_weights_const.numpy()
897+
self.set_vocabulary(values, idf_weights=idf_weights)
893898
else:
894899
self.set_vocabulary(values)
895900

keras/src/layers/preprocessing/text_vectorization_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,24 @@ def test_save_load_with_ngrams_flow(self):
8585
model = saving.load_model(temp_filepath)
8686
self.assertAllClose(output, model(input_data))
8787

88+
@pytest.mark.skipif(
89+
backend.backend() != "tensorflow", reason="Requires string input dtype"
90+
)
91+
def test_save_load_tf_idf_mode(self):
92+
input_data = np.array(["foo bar", "bar baz", "baz bada boom"])
93+
model = Sequential(
94+
[
95+
layers.Input(dtype="string", shape=()),
96+
layers.TextVectorization(max_tokens=100, output_mode="tf_idf"),
97+
]
98+
)
99+
model.layers[0].adapt(input_data)
100+
output = model(input_data)
101+
temp_filepath = os.path.join(self.get_temp_dir(), "model.keras")
102+
model.save(temp_filepath)
103+
loaded_model = saving.load_model(temp_filepath)
104+
self.assertAllClose(output, loaded_model(input_data))
105+
88106
def test_tf_data_compatibility(self):
89107
max_tokens = 5000
90108
max_len = 4

0 commit comments

Comments
 (0)