Skip to content

Commit d7a9659

Browse files
dzelletensorflower-gardener
authored andcommitted
Add include_preprocessing to model exporters.
PiperOrigin-RevId: 488371678
1 parent 1a6f7ad commit d7a9659

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

Diff for: tensorflow_gnn/runner/utils/model_export.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class KerasModelExporter:
4747
def __init__(self,
4848
*,
4949
output_names: Optional[Any] = None,
50-
subdirectory: Optional[str] = None):
50+
subdirectory: Optional[str] = None,
51+
include_preprocessing: bool = True):
5152
"""Captures the args shared across `save(...)` calls.
5253
5354
Args:
@@ -62,9 +63,11 @@ def __init__(self,
6263
keys of a model output can also be renamed.
6364
subdirectory: An optional subdirectory, if set: models are exported to
6465
`os.path.join(export_dir, subdirectory).`
66+
include_preprocessing: Whether to include any `preprocess_model.`
6567
"""
6668
self._output_names = output_names
6769
self._subdirectory = subdirectory
70+
self._include_preprocessing = include_preprocessing
6871

6972
def save(self,
7073
preprocess_model: Optional[tf.keras.Model],
@@ -82,7 +85,7 @@ def save(self,
8285
model: A `tf.keras.Model` to save.
8386
export_dir: A destination directory for the model.
8487
"""
85-
if preprocess_model is not None:
88+
if preprocess_model is not None and self._include_preprocessing:
8689
model = model_utils.chain_first_output(preprocess_model, model)
8790
if self._output_names is not None:
8891
output = _rename_output(model.output, self._output_names)
@@ -99,18 +102,21 @@ def __init__(self,
99102
submodule_name: str,
100103
*,
101104
output_names: Optional[Any] = None,
102-
subdirectory: Optional[str] = None):
105+
subdirectory: Optional[str] = None,
106+
include_preprocessing: bool = False):
103107
"""Captures the args shared across `save(...)` calls.
104108
105109
Args:
106110
submodule_name: The name of the submodule to export.
107111
output_names: The names for output Tensor(s), see: `KerasModelExporter.`
108112
subdirectory: An optional subdirectory, if set: submodules are exported
109113
to `os.path.join(export_dir, subdirectory).`
114+
include_preprocessing: Whether to include any `preprocess_model.`
110115
"""
111116
self._output_names = output_names
112117
self._subdirectory = subdirectory
113118
self._submodule_name = submodule_name
119+
self._include_preprocessing = include_preprocessing
114120

115121
def save(self,
116122
preprocess_model: tf.keras.Model,
@@ -144,6 +150,7 @@ def save(self,
144150

145151
exporter = KerasModelExporter(
146152
output_names=self._output_names,
147-
subdirectory=self._subdirectory)
153+
subdirectory=self._subdirectory,
154+
include_preprocessing=self._include_preprocessing)
148155

149156
exporter.save(preprocess_model, submodel, export_dir)

0 commit comments

Comments
 (0)