Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_example flag to the RunInference and Model Handler #23266

Merged
merged 12 commits into from
Sep 18, 2022
52 changes: 42 additions & 10 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ def _to_microseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MICROSECOND)


def _convert_to_result(
batch: Iterable,
predictions: Union[Iterable, Dict[Any, Iterable]],
drop_example: bool = False) -> Iterable[PredictionResult]:
if isinstance(predictions, dict):
# Go from one dictionary of type: {key_type1: Iterable<val_type1>,
# key_type2: Iterable<val_type2>, ...} where each Iterable is of
# length batch_size, to a list of dictionaries:
# [{key_type1: value_type1, key_type2: value_type2}]
predictions = [
dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
]
if drop_example:
return [PredictionResult(None, y) for x, y in zip(batch, predictions)]

return [PredictionResult(x, y) for x, y in zip(batch, predictions)]


class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load and apply an ML model."""
def load_model(self) -> ModelT:
Expand All @@ -92,15 +110,17 @@ def run_inference(
self,
batch: Sequence[ExampleT],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
inference_args: Optional[Dict[str, Any]] = None,
drop_example: Optional[bool] = False) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples.

Args:
batch: A sequence of examples or features.
model: The model used to make inferences.
inference_args: Extra arguments for models whose inference call requires
extra parameters.

drop_example: Boolean flag indicating whether to
drop the example from PredictionResult
Returns:
An Iterable of Predictions.
"""
Expand Down Expand Up @@ -170,7 +190,8 @@ def run_inference(
self,
batch: Sequence[Tuple[KeyT, ExampleT]],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
inference_args: Optional[Dict[str, Any]] = None,
drop_example: Optional[bool] = False
) -> Iterable[Tuple[KeyT, PredictionT]]:
keys, unkeyed_batch = zip(*batch)
return zip(
Expand Down Expand Up @@ -225,7 +246,8 @@ def run_inference(
self,
batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
inference_args: Optional[Dict[str, Any]] = None,
drop_example: Optional[bool] = False
) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
# Really the input should be
# Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
Expand Down Expand Up @@ -273,7 +295,9 @@ def __init__(
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock=time,
inference_args: Optional[Dict[str, Any]] = None,
metrics_namespace: Optional[str] = None):
metrics_namespace: Optional[str] = None,
drop_example: Optional[bool] = False,
):
"""A transform that takes a PCollection of examples (or features) to be used
on an ML model. It will then output inferences (or predictions) for those
examples in a PCollection of PredictionResults, containing the input
Expand All @@ -291,11 +315,13 @@ def __init__(
inference_args: Extra arguments for models whose inference call requires
extra parameters.
metrics_namespace: Namespace of the transform to collect metrics.
"""
drop_example: Boolean flag indicating whether to
drop the example from PredictionResult """
self._model_handler = model_handler
self._inference_args = inference_args
self._clock = clock
self._metrics_namespace = metrics_namespace
self._drop_example = drop_example

# TODO(BEAM-14046): Add and link to help documentation.
@classmethod
Expand Down Expand Up @@ -327,7 +353,10 @@ def expand(
| 'BeamML_RunInference' >> (
beam.ParDo(
_RunInferenceDoFn(
self._model_handler, self._clock, self._metrics_namespace),
model_handler=self._model_handler,
clock=self._clock,
metrics_namespace=self._metrics_namespace,
drop_example=self._drop_example),
self._inference_args).with_resource_hints(**resource_hints)))


Expand Down Expand Up @@ -385,19 +414,22 @@ def __init__(
self,
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock,
metrics_namespace):
metrics_namespace: Optional[str],
drop_example: Optional[bool] = False):
"""A DoFn implementation generic to frameworks.

Args:
model_handler: An implementation of ModelHandler.
clock: A clock implementing time_ns. *Used for unit testing.*
metrics_namespace: Namespace of the transform to collect metrics.
"""
drop_example: Boolean flag indicating whether to
drop the example from PredictionResult """
self._model_handler = model_handler
self._shared_model_handle = shared.Shared()
self._clock = clock
self._model = None
self._metrics_namespace = metrics_namespace
self._drop_example = drop_example

def _load_model(self):
def load():
Expand Down Expand Up @@ -427,7 +459,7 @@ def setup(self):
def process(self, batch, inference_args):
start_time = _to_microseconds(self._clock.time_ns())
result_generator = self._model_handler.run_inference(
batch, self._model, inference_args)
batch, self._model, inference_args, self._drop_example)
predictions = list(result_generator)

end_time = _to_microseconds(self._clock.time_ns())
Expand Down
49 changes: 45 additions & 4 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,38 @@ def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None) -> Iterable[int]:
inference_args=None,
drop_example=False) -> Iterable[int]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
for example in batch:
yield model.predict(example)


class FakeModelHandlerReturnsPredictionResult(
base.ModelHandler[int, base.PredictionResult, FakeModel]):
def __init__(self, clock=None):
self._fake_clock = clock

def load_model(self):
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
return FakeModel()

def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None,
drop_example=False) -> Iterable[base.PredictionResult]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds

predictions = [model.predict(example) for example in batch]
return base._convert_to_result(
batch=batch, predictions=predictions, drop_example=drop_example)


class FakeClock:
def __init__(self):
# Start at 10 seconds.
Expand All @@ -70,7 +95,8 @@ def process(self, prediction_result):


class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
def run_inference(
self, batch, model, inference_args=None, drop_example=False):
if len(batch) < 100:
raise ValueError('Unexpectedly small batch')
return batch
Expand All @@ -80,14 +106,16 @@ def batch_elements_kwargs(self):


class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
def run_inference(
self, batch, model, inference_args=None, drop_example=False):
raise ValueError(
'run_inference should not be called because error should already be '
'thrown from the validate_inference_args check.')


class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
def run_inference(
self, batch, model, inference_args=None, drop_example=False):
if not inference_args:
raise ValueError('inference_args should exist')
return batch
Expand Down Expand Up @@ -251,6 +279,19 @@ def test_run_inference_keyed_examples_with_unkeyed_model_handler(self):
| 'RunKeyed' >> base.RunInference(model_handler))
pipeline.run()

def test_drop_example_prediction_result(self):
def assert_drop_example(prediction_result):
assert prediction_result.example is None

pipeline = TestPipeline()
examples = [1, 3, 5]
model_handler = FakeModelHandlerReturnsPredictionResult()
_ = (
pipeline | 'keyed' >> beam.Create(examples)
| 'RunKeyed' >> base.RunInference(model_handler, drop_example=True)
| beam.Map(assert_drop_example))
pipeline.run()


if __name__ == '__main__':
unittest.main()
35 changes: 11 additions & 24 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import Union

import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import _convert_to_result
from apache_beam.utils.annotations import experimental

__all__ = [
Expand Down Expand Up @@ -83,23 +83,6 @@ def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
return examples


def _convert_to_result(
batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
) -> Iterable[PredictionResult]:
if isinstance(predictions, dict):
# Go from one dictionary of type: {key_type1: Iterable<val_type1>,
# key_type2: Iterable<val_type2>, ...} where each Iterable is of
# length batch_size, to a list of dictionaries:
# [{key_type1: value_type1, key_type2: value_type2}]
predictions_per_tensor = [
dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
]
return [
PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
]
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]


class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
PredictionResult,
torch.nn.Module]):
Expand Down Expand Up @@ -152,8 +135,8 @@ def run_inference(
self,
batch: Sequence[torch.Tensor],
model: torch.nn.Module,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
inference_args: Optional[Dict[str, Any]] = None,
drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Tensors and returns an Iterable of
Tensor Predictions.
Expand All @@ -170,7 +153,8 @@ def run_inference(
inference_args: Non-batchable arguments required as inputs to the model's
forward() function. Unlike Tensors in `batch`, these parameters will
not be dynamically batched

drop_example: Boolean flag indicating whether to
drop the example from PredictionResult
Returns:
An Iterable of type PredictionResult.
"""
Expand All @@ -182,7 +166,7 @@ def run_inference(
batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, self._device)
predictions = model(batched_tensors, **inference_args)
return _convert_to_result(batch, predictions)
return _convert_to_result(batch, predictions, drop_example=drop_example)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand Down Expand Up @@ -259,7 +243,8 @@ def run_inference(
self,
batch: Sequence[Dict[str, torch.Tensor]],
model: torch.nn.Module,
inference_args: Optional[Dict[str, Any]] = None
inference_args: Optional[Dict[str, Any]] = None,
drop_example: Optional[bool] = False,
) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Keyed Tensors and returns an Iterable of
Expand All @@ -277,6 +262,8 @@ def run_inference(
inference_args: Non-batchable arguments required as inputs to the model's
forward() function. Unlike Tensors in `batch`, these parameters will
not be dynamically batched
drop_example: Boolean flag indicating whether to
drop the example from PredictionResult

Returns:
An Iterable of type PredictionResult.
Expand All @@ -300,7 +287,7 @@ def run_inference(
key_to_batched_tensors[key] = batched_tensors
predictions = model(**key_to_batched_tensors, **inference_args)

return _convert_to_result(batch, predictions)
return _convert_to_result(batch, predictions, drop_example=drop_example)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand Down
Loading