@@ -47,7 +47,8 @@ class KerasModelExporter:
47
47
def __init__ (self ,
48
48
* ,
49
49
output_names : Optional [Any ] = None ,
50
- subdirectory : Optional [str ] = None ):
50
+ subdirectory : Optional [str ] = None ,
51
+ include_preprocessing : bool = True ):
51
52
"""Captures the args shared across `save(...)` calls.
52
53
53
54
Args:
@@ -62,9 +63,11 @@ def __init__(self,
62
63
keys of a model output can also be renamed.
63
64
subdirectory: An optional subdirectory, if set: models are exported to
64
65
`os.path.join(export_dir, subdirectory).`
66
+ include_preprocessing: Whether to include any `preprocess_model.`
65
67
"""
66
68
self ._output_names = output_names
67
69
self ._subdirectory = subdirectory
70
+ self ._include_preprocessing = include_preprocessing
68
71
69
72
def save (self ,
70
73
preprocess_model : Optional [tf .keras .Model ],
@@ -82,7 +85,7 @@ def save(self,
82
85
model: A `tf.keras.Model` to save.
83
86
export_dir: A destination directory for the model.
84
87
"""
85
- if preprocess_model is not None :
88
+ if preprocess_model is not None and self . _include_preprocessing :
86
89
model = model_utils .chain_first_output (preprocess_model , model )
87
90
if self ._output_names is not None :
88
91
output = _rename_output (model .output , self ._output_names )
@@ -99,18 +102,21 @@ def __init__(self,
99
102
submodule_name : str ,
100
103
* ,
101
104
output_names : Optional [Any ] = None ,
102
- subdirectory : Optional [str ] = None ):
105
+ subdirectory : Optional [str ] = None ,
106
+ include_preprocessing : bool = False ):
103
107
"""Captures the args shared across `save(...)` calls.
104
108
105
109
Args:
106
110
submodule_name: The name of the submodule to export.
107
111
output_names: The names for output Tensor(s), see: `KerasModelExporter.`
108
112
subdirectory: An optional subdirectory, if set: submodules are exported
109
113
to `os.path.join(export_dir, subdirectory).`
114
+ include_preprocessing: Whether to include any `preprocess_model.`
110
115
"""
111
116
self ._output_names = output_names
112
117
self ._subdirectory = subdirectory
113
118
self ._submodule_name = submodule_name
119
+ self ._include_preprocessing = include_preprocessing
114
120
115
121
def save (self ,
116
122
preprocess_model : tf .keras .Model ,
@@ -144,6 +150,7 @@ def save(self,
144
150
145
151
exporter = KerasModelExporter (
146
152
output_names = self ._output_names ,
147
- subdirectory = self ._subdirectory )
153
+ subdirectory = self ._subdirectory ,
154
+ include_preprocessing = self ._include_preprocessing )
148
155
149
156
exporter .save (preprocess_model , submodel , export_dir )
0 commit comments