Skip to content

Commit

Permalink
Handle compatibility issue & use cloudpickle
Browse files Browse the repository at this point in the history
  • Loading branch information
mthiboust committed Jun 18, 2024
1 parent f4865cc commit e73c4cc
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 47 deletions.
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import pytest # noqa: E402

from keras.src.backend import backend # noqa: E402
from keras.src.saving.object_registration import ( # noqa: E402
get_custom_objects,
)


def pytest_configure(config):
Expand All @@ -32,3 +35,9 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)


# Ensure each test is run in isolation regarding the custom objects dict
@pytest.fixture(autouse=True)
def reset_custom_objects_global_dictionary(request):
get_custom_objects().clear()
76 changes: 50 additions & 26 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pickle
import sys

import cloudpickle
import numpy as np
import pytest
from absl.testing import parameterized
Expand All @@ -14,26 +16,6 @@
from keras.src.saving.object_registration import register_keras_serializable


@pytest.fixture
def my_custom_dense():
@register_keras_serializable(package="MyLayers", name="CustomDense")
class CustomDense(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.dense = layers.Dense(units)

def call(self, x):
return self.dense(x)

def get_config(self):
config = super().get_config()
config.update({"units": self.units})
return config

return CustomDense


def _get_model():
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
Expand Down Expand Up @@ -89,11 +71,15 @@ def _get_model_multi_outputs_dict():
return model


def _get_model_custom_layer():
x = Input(shape=(3,), name="input_a")
output_a = my_custom_dense()(10, name="output_a")(x)
model = Model(x, output_a)
return model
@pytest.fixture
def fake_main_module(request, monkeypatch):
original_main = sys.modules["__main__"]

def restore_main_module():
sys.modules["__main__"] = original_main

request.addfinalizer(restore_main_module)
sys.modules["__main__"] = sys.modules[__name__]


@pytest.mark.requires_trainable_backend
Expand Down Expand Up @@ -155,7 +141,6 @@ def call(self, x):
("single_list_output_2", _get_model_single_output_list),
("single_list_output_3", _get_model_single_output_list),
("single_list_output_4", _get_model_single_output_list),
("custom_layer", _get_model_custom_layer),
)
def test_functional_pickling(self, model_fn):
model = model_fn()
Expand All @@ -170,6 +155,45 @@ def test_functional_pickling(self, model_fn):

self.assertAllClose(np.array(pred_reloaded), np.array(pred))

# Fake the __main__ module because cloudpickle only serializes
# functions & classes if they are defined in the __main__ module.
@pytest.mark.usefixtures("fake_main_module")
def test_functional_pickling_custom_layer(self):
@register_keras_serializable()
class CustomDense(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.dense = layers.Dense(units)

def call(self, x):
return self.dense(x)

def get_config(self):
config = super().get_config()
config.update({"units": self.units})
return config

x = Input(shape=(3,), name="input_a")
output_a = CustomDense(10, name="output_a")(x)
model = Model(x, output_a)

self.assertIsInstance(model, Functional)
model.compile()
x = np.random.rand(8, 3)

dumped_pickle = cloudpickle.dumps(model)

# Verify that we can load the dumped pickle even if the custom object
# is not available in the loading environment.
del CustomDense
reloaded_pickle = cloudpickle.loads(dumped_pickle)

pred_reloaded = reloaded_pickle.predict(x)
pred = model.predict(x)

self.assertAllClose(np.array(pred_reloaded), np.array(pred))

@parameterized.named_parameters(
("single_output_1", _get_model_single_output, None),
("single_output_2", _get_model_single_output, "list"),
Expand Down
46 changes: 25 additions & 21 deletions keras/src/saving/keras_saveable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import pickle

from keras.src.saving.object_registration import get_custom_objects

Expand All @@ -17,28 +16,24 @@ def _obj_type(self):
)

@classmethod
def _unpickle_model(cls, model_buf, *args):
def _unpickle_model(cls, data):
import keras.src.saving.saving_lib as saving_lib

# pickle is not safe regardless of what you do.

if len(args) == 0:
return saving_lib._load_model_from_fileobj(
model_buf,
custom_objects=None,
compile=True,
safe_mode=False,
)
if "custom_objects_buf" in data.keys():
import pickle

custom_objects = pickle.load(data["custom_objects_buf"])
else:
custom_objects_buf = args[0]
custom_objects = pickle.load(custom_objects_buf)
return saving_lib._load_model_from_fileobj(
model_buf,
custom_objects=custom_objects,
compile=True,
safe_mode=False,
)
custom_objects = None

return saving_lib._load_model_from_fileobj(
data["model_buf"],
custom_objects=custom_objects,
compile=True,
safe_mode=False,
)

def __reduce__(self):
"""__reduce__ is used to customize the behavior of `pickle.pickle()`.
Expand All @@ -48,14 +43,23 @@ def __reduce__(self):
keras saving library."""
import keras.src.saving.saving_lib as saving_lib

data = {}

model_buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, model_buf, "h5")
data["model_buf"] = model_buf

try:
import cloudpickle

custom_objects_buf = io.BytesIO()
pickle.dump(get_custom_objects(), custom_objects_buf)
custom_objects_buf.seek(0)
custom_objects_buf = io.BytesIO()
cloudpickle.dump(get_custom_objects(), custom_objects_buf)
custom_objects_buf.seek(0)
data["custom_objects_buf"] = custom_objects_buf
except ImportError:
pass

return (
self._unpickle_model,
(model_buf, custom_objects_buf),
(data,),
)

0 comments on commit e73c4cc

Please sign in to comment.