Skip to content

Commit e73c4cc

Browse files
committed
Handle compatibility issue & use cloudpickle
1 parent f4865cc commit e73c4cc

File tree

3 files changed

+84
-47
lines changed

3 files changed

+84
-47
lines changed

conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import pytest # noqa: E402
1616

1717
from keras.src.backend import backend # noqa: E402
18+
from keras.src.saving.object_registration import ( # noqa: E402
19+
get_custom_objects,
20+
)
1821

1922

2023
def pytest_configure(config):
@@ -32,3 +35,9 @@ def pytest_collection_modifyitems(config, items):
3235
for item in items:
3336
if "requires_trainable_backend" in item.keywords:
3437
item.add_marker(requires_trainable_backend)
38+
39+
40+
# Ensure each test is run in isolation regarding the custom objects dict
41+
@pytest.fixture(autouse=True)
42+
def reset_custom_objects_global_dictionary(request):
43+
get_custom_objects().clear()

keras/src/models/model_test.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pickle
2+
import sys
23

4+
import cloudpickle
35
import numpy as np
46
import pytest
57
from absl.testing import parameterized
@@ -14,26 +16,6 @@
1416
from keras.src.saving.object_registration import register_keras_serializable
1517

1618

17-
@pytest.fixture
18-
def my_custom_dense():
19-
@register_keras_serializable(package="MyLayers", name="CustomDense")
20-
class CustomDense(layers.Layer):
21-
def __init__(self, units, **kwargs):
22-
super().__init__(**kwargs)
23-
self.units = units
24-
self.dense = layers.Dense(units)
25-
26-
def call(self, x):
27-
return self.dense(x)
28-
29-
def get_config(self):
30-
config = super().get_config()
31-
config.update({"units": self.units})
32-
return config
33-
34-
return CustomDense
35-
36-
3719
def _get_model():
3820
input_a = Input(shape=(3,), batch_size=2, name="input_a")
3921
input_b = Input(shape=(3,), batch_size=2, name="input_b")
@@ -89,11 +71,15 @@ def _get_model_multi_outputs_dict():
8971
return model
9072

9173

92-
def _get_model_custom_layer():
93-
x = Input(shape=(3,), name="input_a")
94-
output_a = my_custom_dense()(10, name="output_a")(x)
95-
model = Model(x, output_a)
96-
return model
74+
@pytest.fixture
75+
def fake_main_module(request, monkeypatch):
76+
original_main = sys.modules["__main__"]
77+
78+
def restore_main_module():
79+
sys.modules["__main__"] = original_main
80+
81+
request.addfinalizer(restore_main_module)
82+
sys.modules["__main__"] = sys.modules[__name__]
9783

9884

9985
@pytest.mark.requires_trainable_backend
@@ -155,7 +141,6 @@ def call(self, x):
155141
("single_list_output_2", _get_model_single_output_list),
156142
("single_list_output_3", _get_model_single_output_list),
157143
("single_list_output_4", _get_model_single_output_list),
158-
("custom_layer", _get_model_custom_layer),
159144
)
160145
def test_functional_pickling(self, model_fn):
161146
model = model_fn()
@@ -170,6 +155,45 @@ def test_functional_pickling(self, model_fn):
170155

171156
self.assertAllClose(np.array(pred_reloaded), np.array(pred))
172157

158+
# Fake the __main__ module because cloudpickle only serializes
159+
# functions & classes if they are defined in the __main__ module.
160+
@pytest.mark.usefixtures("fake_main_module")
161+
def test_functional_pickling_custom_layer(self):
162+
@register_keras_serializable()
163+
class CustomDense(layers.Layer):
164+
def __init__(self, units, **kwargs):
165+
super().__init__(**kwargs)
166+
self.units = units
167+
self.dense = layers.Dense(units)
168+
169+
def call(self, x):
170+
return self.dense(x)
171+
172+
def get_config(self):
173+
config = super().get_config()
174+
config.update({"units": self.units})
175+
return config
176+
177+
x = Input(shape=(3,), name="input_a")
178+
output_a = CustomDense(10, name="output_a")(x)
179+
model = Model(x, output_a)
180+
181+
self.assertIsInstance(model, Functional)
182+
model.compile()
183+
x = np.random.rand(8, 3)
184+
185+
dumped_pickle = cloudpickle.dumps(model)
186+
187+
# Verify that we can load the dumped pickle even if the custom object
188+
# is not available in the loading environment.
189+
del CustomDense
190+
reloaded_pickle = cloudpickle.loads(dumped_pickle)
191+
192+
pred_reloaded = reloaded_pickle.predict(x)
193+
pred = model.predict(x)
194+
195+
self.assertAllClose(np.array(pred_reloaded), np.array(pred))
196+
173197
@parameterized.named_parameters(
174198
("single_output_1", _get_model_single_output, None),
175199
("single_output_2", _get_model_single_output, "list"),

keras/src/saving/keras_saveable.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
import pickle
32

43
from keras.src.saving.object_registration import get_custom_objects
54

@@ -17,28 +16,24 @@ def _obj_type(self):
1716
)
1817

1918
@classmethod
20-
def _unpickle_model(cls, model_buf, *args):
19+
def _unpickle_model(cls, data):
2120
import keras.src.saving.saving_lib as saving_lib
2221

2322
# pickle is not safe regardless of what you do.
2423

25-
if len(args) == 0:
26-
return saving_lib._load_model_from_fileobj(
27-
model_buf,
28-
custom_objects=None,
29-
compile=True,
30-
safe_mode=False,
31-
)
24+
if "custom_objects_buf" in data.keys():
25+
import pickle
3226

27+
custom_objects = pickle.load(data["custom_objects_buf"])
3328
else:
34-
custom_objects_buf = args[0]
35-
custom_objects = pickle.load(custom_objects_buf)
36-
return saving_lib._load_model_from_fileobj(
37-
model_buf,
38-
custom_objects=custom_objects,
39-
compile=True,
40-
safe_mode=False,
41-
)
29+
custom_objects = None
30+
31+
return saving_lib._load_model_from_fileobj(
32+
data["model_buf"],
33+
custom_objects=custom_objects,
34+
compile=True,
35+
safe_mode=False,
36+
)
4237

4338
def __reduce__(self):
4439
"""__reduce__ is used to customize the behavior of `pickle.pickle()`.
@@ -48,14 +43,23 @@ def __reduce__(self):
4843
keras saving library."""
4944
import keras.src.saving.saving_lib as saving_lib
5045

46+
data = {}
47+
5148
model_buf = io.BytesIO()
5249
saving_lib._save_model_to_fileobj(self, model_buf, "h5")
50+
data["model_buf"] = model_buf
51+
52+
try:
53+
import cloudpickle
5354

54-
custom_objects_buf = io.BytesIO()
55-
pickle.dump(get_custom_objects(), custom_objects_buf)
56-
custom_objects_buf.seek(0)
55+
custom_objects_buf = io.BytesIO()
56+
cloudpickle.dump(get_custom_objects(), custom_objects_buf)
57+
custom_objects_buf.seek(0)
58+
data["custom_objects_buf"] = custom_objects_buf
59+
except ImportError:
60+
pass
5761

5862
return (
5963
self._unpickle_model,
60-
(model_buf, custom_objects_buf),
64+
(data,),
6165
)

0 commit comments

Comments
 (0)