1
1
import pickle
2
+ import sys
2
3
4
+ import cloudpickle
3
5
import numpy as np
4
6
import pytest
5
7
from absl .testing import parameterized
14
16
from keras .src .saving .object_registration import register_keras_serializable
15
17
16
18
17
- @pytest .fixture
18
- def my_custom_dense ():
19
- @register_keras_serializable (package = "MyLayers" , name = "CustomDense" )
20
- class CustomDense (layers .Layer ):
21
- def __init__ (self , units , ** kwargs ):
22
- super ().__init__ (** kwargs )
23
- self .units = units
24
- self .dense = layers .Dense (units )
25
-
26
- def call (self , x ):
27
- return self .dense (x )
28
-
29
- def get_config (self ):
30
- config = super ().get_config ()
31
- config .update ({"units" : self .units })
32
- return config
33
-
34
- return CustomDense
35
-
36
-
37
19
def _get_model ():
38
20
input_a = Input (shape = (3 ,), batch_size = 2 , name = "input_a" )
39
21
input_b = Input (shape = (3 ,), batch_size = 2 , name = "input_b" )
@@ -89,11 +71,15 @@ def _get_model_multi_outputs_dict():
89
71
return model
90
72
91
73
92
- def _get_model_custom_layer ():
93
- x = Input (shape = (3 ,), name = "input_a" )
94
- output_a = my_custom_dense ()(10 , name = "output_a" )(x )
95
- model = Model (x , output_a )
96
- return model
74
+ @pytest .fixture
75
+ def fake_main_module (request , monkeypatch ):
76
+ original_main = sys .modules ["__main__" ]
77
+
78
+ def restore_main_module ():
79
+ sys .modules ["__main__" ] = original_main
80
+
81
+ request .addfinalizer (restore_main_module )
82
+ sys .modules ["__main__" ] = sys .modules [__name__ ]
97
83
98
84
99
85
@pytest .mark .requires_trainable_backend
@@ -155,7 +141,6 @@ def call(self, x):
155
141
("single_list_output_2" , _get_model_single_output_list ),
156
142
("single_list_output_3" , _get_model_single_output_list ),
157
143
("single_list_output_4" , _get_model_single_output_list ),
158
- ("custom_layer" , _get_model_custom_layer ),
159
144
)
160
145
def test_functional_pickling (self , model_fn ):
161
146
model = model_fn ()
@@ -170,6 +155,45 @@ def test_functional_pickling(self, model_fn):
170
155
171
156
self .assertAllClose (np .array (pred_reloaded ), np .array (pred ))
172
157
158
+ # Fake the __main__ module because cloudpickle only serializes
159
+ # functions & classes if they are defined in the __main__ module.
160
+ @pytest .mark .usefixtures ("fake_main_module" )
161
+ def test_functional_pickling_custom_layer (self ):
162
+ @register_keras_serializable ()
163
+ class CustomDense (layers .Layer ):
164
+ def __init__ (self , units , ** kwargs ):
165
+ super ().__init__ (** kwargs )
166
+ self .units = units
167
+ self .dense = layers .Dense (units )
168
+
169
+ def call (self , x ):
170
+ return self .dense (x )
171
+
172
+ def get_config (self ):
173
+ config = super ().get_config ()
174
+ config .update ({"units" : self .units })
175
+ return config
176
+
177
+ x = Input (shape = (3 ,), name = "input_a" )
178
+ output_a = CustomDense (10 , name = "output_a" )(x )
179
+ model = Model (x , output_a )
180
+
181
+ self .assertIsInstance (model , Functional )
182
+ model .compile ()
183
+ x = np .random .rand (8 , 3 )
184
+
185
+ dumped_pickle = cloudpickle .dumps (model )
186
+
187
+ # Verify that we can load the dumped pickle even if the custom object
188
+ # is not available in the loading environment.
189
+ del CustomDense
190
+ reloaded_pickle = cloudpickle .loads (dumped_pickle )
191
+
192
+ pred_reloaded = reloaded_pickle .predict (x )
193
+ pred = model .predict (x )
194
+
195
+ self .assertAllClose (np .array (pred_reloaded ), np .array (pred ))
196
+
173
197
@parameterized .named_parameters (
174
198
("single_output_1" , _get_model_single_output , None ),
175
199
("single_output_2" , _get_model_single_output , "list" ),
0 commit comments