Skip to content

Commit 0ffafd7

Browse files
olegshaldybinOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 875386048
1 parent db84474 commit 0ffafd7

File tree

5 files changed

+180
-19
lines changed

5 files changed

+180
-19
lines changed

export/orbax/export/data_processors/tf_data_processor_test.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Tests for TfDataProcessor."""
16-
15+
import os
1716
import orbax.experimental.model.core as obm
1817
from orbax.export.data_processors import tf_data_processor
1918
import tensorflow as tf
@@ -80,6 +79,7 @@ def test_prepare_succeeds(self):
8079

8180
self.assertIsNotNone(processor.concrete_function)
8281
self.assertIsNotNone(processor.obm_function)
82+
8383
self.assertEqual(
8484
processor.input_signature[0][0],
8585
obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f64, name='x'),
@@ -158,17 +158,40 @@ def test_suppress_x64_output(self):
158158
def test_convert_to_bfloat16(self):
159159
v = tf.Variable(0.5, dtype=tf.float32)
160160

161-
def func(x):
162-
return v + x
161+
def func(x, y, z):
162+
return v + x + y + tf.cast(z, tf.float32)
163163

164164
processor = tf_data_processor.TfDataProcessor(func, name='preprocessor')
165+
converter_options = converter_options_v2_pb2.ConverterOptionsV2(
166+
bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions(
167+
scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL,
168+
skip_safety_checks=True,
169+
)
170+
)
171+
165172
processor.prepare(
166-
available_tensor_specs=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32)),
167-
bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2(
168-
bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions(
169-
scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL,
170-
skip_safety_checks=True,
171-
)
173+
available_tensor_specs=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
174+
tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
175+
tf.constant(2, dtype=tf.int32)),
176+
bfloat16_options=converter_options,
177+
tf_trackable_resources=[v],
178+
)
179+
180+
self.assertEqual(
181+
processor.input_signature,
182+
(
183+
(
184+
obm.ShloTensorSpec(
185+
shape=(2, 3), dtype=obm.ShloDType.bf16, name='x'
186+
),
187+
obm.ShloTensorSpec(
188+
shape=(2, 3), dtype=obm.ShloDType.bf16, name='y'
189+
),
190+
obm.ShloTensorSpec(
191+
shape=(), dtype=obm.ShloDType.i32, name='z'
192+
),
193+
),
194+
{},
172195
),
173196
)
174197
self.assertEqual(
@@ -177,11 +200,22 @@ def func(x):
177200
shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0'
178201
),
179202
)
180-
self.assertLen(processor.concrete_function.variables, 1)
181-
self.assertEqual(
182-
processor.concrete_function.variables[0].dtype, tf.bfloat16
203+
204+
# Verify that the variables have been converted to bfloat16 too.
205+
model_dir = self.create_tempdir().full_path
206+
tf2obm.save_tf_functions(
207+
model_dir,
208+
{'preprocessor': processor.concrete_function},
209+
trackable_resources=[v],
210+
converter_options=converter_options,
183211
)
184212

213+
saved_model = tf.saved_model.load(os.path.join(model_dir, 'tf_saved_model'))
214+
restored_fn = saved_model.signatures['preprocessor']
215+
216+
self.assertLen(restored_fn.variables, 1)
217+
self.assertEqual(restored_fn.variables[0].dtype, tf.bfloat16)
218+
185219
def test_bfloat16_convert_error(self):
186220
processor = tf_data_processor.TfDataProcessor(
187221
lambda x: 0.5 + x, name='preprocessor'

export/orbax/export/export_testing_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
"""Testing utils for orbax.export."""
16+
1617
import os
1718
from typing import cast
19+
1820
import jax
1921
from jax import sharding
2022
from jax.experimental import mesh_utils
@@ -27,4 +29,5 @@
2729
from orbax.export import serving_config as osc
2830
import tensorflow as tf
2931

32+
3033
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

export/orbax/export/obm_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from collections.abc import Callable, Mapping, Sequence
1818
import copy
1919
import dataclasses
20-
import functools
2120
import itertools
2221
import os
2322
from typing import Any, cast
@@ -42,6 +41,7 @@
4241

4342
_obm_export_config = config.config
4443

44+
4545
class ObmExport(export_base.ExportBase):
4646
"""Defines the save and load methods for exporting a model using Orbax Model export."""
4747

export/orbax/export/obm_export_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from collections.abc import Mapping, Sequence
1616
import contextlib
17-
import importlib
1817
import os
1918
import pathlib
2019
from typing import Any, Callable

model/orbax/experimental/model/tf2obm/_src/converter.py

Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""Converts TF concrete functions to OBM functions (allowing TF resources)."""
1616

1717
from collections.abc import Mapping, Sequence
18+
import copy
1819
import os
20+
import tempfile
1921
from typing import Any, Dict, NamedTuple, Tuple
2022

2123
from jax import tree_util as jax_tree_util
@@ -26,6 +28,8 @@
2628

2729
from .learning.brain.contrib.tpu_modeling.inference_converter_v2 import converter_options_v2_pb2
2830
from .learning.brain.contrib.tpu_modeling.inference_converter_v2.python import converter
31+
from tensorflow.core.protobuf import meta_graph_pb2 # pylint: disable=g-direct-tensorflow-import
32+
from tensorflow.core.protobuf import saved_model_pb2 # pylint: disable=g-direct-tensorflow-import
2933

3034
TF_CONCRETE_FUNCTION_HANDLE_MIME_TYPE = (
3135
'application/protobuf;'
@@ -52,6 +56,10 @@ def _is_args_kwargs_pattern(tree: utils.TfSignature) -> bool:
5256
def convert_function(
5357
fn_name: str,
5458
fn: tf.types.experimental.ConcreteFunction,
59+
converter_options: (
60+
converter_options_v2_pb2.ConverterOptionsV2 | None
61+
) = None,
62+
trackable_resources: Any | None = None,
5563
) -> obm.SerializableFunction:
5664
"""Converts the TF concrete function to an OBM function.
5765
@@ -62,16 +70,36 @@ def convert_function(
6270
fn_name: The name to be used in the OBM manifest to refer to the TF
6371
function.
6472
fn: The TF concrete function.
73+
converter_options: The converter options to use for the TF SavedModel. If
74+
set, the TF SavedModel will be converted using Inference Converter V2 in
75+
order to get the correct types for the input and output signatures.
76+
trackable_resources: Trackable resources used by the function.
6577
6678
Returns:
6779
The OBM function referring to the original TF function in the TF SavedModel.
6880
"""
69-
input_signature = fn.structured_input_signature
70-
output_signature = get_output_signature(fn)
7181

7282
input_names, _, _ = _flat_input_signature(fn)
7383
output_names = _output_names(fn)
7484

85+
if converter_options is not None:
86+
converterted_signature_def = _get_converted_function_signature_def(
87+
fn_name, fn, trackable_resources, converter_options
88+
)
89+
input_signature = _copy_types_from_signature_def(
90+
fn.structured_input_signature,
91+
converterted_signature_def.inputs,
92+
input_names,
93+
)
94+
output_signature = _copy_types_from_signature_def(
95+
get_output_signature(fn),
96+
converterted_signature_def.outputs,
97+
output_names,
98+
)
99+
else:
100+
input_signature = fn.structured_input_signature
101+
output_signature = get_output_signature(fn)
102+
75103
unstructured_data = obm.manifest_pb2.UnstructuredData(
76104
inlined_bytes=tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(
77105
fn_name=fn_name,
@@ -406,12 +434,23 @@ def save_tf_functions(
406434

407435
target_path = os.path.join(model_dir, tf_saved_model_sub_dir)
408436
if converter_options is not None:
437+
# Inference Converter V2 modifies the converter_options in place, so we
438+
# need to deepcopy it to avoid modifying the original options and keep
439+
# them re-usable.
440+
converter_options_copy = copy.deepcopy(converter_options)
409441
pre_conversion_path = os.path.join(model_dir, 'tmp_tf_saved_model')
410-
tf.saved_model.save(tf_module, pre_conversion_path, signatures=wrapped_fns)
442+
tf.saved_model.save(
443+
tf_module,
444+
pre_conversion_path,
445+
signatures=wrapped_fns,
446+
# Function aliases are used by the Inference Converter V2 to
447+
# identify XLA functions.
448+
options=tf.saved_model.SaveOptions(function_aliases=wrapped_fns),
449+
)
411450
converter.ConvertSavedModel(
412451
pre_conversion_path,
413452
target_path,
414-
converter_options,
453+
converter_options_copy,
415454
)
416455
tf.io.gfile.rmtree(pre_conversion_path)
417456
else:
@@ -422,3 +461,89 @@ def save_tf_functions(
422461
tf_saved_model_as_obm_supplemental(tf_saved_model_sub_dir)
423462
)
424463
}
464+
465+
466+
def _copy_types_from_signature_def(
467+
original_signature: Any,
468+
signature_def_args: Mapping[str, meta_graph_pb2.TensorInfo],
469+
arg_names: Sequence[str],
470+
) -> Any:
471+
"""Copies types from TF SignatureDef to the original signature.
472+
473+
Args:
474+
original_signature: The original signature that needs new types.
475+
signature_def_args: The TF SignatureDef arguments to copy types from.
476+
arg_names: The argument names of the original TF function. They are used to
477+
infer the input order in the original signature.
478+
479+
Returns:
480+
The original signature with types copied from the signature_def for the
481+
corresponding input names.
482+
483+
Raises:
484+
ValueError: If any of the argument names is not found in the SignatureDef.
485+
"""
486+
487+
arg_names_iter = iter(arg_names)
488+
489+
def _copy_type(t: Any) -> Any:
490+
arg_name = next(arg_names_iter)
491+
if arg_name not in signature_def_args:
492+
raise ValueError(
493+
f'Argument name {arg_name!r} not found in SignatureDef: '
494+
f'{signature_def_args.keys()!r}'
495+
)
496+
497+
if not isinstance(t, tf.TensorSpec):
498+
return t
499+
500+
return tf.TensorSpec(
501+
shape=t.shape,
502+
dtype=tf.as_dtype(signature_def_args[arg_name].dtype),
503+
name=arg_name,
504+
)
505+
506+
return jax_tree_util.tree_map(
507+
_copy_type,
508+
original_signature,
509+
)
510+
511+
512+
def _get_converted_function_signature_def(
513+
fn_name: str,
514+
fn: tf.types.experimental.ConcreteFunction,
515+
trackable_resources: Any,
516+
converter_options: converter_options_v2_pb2.ConverterOptionsV2,
517+
) -> meta_graph_pb2.SignatureDef:
518+
"""Saves the function, converts it, returns its SignatureDef.
519+
520+
Args:
521+
fn_name: The name of the function in the SavedModel.
522+
fn: The concrete function to save.
523+
trackable_resources: The trackable resources to save.
524+
converter_options: The converter options to use for the TF SavedModel.
525+
526+
Returns:
527+
The SignatureDef of the converted function.
528+
"""
529+
530+
opts_copy = copy.deepcopy(converter_options)
531+
# There is no need to convert the checkpoint in this case, since we are only
532+
# interested in the signature.
533+
opts_copy.bfloat16_optimization_options.experimental.convert_checkpoint = (
534+
False
535+
)
536+
537+
with tempfile.TemporaryDirectory() as temp_dir:
538+
save_tf_functions(
539+
temp_dir,
540+
{fn_name: fn},
541+
trackable_resources=trackable_resources,
542+
converter_options=opts_copy,
543+
)
544+
545+
converted_model_path = os.path.join(temp_dir, OBM_TF_SAVED_MODEL_SUB_DIR)
546+
with open(os.path.join(converted_model_path, 'saved_model.pb'), 'rb') as f:
547+
saved_model_proto = saved_model_pb2.SavedModel.FromString(f.read())
548+
549+
return saved_model_proto.meta_graphs[0].signature_def[fn_name]

0 commit comments

Comments
 (0)