diff --git a/tfx_bsl/beam/run_inference.py b/tfx_bsl/beam/run_inference.py index 37f8e91f..2beb8e30 100644 --- a/tfx_bsl/beam/run_inference.py +++ b/tfx_bsl/beam/run_inference.py @@ -44,8 +44,8 @@ from tfx_bsl.beam import shared from tfx_bsl.public.proto import model_spec_pb2 from tfx_bsl.telemetry import util -from typing import Any, Generator, Iterable, List, Mapping, Sequence, Text, \ - Tuple, Union +from typing import Any, Generator, Iterable, List, Mapping, Optional, \ + Sequence, Text, Tuple, Union # TODO(b/140306674): stop using the internal TF API. from tensorflow.python.saved_model import loader_impl @@ -86,6 +86,15 @@ Tuple[tf.train.Example, classification_pb2.Classifications]] +# Public facing type aliases +ExampleType = Union[tf.train.Example, tf.train.SequenceExample] +QueryType = Tuple[Union[model_spec_pb2.InferenceSpecType, None], ExampleType] + +_QueryBatchType = Tuple[ + Union[model_spec_pb2.InferenceSpecType, None], + List[ExampleType] +] + # TODO(b/151468119): Converts this into enum once we stop supporting Python 2.7 class OperationType(object): @@ -96,8 +105,7 @@ class OperationType(object): @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) +@beam.typehints.with_input_types(ExampleType) @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) def RunInferenceImpl( # pylint: disable=invalid-name examples: beam.pvalue.PCollection, @@ -113,103 +121,172 @@ def RunInferenceImpl( # pylint: disable=invalid-name A PCollection containing prediction logs. Raises: - ValueError; when operation is not supported. + ValueError: when operation is not supported. """ logging.info('RunInference on model: %s', inference_spec_type) - batched_examples = examples | 'BatchExamples' >> beam.BatchElements() - operation_type = _get_operation_type(inference_spec_type) - if operation_type == OperationType.CLASSIFICATION: - return batched_examples | 'Classify' >> _Classify(inference_spec_type) - elif operation_type == OperationType.REGRESSION: - return batched_examples | 'Regress' >> _Regress(inference_spec_type) - elif operation_type == OperationType.PREDICTION: - return batched_examples | 'Predict' >> _Predict(inference_spec_type) - elif operation_type == OperationType.MULTIHEAD: - return (batched_examples - | 'MultiInference' >> _MultiInference(inference_spec_type)) - else: - raise ValueError('Unsupported operation_type %s' % operation_type) + queries = examples | 'FormatAsQueries' >> beam.Map(lambda x: (None, x)) + predictions = queries | '_RunInferenceCoreOnFixedModel' >> _RunInferenceCore( + fixed_inference_spec_type=inference_spec_type) - -_IOTensorSpec = collections.namedtuple( - '_IOTensorSpec', - ['input_tensor_alias', 'input_tensor_name', 'output_alias_tensor_names']) - -_Signature = collections.namedtuple('_Signature', ['name', 'signature_def']) + return predictions @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) +@beam.typehints.with_input_types(QueryType) @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _Classify(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs classify PTransform.""" - if _using_in_process_inference(inference_spec_type): - return (pcoll - | 'Classify' >> beam.ParDo( - _BatchClassifyDoFn(inference_spec_type, shared.Shared())) - | 'BuildPredictionLogForClassifications' >> beam.ParDo( - _BuildPredictionLogForClassificationsDoFn())) - else: - raise NotImplementedError +def _RunInferenceCore( + queries: beam.pvalue.PCollection, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None +) -> beam.pvalue.PCollection: + """Runs inference on queries and returns prediction logs. + This internal run inference implementation operates on queries. Internally, + these queries are grouped by model and inference runs in batches. If a + fixed_inference_spec_type is provided, this spec is used for all inference + requests which enables pre-configuring the model during pipeline + construction. If the fixed_inference_spec_type is not provided, each input + query must contain a valid InferenceSpecType and models will be loaded + dynamically at runtime. -@beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) -@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _Regress(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs regress PTransform.""" - if _using_in_process_inference(inference_spec_type): - return (pcoll - | 'Regress' >> beam.ParDo( - _BatchRegressDoFn(inference_spec_type, shared.Shared())) - | 'BuildPredictionLogForRegressions' >> beam.ParDo( - _BuildPredictionLogForRegressionsDoFn())) + Args: + queries: A PCollection containing QueryType tuples. + fixed_inference_spec_type: An optional model inference endpoint. If + specified, this is "preloaded" during inference and models specified in + query tuples are ignored. This requires the InferenceSpecType to be known + at pipeline creation time. If this fixed_inference_spec_type is not + provided, each input query must contain a valid InferenceSpecType and + models will be loaded dynamically at runtime. + + Returns: + A PCollection containing prediction logs. + + Raises: + ValueError: when operation is not supported. + """ + # TODO(BEAM-2717): Currently batching by inference spec is not supported and + # it is assumed that all queries share the same inference spec. Once + # BEAM-2717 is fixed, we can use beam.GroupIntoBatches and remove this + # constraint. + batched_queries = queries | 'BatchQueries' >> _BatchQueries() + predictions = None + + if fixed_inference_spec_type is None: + # operation type is determined at runtime + split = batched_queries | 'SplitByOperation' >> _SplitByOperation() + + predictions = [ + split[OperationType.CLASSIFICATION] | 'Classify' >> _Classify(), + split[OperationType.REGRESSION] | 'Regress' >> _Regress(), + split[OperationType.PREDICTION] | 'Predict' >> _Predict(), + split[OperationType.MULTIHEAD] | 'MultiInference' >> _MultiInference() + ] | beam.Flatten() else: - raise NotImplementedError + # operation type is determined at pipeline construction time + operation_type = _get_operation_type(fixed_inference_spec_type) + + if operation_type == OperationType.CLASSIFICATION: + predictions = batched_queries | 'Classify' >> _Classify( + fixed_inference_spec_type=fixed_inference_spec_type) + elif operation_type == OperationType.REGRESSION: + predictions = batched_queries | 'Regress' >> _Regress( + fixed_inference_spec_type=fixed_inference_spec_type) + elif operation_type == OperationType.PREDICTION: + predictions = batched_queries | 'Predict' >> _Predict( + fixed_inference_spec_type=fixed_inference_spec_type) + elif operation_type == OperationType.MULTIHEAD: + predictions = (batched_queries | 'MultiInference' >> _MultiInference( + fixed_inference_spec_type=fixed_inference_spec_type)) + else: + raise ValueError('Unsupported operation_type %s' % operation_type) + + return predictions @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) -@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _Predict(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs predict PTransform.""" - if _using_in_process_inference(inference_spec_type): - predictions = ( - pcoll - | 'Predict' >> beam.ParDo( - _BatchPredictDoFn(inference_spec_type, shared.Shared()))) - else: - predictions = ( - pcoll - | 'RemotePredict' >> beam.ParDo( - _RemotePredictDoFn(inference_spec_type, pcoll.pipeline.options))) - return (predictions - | 'BuildPredictionLogForPredictions' >> beam.ParDo( - _BuildPredictionLogForPredictionsDoFn())) +@beam.typehints.with_input_types(QueryType) +@beam.typehints.with_output_types(_QueryBatchType) +def _BatchQueries(queries: beam.pvalue.PCollection) -> beam.pvalue.PCollection: + """Groups queries into batches.""" + + def _to_query_batch(query_list: List[QueryType]) -> _QueryBatchType: + """Converts a list of queries to a logical _QueryBatch.""" + inference_spec = query_list[0][0] + examples = [x[1] for x in query_list] + return (inference_spec, examples) + + batches = ( + queries + # TODO(hgarrereyn): Use of BatchElements is a temporary workaround to + # enable RunInference to run on Dataflow v1 runner until BEAM-2717 + # is fixed. BatchElements does not performing a grouping operation + # and therefore, _BatchQueries currently operates on queries that all + # contain the same inference spec. + | 'Batch' >> beam.BatchElements() + | 'ToQueryBatch' >> beam.Map(_to_query_batch) + ) + return batches @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) -@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _MultiInference(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs multi inference PTransform.""" - if _using_in_process_inference(inference_spec_type): - return ( - pcoll - | 'MultiInference' >> beam.ParDo( - _BatchMultiInferenceDoFn(inference_spec_type, shared.Shared())) - | 'BuildMultiInferenceLog' >> beam.ParDo(_BuildMultiInferenceLogDoFn())) - else: - raise NotImplementedError +@beam.typehints.with_input_types(_QueryBatchType) +@beam.typehints.with_output_types(_QueryBatchType) +def _SplitByOperation(batches): + """A PTransform that splits a _QueryBatchType PCollection based on operation. + + Benchmarks demonstrated that this transform was a bottleneck (comprising + nearly 25% of the total RunInference walltime) since looking up the operation + type requires reading the saved model signature from disk. To improve + performance, we use a caching layer inside each DoFn instance that saves a + mapping of: + + {inference_spec.SerializeToString(): operation_type} + + In practice this cache reduces _SplitByOperation walltime by more than 90%. + + Returns a DoOutputsTuple with keys: + - OperationType.CLASSIFICATION + - OperationType.REGRESSION + - OperationType.PREDICTION + - OperationType.MULTIHEAD + + Raises: + ValueError: If any inference_spec_type is None. + """ + class _SplitDoFn(beam.DoFn): + def __init__(self): + self._cache = {} + + def process(self, batch): + inference_spec, _ = batch + + if inference_spec is None: + raise ValueError("InferenceSpecType cannot be None.") + + key = inference_spec.SerializeToString() + operation_type = self._cache.get(key) + + if operation_type is None: + operation_type = _get_operation_type(inference_spec) + self._cache[key] = operation_type + + return [beam.pvalue.TaggedOutput(operation_type, batch)] + + return ( + batches + | 'SplitDoFn' >> beam.ParDo(_SplitDoFn()).with_outputs( + OperationType.CLASSIFICATION, + OperationType.REGRESSION, + OperationType.PREDICTION, + OperationType.MULTIHEAD + )) + + +_IOTensorSpec = collections.namedtuple( + '_IOTensorSpec', + ['input_tensor_alias', 'input_tensor_name', 'output_alias_tensor_names']) + +_Signature = collections.namedtuple('_Signature', ['name', 'signature_def']) @six.add_metaclass(abc.ABCMeta) @@ -219,14 +296,17 @@ class _BaseDoFn(beam.DoFn): class _MetricsCollector(object): """A collector for beam metrics.""" - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): - operation_type = _get_operation_type(inference_spec_type) - proximity_descriptor = ( - _METRICS_DESCRIPTOR_IN_PROCESS - if _using_in_process_inference(inference_spec_type) else - _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION) - namespace = util.MakeTfxNamespace( - [_METRICS_DESCRIPTOR_INFERENCE, operation_type, proximity_descriptor]) + def __init__(self, operation_type: Text, proximity_descriptor: Text): + """Initializes a metrics collector. + + Args: + operation_type: A string describing the type of operation, e.g. + "CLASSIFICATION". + proximity_descriptor: A string describing the location of inference, + e.g. "InProcess". + """ + namespace = util.MakeTfxNamespace([ + _METRICS_DESCRIPTOR_INFERENCE, operation_type, proximity_descriptor]) # Metrics self._inference_counter = beam.metrics.Metrics.counter( @@ -249,21 +329,45 @@ def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): namespace, 'load_model_latency_milli_secs') # Metrics cache - self.load_model_latency_milli_secs_cache = None - self.model_byte_size_cache = None + self._load_model_latency_milli_secs_cache = None + self._model_byte_size_cache = None + + def commit_cached_metrics(self): + """Updates any cached metrics. - def update_metrics_with_cache(self): - if self.load_model_latency_milli_secs_cache is not None: + If there are no cached metrics, this has no effect. Cached metrics are + automatically cleared after use. + """ + if self._load_model_latency_milli_secs_cache is not None: self._load_model_latency_milli_secs.update( - self.load_model_latency_milli_secs_cache) - self.load_model_latency_milli_secs_cache = None - if self.model_byte_size_cache is not None: - self._model_byte_size.update(self.model_byte_size_cache) - self.model_byte_size_cache = None - - def update(self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]], - latency_micro_secs: int) -> None: + self._load_model_latency_milli_secs_cache) + self._load_model_latency_milli_secs_cache = None + if self._model_byte_size_cache is not None: + self._model_byte_size.update(self._model_byte_size_cache) + self._model_byte_size_cache = None + + def update_model_load( + self, load_model_latency_milli_secs: int, model_byte_size: int): + """Updates model loading metrics. + + Note: To commit model loading metrics, you must call + commit_cached_metrics() after storing values with this method. + + Args: + load_model_latency_milli_secs: Model loading latency in milliseconds. + model_byte_size: Approximate model size in bytes. + """ + self._load_model_latency_milli_secs_cache = load_model_latency_milli_secs + self._model_byte_size_cache = model_byte_size + + def update_inference( + self, elements: List[ExampleType], latency_micro_secs: int) -> None: + """Updates inference metrics. + + Args: + elements: A list of examples used for inference. + latency_micro_secs: Total inference latency in microseconds. + """ self._inference_batch_latency_micro_secs.update(latency_micro_secs) self._num_instances.inc(len(elements)) self._inference_counter.inc(len(elements)) @@ -271,37 +375,35 @@ def update(self, elements: List[Union[tf.train.Example, self._inference_request_batch_byte_size.update( sum(element.ByteSize() for element in elements)) - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): + def __init__(self, operation_type: Text, proximity_descriptor: Text): super(_BaseDoFn, self).__init__() self._clock = None - self._metrics_collector = self._MetricsCollector(inference_spec_type) + self._metrics_collector = self._MetricsCollector( + operation_type, proximity_descriptor) def setup(self): self._clock = _ClockFactory.make_clock() - def process( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] - ) -> Iterable[Any]: + def process(self, batch: _QueryBatchType) -> Iterable[Any]: + inference_spec, elements = batch batch_start_time = self._clock.get_current_time_in_microseconds() - outputs = self.run_inference(elements) + outputs = self.run_inference(inference_spec, elements) result = self._post_process(elements, outputs) - self._metrics_collector.update( + self._metrics_collector.update_inference( elements, self._clock.get_current_time_in_microseconds() - batch_start_time) return result - def finish_bundle(self): - self._metrics_collector.update_metrics_with_cache() - @abc.abstractmethod def run_inference( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, + inference_spec: model_spec_pb2.InferenceSpecType, + elements: List[ExampleType] ) -> Union[Mapping[Text, np.ndarray], Sequence[Mapping[Text, Any]]]: raise NotImplementedError @abc.abstractmethod - def _post_process(self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]], + def _post_process(self, elements: List[ExampleType], outputs: Any) -> Iterable[Any]: raise NotImplementedError @@ -322,8 +424,7 @@ def _retry_on_unavailable_and_resource_error_filter(exception: Exception): exception.resp.status in (503, 429)) -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) # Using output typehints triggers NotImplementedError('BEAM-2717)' on # streaming mode on Dataflow runner. # TODO(b/151468119): Consider to re-batch with online serving request size @@ -349,16 +450,34 @@ class _RemotePredictDoFn(_BaseDoFn): without having access to cloud-hosted model's signatures. """ - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType, - pipeline_options: PipelineOptions): - super(_RemotePredictDoFn, self).__init__(inference_spec_type) + def __init__( + self, + pipeline_options: PipelineOptions, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_RemotePredictDoFn, self).__init__( + OperationType.PREDICTION, _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION) + self._pipeline_options = pipeline_options + self._fixed_inference_spec_type = fixed_inference_spec_type + + self._ai_platform_prediction_model_spec = None + self._api_client = None + self._full_model_name = None + + def setup(self): + super(_RemotePredictDoFn, self).setup() + if self._fixed_inference_spec_type: + self._setup_model(self._fixed_inference_spec_type) + + def _setup_model( + self, inference_spec_type: model_spec_pb2.InferenceSpecType + ): self._ai_platform_prediction_model_spec = ( inference_spec_type.ai_platform_prediction_model_spec) - self._api_client = None project_id = ( inference_spec_type.ai_platform_prediction_model_spec.project_id or - pipeline_options.view_as(GoogleCloudOptions).project) + self._pipeline_options.view_as(GoogleCloudOptions).project) if not project_id: raise ValueError('Either a non-empty project id or project flag in ' ' beam pipeline options needs be provided.') @@ -377,8 +496,6 @@ def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType, self._full_model_name = name_spec.format(project_id, model_name, version_name) - def setup(self): - super(_RemotePredictDoFn, self).setup() # TODO(b/151468119): Add tfx_bsl_version and tfx_bsl_py_version to # user agent once custom header is supported in googleapiclient. self._api_client = discovery.build('ml', 'v1') @@ -401,7 +518,7 @@ def _make_request(self, body: Mapping[Text, List[Any]]) -> http.HttpRequest: name=self._full_model_name, body=body) def _prepare_instances_dict( - self, elements: List[tf.train.Example] + self, elements: List[ExampleType] ) -> Generator[Mapping[Text, Any], None, None]: """Prepare instances by converting features to dictionary.""" for example in elements: @@ -423,14 +540,14 @@ def _prepare_instances_dict( yield instance def _prepare_instances_serialized( - self, elements: List[tf.train.Example] + self, elements: List[ExampleType] ) -> Generator[Mapping[Text, Text], None, None]: """Prepare instances by base64 encoding serialized examples.""" for example in elements: yield {'b64': base64.b64encode(example.SerializeToString()).decode()} def _prepare_instances( - self, elements: List[tf.train.Example] + self, elements: List[ExampleType] ) -> Generator[Mapping[Text, Any], None, None]: if self._ai_platform_prediction_model_spec.use_serialization_config: return self._prepare_instances_serialized(elements) @@ -465,16 +582,19 @@ def _parse_feature_content(values: Sequence[Any], attr_name: Text, return list(values) def run_inference( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, + inference_spec: model_spec_pb2.InferenceSpecType, + elements: List[ExampleType] ) -> Sequence[Mapping[Text, Any]]: + if not self._fixed_inference_spec_type: + self._setup_model(inference_spec) body = {'instances': list(self._prepare_instances(elements))} request = self._make_request(body) response = self._execute_request(request) return response['predictions'] def _post_process( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]], - outputs: Sequence[Mapping[Text, Any]] + self, elements: List[ExampleType], outputs: Sequence[Mapping[Text, Any]] ) -> Iterable[prediction_log_pb2.PredictLog]: result = [] for output in outputs: @@ -507,30 +627,51 @@ class _BaseBatchSavedModelDoFn(_BaseDoFn): def __init__( self, - inference_spec_type: model_spec_pb2.InferenceSpecType, shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None, + operation_type: Text = '' ): - super(_BaseBatchSavedModelDoFn, self).__init__(inference_spec_type) - self._inference_spec_type = inference_spec_type + super(_BaseBatchSavedModelDoFn, self).__init__( + operation_type, _METRICS_DESCRIPTOR_IN_PROCESS) self._shared_model_handle = shared_model_handle - self._model_path = inference_spec_type.saved_model_spec.model_path + self._fixed_inference_spec_type = fixed_inference_spec_type + + self._model_path = None self._tags = None - self._signatures = _get_signatures( - inference_spec_type.saved_model_spec.model_path, - inference_spec_type.saved_model_spec.signature_name, - _get_tags(inference_spec_type)) + self._signatures = None self._session = None self._io_tensor_spec = None def setup(self): """Load the model. - Note that worker may crash if exception is thrown in setup due to b/139207285. """ - super(_BaseBatchSavedModelDoFn, self).setup() - self._tags = _get_tags(self._inference_spec_type) + if self._fixed_inference_spec_type: + self._setup_model(self._fixed_inference_spec_type) + + def finish_bundle(self): + # If we are using a fixed model, _setup_model will be called in DoFn.setup + # and model loading metrics will be cached. To commit these metrics, we + # need to call _metrics_collector.commit_cached_metrics() once during the + # DoFn lifetime. DoFn.teardown() is not guaranteed to be called, so the + # next best option is to call this in finish_bundle(). + if self._fixed_inference_spec_type: + self._metrics_collector.commit_cached_metrics() + + def _setup_model( + self, inference_spec_type: model_spec_pb2.InferenceSpecType + ): + self._model_path = inference_spec_type.saved_model_spec.model_path + self._signatures = _get_signatures( + inference_spec_type.saved_model_spec.model_path, + inference_spec_type.saved_model_spec.signature_name, + _get_tags(inference_spec_type)) + + self._validate_model() + + self._tags = _get_tags(inference_spec_type) self._io_tensor_spec = self._pre_process() if self._has_tpu_tag(): @@ -538,6 +679,14 @@ def setup(self): raise ValueError('TPU inference is not supported yet.') self._session = self._load_model() + def _validate_model(self): + """Optional subclass model validation hook. + + Raises: + ValueError: if model is invalid. + """ + pass + def _load_model(self): """Load a saved model into memory. @@ -554,10 +703,14 @@ def load(): tf.compat.v1.saved_model.loader.load(result, self._tags, self._model_path) end_time = self._clock.get_current_time_in_microseconds() memory_after = _get_current_process_memory_in_bytes() - self._metrics_collector.load_model_latency_milli_secs_cache = ( - (end_time - start_time) / _MILLISECOND_TO_MICROSECOND) - self._metrics_collector.model_byte_size_cache = ( - memory_after - memory_before) + + # Compute model loading metrics. + load_model_latency_milli_secs = ( + (end_time - start_time) / _MILLISECOND_TO_MICROSECOND) + model_byte_size = (memory_after - memory_before) + self._metrics_collector.update_model_load( + load_model_latency_milli_secs, model_byte_size) + return result if not self._model_path: @@ -600,14 +753,19 @@ def _has_tpu_tag(self) -> bool: tf.saved_model.TPU in self._tags) def run_inference( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, + inference_spec_type: model_spec_pb2.InferenceSpecType, + elements: List[ExampleType] ) -> Mapping[Text, np.ndarray]: + if not self._fixed_inference_spec_type: + self._setup_model(inference_spec_type) + self._metrics_collector.commit_cached_metrics() self._check_elements(elements) outputs = self._run_tf_operations(elements) return outputs def _run_tf_operations( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, elements: List[ExampleType] ) -> Mapping[Text, np.ndarray]: input_values = [] for element in elements: @@ -619,93 +777,111 @@ def _run_tf_operations( raise RuntimeError('Output length does not match fetches') return result - def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + def _check_elements(self, elements: List[ExampleType]) -> None: """Unimplemented.""" raise NotImplementedError -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(Tuple[tf.train.Example, classification_pb2.Classifications]) class _BatchClassifyDoFn(_BaseBatchSavedModelDoFn): """A DoFn that run inference on classification model.""" - def setup(self): + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchClassifyDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.CLASSIFICATION) + + def _validate_model(self): signature_def = self._signatures[0].signature_def if signature_def.method_name != tf.saved_model.CLASSIFY_METHOD_NAME: raise ValueError( 'BulkInferrerClassifyDoFn requires signature method ' 'name %s, got: %s' % tf.saved_model.CLASSIFY_METHOD_NAME, signature_def.method_name) - super(_BatchClassifyDoFn, self).setup() def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: if not all(isinstance(element, tf.train.Example) for element in elements): raise ValueError('Classify only supports tf.train.Example') def _post_process( - self, elements: Sequence[tf.train.Example], outputs: Mapping[Text, - np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[Tuple[tf.train.Example, classification_pb2.Classifications]]: classifications = _post_process_classify( self._io_tensor_spec.output_alias_tensor_names, elements, outputs) return zip(elements, classifications) -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(Tuple[tf.train.Example, regression_pb2.Regression]) class _BatchRegressDoFn(_BaseBatchSavedModelDoFn): """A DoFn that run inference on regression model.""" - def setup(self): - super(_BatchRegressDoFn, self).setup() + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchRegressDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.REGRESSION) + + def _validate_model(self): + signature_def = self._signatures[0].signature_def + if signature_def.method_name != tf.saved_model.REGRESS_METHOD_NAME: + raise ValueError( + 'BulkInferrerRegressDoFn requires signature method ' + 'name %s, got: %s' % tf.saved_model.REGRESS_METHOD_NAME, + signature_def.method_name) def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: if not all(isinstance(element, tf.train.Example) for element in elements): raise ValueError('Regress only supports tf.train.Example') def _post_process( - self, elements: Sequence[tf.train.Example], outputs: Mapping[Text, - np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[Tuple[tf.train.Example, regression_pb2.Regression]]: regressions = _post_process_regress(elements, outputs) return zip(elements, regressions) -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(prediction_log_pb2.PredictLog) class _BatchPredictDoFn(_BaseBatchSavedModelDoFn): """A DoFn that runs inference on predict model.""" - def setup(self): + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchPredictDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.PREDICTION) + + def _validate_model(self): signature_def = self._signatures[0].signature_def if signature_def.method_name != tf.saved_model.PREDICT_METHOD_NAME: raise ValueError( 'BulkInferrerPredictDoFn requires signature method ' 'name %s, got: %s' % tf.saved_model.PREDICT_METHOD_NAME, signature_def.method_name) - super(_BatchPredictDoFn, self).setup() def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: pass def _post_process( - self, elements: Union[Sequence[tf.train.Example], - Sequence[tf.train.SequenceExample]], - outputs: Mapping[Text, np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[prediction_log_pb2.PredictLog]: input_tensor_alias = self._io_tensor_spec.input_tensor_alias signature_name = self._signatures[0].name @@ -741,22 +917,28 @@ def _post_process( return result -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(Tuple[tf.train.Example, inference_pb2.MultiInferenceResponse]) class _BatchMultiInferenceDoFn(_BaseBatchSavedModelDoFn): """A DoFn that runs inference on multi-head model.""" + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchMultiInferenceDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.MULTIHEAD) + def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: if not all(isinstance(element, tf.train.Example) for element in elements): raise ValueError('Multi inference only supports tf.train.Example') def _post_process( - self, elements: Sequence[tf.train.Example], outputs: Mapping[Text, - np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[Tuple[tf.train.Example, inference_pb2.MultiInferenceResponse]]: classifications = None regressions = None @@ -862,6 +1044,99 @@ def process( yield result +def _BuildInferenceOperation( + name: str, + in_process_dofn: _BaseBatchSavedModelDoFn, + remote_dofn: Optional[_BaseDoFn], + build_prediction_log_dofn: beam.DoFn +): + """Construct an operation specific inference sub-pipeline. + + Args: + name: Name of the operation (e.g. "Classify"). + in_process_dofn: A _BaseBatchSavedModelDoFn class to use for in-process + inference. + remote_dofn: An optional DoFn that is used for remote inference. If not + provided, attempts at remote inference will throw a NotImplementedError. + build_prediction_log_dofn: A DoFn that can build prediction logs from the + output of `in_process_dofn` and `remote_dofn`. + + Returns: + A PTransform of the type (_QueryBatchType -> PredictionLog). + + Raises: + NotImplementedError: if remote inference is attempted and not supported. + """ + @beam.ptransform_fn + @beam.typehints.with_input_types(_QueryBatchType) + @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) + def _Op( + pcoll: beam.pvalue.PCollection, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): # pylint: disable=invalid-name + raw_result = None + + if fixed_inference_spec_type is None: + tagged = pcoll | ('TagInferenceType%s' % name) >> _TagUsingInProcessInference() + + in_process_result = ( + tagged['in_process'] + | ('InProcess%s' % name) >> beam.ParDo( + in_process_dofn(shared.Shared()))) + + if remote_dofn: + remote_result = ( + tagged['remote'] + | ('Remote%s' % name) >> beam.ParDo( + remote_dofn(pcoll.pipeline.options))) + + raw_result = ( + [in_process_result, remote_result] + | 'FlattenResult' >> beam.Flatten()) + else: + tagged['remote'] | 'NotImplemented' >> _NotImplementedTransform( + 'Remote inference is not supported for operation type: %s' % name) + raw_result = in_process_result + else: + if _using_in_process_inference(fixed_inference_spec_type): + raw_result = ( + pcoll + | ('InProcess%s' % name) >> beam.ParDo(in_process_dofn( + shared.Shared(), + fixed_inference_spec_type=fixed_inference_spec_type))) + else: + if remote_dofn: + raw_result = ( + pcoll + | ('Remote%s' % name) >> beam.ParDo(remote_dofn( + pcoll.pipeline.options, + fixed_inference_spec_type=fixed_inference_spec_type))) + else: + raise NotImplementedError('Remote inference is not supported for' + 'operation type: %s' % name) + + return ( + raw_result + | ('BuildPredictionLogFor%s' % name) >> beam.ParDo( + build_prediction_log_dofn())) + + return _Op + + +_Classify = _BuildInferenceOperation( + 'Classify', _BatchClassifyDoFn, None, + _BuildPredictionLogForClassificationsDoFn) +_Regress = _BuildInferenceOperation( + 'Regress', _BatchRegressDoFn, None, + _BuildPredictionLogForRegressionsDoFn) +_Predict = _BuildInferenceOperation( + 'Predict', _BatchPredictDoFn, _RemotePredictDoFn, + _BuildPredictionLogForPredictionsDoFn) +_MultiInference = _BuildInferenceOperation( + 'MultiInference', _BatchMultiInferenceDoFn, None, + _BuildMultiInferenceLogDoFn) + + def _post_process_classify( output_alias_tensor_names: Mapping[Text, Text], elements: Sequence[tf.train.Example], outputs: Mapping[Text, np.ndarray] @@ -1070,6 +1345,27 @@ def _using_in_process_inference( return inference_spec_type.WhichOneof('type') == 'saved_model_spec' +@beam.ptransform_fn +@beam.typehints.with_input_types(_QueryBatchType) +@beam.typehints.with_output_types(_QueryBatchType) +def _TagUsingInProcessInference( + queries: beam.pvalue.PCollection) -> beam.pvalue.DoOutputsTuple: + """Tags each query batch with 'in_process' or 'remote'.""" + return queries | 'TagBatches' >> beam.Map( + lambda query: beam.pvalue.TaggedOutput( + 'in_process' if _using_in_process_inference(query[0]) else 'remote', query) + ).with_outputs('in_process', 'remote') + + +@beam.ptransform_fn +def _NotImplementedTransform( + pcoll: beam.pvalue.PCollection, message: Text = ''): + """Raises NotImplementedError for each value in the input PCollection.""" + def _raise(x): + raise NotImplementedError(message) + pcoll | beam.Map(_raise) + + def _get_signatures(model_path: Text, signatures: Sequence[Text], tags: Sequence[Text]) -> Sequence[_Signature]: """Returns a sequence of {model_signature_name: signature}.""" diff --git a/tfx_bsl/beam/run_inference_test.py b/tfx_bsl/beam/run_inference_test.py index 5fb9adad..53236f55 100644 --- a/tfx_bsl/beam/run_inference_test.py +++ b/tfx_bsl/beam/run_inference_test.py @@ -28,6 +28,7 @@ import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter +from apache_beam.options import pipeline_options from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from googleapiclient import discovery @@ -71,6 +72,16 @@ def _prepare_predict_examples(self, example_path): for example in self._predict_examples: output_file.write(example.SerializeToString()) + def _get_results(self, prediction_log_path): + results = [] + for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'): + record_iterator = tf.compat.v1.io.tf_record_iterator(path=f) + for record_string in record_iterator: + prediction_log = prediction_log_pb2.PredictionLog() + prediction_log.MergeFromString(record_string) + results.append(prediction_log) + return results + class RunOfflineInferenceTest(RunInferenceFixture): @@ -220,16 +231,6 @@ def _run_inference_with_beam(self, example_path, inference_spec_type, prediction_log_path, coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog))) - def _get_results(self, prediction_log_path): - results = [] - for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'): - record_iterator = tf.compat.v1.io.tf_record_iterator(path=f) - for record_string in record_iterator: - prediction_log = prediction_log_pb2.PredictionLog() - prediction_log.MergeFromString(record_string) - results.append(prediction_log) - return results - def testModelPathInvalid(self): example_path = self._get_output_data_dir('examples') self._prepare_predict_examples(example_path) @@ -609,7 +610,9 @@ def test_request_body_with_binary_data(self): project_id='test_project', model_name='test_model', version_name='test_version')) - remote_predict = run_inference._RemotePredictDoFn(inference_spec_type, None) + remote_predict = run_inference._RemotePredictDoFn( + None, fixed_inference_spec_type=inference_spec_type) + remote_predict._setup_model(remote_predict._fixed_inference_spec_type) result = list(remote_predict._prepare_instances([example])) self.assertEqual(result, [ { @@ -638,12 +641,121 @@ def test_request_serialized_example(self): model_name='test_model', version_name='test_version', use_serialization_config=True)) - remote_predict = run_inference._RemotePredictDoFn(inference_spec_type, None) + remote_predict = run_inference._RemotePredictDoFn( + None, fixed_inference_spec_type=inference_spec_type) + remote_predict._setup_model(remote_predict._fixed_inference_spec_type) result = list(remote_predict._prepare_instances([example])) self.assertEqual(result, [{ 'b64': base64.b64encode(example.SerializeToString()).decode() }]) +class RunInferenceCoreTest(RunInferenceFixture): + + def _build_keras_model(self, add): + """Builds a dummy keras model with one input and output.""" + inp = tf.keras.layers.Input((1,), name='input') + out = tf.keras.layers.Lambda(lambda x: x + add)(inp) + m = tf.keras.models.Model(inp, out) + return m + + def _new_model(self, model_path, add): + """Exports a keras model in the SavedModel format.""" + class WrapKerasModel(tf.keras.Model): + """Wrapper class to apply a signature to a keras model.""" + def __init__(self, model): + super().__init__() + self.model = model + + @tf.function(input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs') + ]) + def call(self, serialized_example): + features = { + 'input': tf.compat.v1.io.FixedLenFeature( + [1], + dtype=tf.float32, + default_value=0 + ) + } + input_tensor_dict = tf.io.parse_example(serialized_example, features) + return {'output': self.model(input_tensor_dict)} + + model = self._build_keras_model(add) + wrapped_model = WrapKerasModel(model) + tf.compat.v1.keras.experimental.export_saved_model( + wrapped_model, model_path, serving_only=True + ) + return self._get_saved_model_spec(model_path) + + def _decode_value(self, pl): + """Returns output value from prediction log.""" + out_tensor = pl.predict_log.response.outputs['output'] + arr = tf.make_ndarray(out_tensor) + x = arr[0][0] + return x + + def _make_example(self, x): + """Builds a TFExample object with a single value.""" + feature = {} + feature['input'] = tf.train.Feature( + float_list=tf.train.FloatList(value=[x])) + return tf.train.Example(features=tf.train.Features(feature=feature)) + + def _get_saved_model_spec(self, model_path): + """Returns an InferenceSpecType object for a saved model path.""" + return model_spec_pb2.InferenceSpecType( + saved_model_spec=model_spec_pb2.SavedModelSpec( + model_path=model_path)) + + # TODO(hgarrereyn): Switch _BatchElements to use GroupIntoBatches once + # BEAM-2717 is fixed so examples are grouped by inference spec key. The + # following test indicates desired but currently unsupported behavior: + # + # def test_batch_queries_multiple_models(self): + # spec1 = self._get_saved_model_spec('/example/model1') + # spec2 = self._get_saved_model_spec('/example/model2') + # + # queries = [] + # for i in range(100): + # queries.append((spec1 if i % 2 == 0 else spec2, self._make_example(i))) + # + # correct = {example.SerializeToString(): spec for spec, example in queries} + # + # def _check_batch(batch): + # """Assert examples are grouped with the correct inference spec.""" + # spec, examples = batch + # assert all([correct[x.SerializeToString()] == spec for x in examples]) + # + # with beam.Pipeline() as p: + # queries = p | 'Queries' >> beam.Create(queries) + # batches = queries | '_BatchQueries' >> run_inference._BatchQueries() + # + # _ = batches | 'Check' >> beam.Map(_check_batch) + + def test_inference_on_queries(self): + spec = self._new_model(self._get_output_data_dir('model1'), 100) + predictions_path = self._get_output_data_dir('predictions') + queries = [(spec, self._make_example(i)) for i in range(10)] + + options = pipeline_options.PipelineOptions(streaming=False) + with beam.Pipeline(options=options) as p: + _ = ( + p + | 'Queries' >> beam.Create(queries) \ + | '_RunInferenceCore' >> run_inference._RunInferenceCore() \ + | 'WritePredictions' >> beam.io.WriteToTFRecord( + predictions_path, + coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)) + ) + + results = self._get_results(predictions_path) + values = [int(self._decode_value(x)) for x in results] + self.assertEqual( + values, + [100,101,102,103,104,105,106,107,108,109] + ) + + if __name__ == '__main__': tf.test.main()