1515"""Converts TF concrete functions to OBM functions (allowing TF resources)."""
1616
1717from collections .abc import Mapping , Sequence
18+ import copy
1819import os
20+ import tempfile
1921from typing import Any , Dict , NamedTuple , Tuple
2022
2123from jax import tree_util as jax_tree_util
2628
2729from .learning .brain .contrib .tpu_modeling .inference_converter_v2 import converter_options_v2_pb2
2830from .learning .brain .contrib .tpu_modeling .inference_converter_v2 .python import converter
31+ from tensorflow .core .protobuf import meta_graph_pb2 # pylint: disable=g-direct-tensorflow-import
32+ from tensorflow .core .protobuf import saved_model_pb2 # pylint: disable=g-direct-tensorflow-import
2933
3034TF_CONCRETE_FUNCTION_HANDLE_MIME_TYPE = (
3135 'application/protobuf;'
@@ -52,6 +56,10 @@ def _is_args_kwargs_pattern(tree: utils.TfSignature) -> bool:
5256def convert_function (
5357 fn_name : str ,
5458 fn : tf .types .experimental .ConcreteFunction ,
59+ converter_options : (
60+ converter_options_v2_pb2 .ConverterOptionsV2 | None
61+ ) = None ,
62+ trackable_resources : Any | None = None ,
5563) -> obm .SerializableFunction :
5664 """Converts the TF concrete function to an OBM function.
5765
@@ -62,16 +70,36 @@ def convert_function(
6270 fn_name: The name to be used in the OBM manifest to refer to the TF
6371 function.
6472 fn: The TF concrete function.
73+ converter_options: The converter options to use for the TF SavedModel. If
74+ set, the TF SavedModel will be converted using Inference Converter V2 in
75+ order to get the correct types for the input and output signatures.
76+ trackable_resources: Trackable resources used by the function.
6577
6678 Returns:
6779 The OBM function referring to the original TF function in the TF SavedModel.
6880 """
69- input_signature = fn .structured_input_signature
70- output_signature = get_output_signature (fn )
7181
7282 input_names , _ , _ = _flat_input_signature (fn )
7383 output_names = _output_names (fn )
7484
85+ if converter_options is not None :
86+ converterted_signature_def = _get_converted_function_signature_def (
87+ fn_name , fn , trackable_resources , converter_options
88+ )
89+ input_signature = _copy_types_from_signature_def (
90+ fn .structured_input_signature ,
91+ converterted_signature_def .inputs ,
92+ input_names ,
93+ )
94+ output_signature = _copy_types_from_signature_def (
95+ get_output_signature (fn ),
96+ converterted_signature_def .outputs ,
97+ output_names ,
98+ )
99+ else :
100+ input_signature = fn .structured_input_signature
101+ output_signature = get_output_signature (fn )
102+
75103 unstructured_data = obm .manifest_pb2 .UnstructuredData (
76104 inlined_bytes = tf_concrete_function_handle_pb2 .TfConcreteFunctionHandle (
77105 fn_name = fn_name ,
@@ -406,12 +434,23 @@ def save_tf_functions(
406434
407435 target_path = os .path .join (model_dir , tf_saved_model_sub_dir )
408436 if converter_options is not None :
437+ # Inference Converter V2 modifies the converter_options in place, so we
438+ # need to deepcopy it to avoid modifying the original options and keep
439+ # them re-usable.
440+ converter_options_copy = copy .deepcopy (converter_options )
409441 pre_conversion_path = os .path .join (model_dir , 'tmp_tf_saved_model' )
410- tf .saved_model .save (tf_module , pre_conversion_path , signatures = wrapped_fns )
442+ tf .saved_model .save (
443+ tf_module ,
444+ pre_conversion_path ,
445+ signatures = wrapped_fns ,
446+ # Function aliases are used by the Inference Converter V2 to
447+ # identify XLA functions.
448+ options = tf .saved_model .SaveOptions (function_aliases = wrapped_fns ),
449+ )
411450 converter .ConvertSavedModel (
412451 pre_conversion_path ,
413452 target_path ,
414- converter_options ,
453+ converter_options_copy ,
415454 )
416455 tf .io .gfile .rmtree (pre_conversion_path )
417456 else :
@@ -422,3 +461,89 @@ def save_tf_functions(
422461 tf_saved_model_as_obm_supplemental (tf_saved_model_sub_dir )
423462 )
424463 }
464+
465+
466+ def _copy_types_from_signature_def (
467+ original_signature : Any ,
468+ signature_def_args : Mapping [str , meta_graph_pb2 .TensorInfo ],
469+ arg_names : Sequence [str ],
470+ ) -> Any :
471+ """Copies types from TF SignatureDef to the original signature.
472+
473+ Args:
474+ original_signature: The original signature that needs new types.
475+ signature_def_args: The TF SignatureDef arguments to copy types from.
476+ arg_names: The argument names of the original TF function. They are used to
477+ infer the input order in the original signature.
478+
479+ Returns:
480+ The original signature with types copied from the signature_def for the
481+ corresponding input names.
482+
483+ Raises:
484+ ValueError: If any of the argument names is not found in the SignatureDef.
485+ """
486+
487+ arg_names_iter = iter (arg_names )
488+
489+ def _copy_type (t : Any ) -> Any :
490+ arg_name = next (arg_names_iter )
491+ if arg_name not in signature_def_args :
492+ raise ValueError (
493+ f'Argument name { arg_name !r} not found in SignatureDef: '
494+ f'{ signature_def_args .keys ()!r} '
495+ )
496+
497+ if not isinstance (t , tf .TensorSpec ):
498+ return t
499+
500+ return tf .TensorSpec (
501+ shape = t .shape ,
502+ dtype = tf .as_dtype (signature_def_args [arg_name ].dtype ),
503+ name = arg_name ,
504+ )
505+
506+ return jax_tree_util .tree_map (
507+ _copy_type ,
508+ original_signature ,
509+ )
510+
511+
512+ def _get_converted_function_signature_def (
513+ fn_name : str ,
514+ fn : tf .types .experimental .ConcreteFunction ,
515+ trackable_resources : Any ,
516+ converter_options : converter_options_v2_pb2 .ConverterOptionsV2 ,
517+ ) -> meta_graph_pb2 .SignatureDef :
518+ """Saves the function, converts it, returns its SignatureDef.
519+
520+ Args:
521+ fn_name: The name of the function in the SavedModel.
522+ fn: The concrete function to save.
523+ trackable_resources: The trackable resources to save.
524+ converter_options: The converter options to use for the TF SavedModel.
525+
526+ Returns:
527+ The SignatureDef of the converted function.
528+ """
529+
530+ opts_copy = copy .deepcopy (converter_options )
531+ # There is no need to convert the checkpoint in this case, since we are only
532+ # interested in the signature.
533+ opts_copy .bfloat16_optimization_options .experimental .convert_checkpoint = (
534+ False
535+ )
536+
537+ with tempfile .TemporaryDirectory () as temp_dir :
538+ save_tf_functions (
539+ temp_dir ,
540+ {fn_name : fn },
541+ trackable_resources = trackable_resources ,
542+ converter_options = opts_copy ,
543+ )
544+
545+ converted_model_path = os .path .join (temp_dir , OBM_TF_SAVED_MODEL_SUB_DIR )
546+ with open (os .path .join (converted_model_path , 'saved_model.pb' ), 'rb' ) as f :
547+ saved_model_proto = saved_model_pb2 .SavedModel .FromString (f .read ())
548+
549+ return saved_model_proto .meta_graphs [0 ].signature_def [fn_name ]
0 commit comments