Skip to content

Commit 72237d6

Browse files
authored
Revert "Add drop_example flag to the RunInference and Model Handler (#23266)" (#23392)
This reverts commit f477b85.
1 parent b5ed548 commit 72237d6

File tree

5 files changed

+70
-118
lines changed

5 files changed

+70
-118
lines changed

sdks/python/apache_beam/ml/inference/base.py

+10-42
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,6 @@ def _to_microseconds(time_ns: int) -> int:
8282
return int(time_ns / _NANOSECOND_TO_MICROSECOND)
8383

8484

85-
def _convert_to_result(
86-
batch: Iterable,
87-
predictions: Union[Iterable, Dict[Any, Iterable]],
88-
drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
89-
if isinstance(predictions, dict):
90-
# Go from one dictionary of type: {key_type1: Iterable<val_type1>,
91-
# key_type2: Iterable<val_type2>, ...} where each Iterable is of
92-
# length batch_size, to a list of dictionaries:
93-
# [{key_type1: value_type1, key_type2: value_type2}]
94-
predictions = [
95-
dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
96-
]
97-
if drop_example:
98-
return [PredictionResult(None, y) for x, y in zip(batch, predictions)]
99-
100-
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
101-
102-
10385
class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
10486
"""Has the ability to load and apply an ML model."""
10587
def load_model(self) -> ModelT:
@@ -110,17 +92,15 @@ def run_inference(
11092
self,
11193
batch: Sequence[ExampleT],
11294
model: ModelT,
113-
inference_args: Optional[Dict[str, Any]] = None,
114-
drop_example: Optional[bool] = False) -> Iterable[PredictionT]:
95+
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
11596
"""Runs inferences on a batch of examples.
11697
11798
Args:
11899
batch: A sequence of examples or features.
119100
model: The model used to make inferences.
120101
inference_args: Extra arguments for models whose inference call requires
121102
extra parameters.
122-
drop_example: Boolean flag indicating whether to
123-
drop the example from PredictionResult
103+
124104
Returns:
125105
An Iterable of Predictions.
126106
"""
@@ -190,8 +170,7 @@ def run_inference(
190170
self,
191171
batch: Sequence[Tuple[KeyT, ExampleT]],
192172
model: ModelT,
193-
inference_args: Optional[Dict[str, Any]] = None,
194-
drop_example: Optional[bool] = False
173+
inference_args: Optional[Dict[str, Any]] = None
195174
) -> Iterable[Tuple[KeyT, PredictionT]]:
196175
keys, unkeyed_batch = zip(*batch)
197176
return zip(
@@ -246,8 +225,7 @@ def run_inference(
246225
self,
247226
batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
248227
model: ModelT,
249-
inference_args: Optional[Dict[str, Any]] = None,
250-
drop_example: Optional[bool] = False
228+
inference_args: Optional[Dict[str, Any]] = None
251229
) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
252230
# Really the input should be
253231
# Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
@@ -295,9 +273,7 @@ def __init__(
295273
model_handler: ModelHandler[ExampleT, PredictionT, Any],
296274
clock=time,
297275
inference_args: Optional[Dict[str, Any]] = None,
298-
metrics_namespace: Optional[str] = None,
299-
drop_example: Optional[bool] = False,
300-
):
276+
metrics_namespace: Optional[str] = None):
301277
"""A transform that takes a PCollection of examples (or features) to be used
302278
on an ML model. It will then output inferences (or predictions) for those
303279
examples in a PCollection of PredictionResults, containing the input
@@ -315,13 +291,11 @@ def __init__(
315291
inference_args: Extra arguments for models whose inference call requires
316292
extra parameters.
317293
metrics_namespace: Namespace of the transform to collect metrics.
318-
drop_example: Boolean flag indicating whether to
319-
drop the example from PredictionResult """
294+
"""
320295
self._model_handler = model_handler
321296
self._inference_args = inference_args
322297
self._clock = clock
323298
self._metrics_namespace = metrics_namespace
324-
self._drop_example = drop_example
325299

326300
# TODO(BEAM-14046): Add and link to help documentation.
327301
@classmethod
@@ -353,10 +327,7 @@ def expand(
353327
| 'BeamML_RunInference' >> (
354328
beam.ParDo(
355329
_RunInferenceDoFn(
356-
model_handler=self._model_handler,
357-
clock=self._clock,
358-
metrics_namespace=self._metrics_namespace,
359-
drop_example=self._drop_example),
330+
self._model_handler, self._clock, self._metrics_namespace),
360331
self._inference_args).with_resource_hints(**resource_hints)))
361332

362333

@@ -414,22 +385,19 @@ def __init__(
414385
self,
415386
model_handler: ModelHandler[ExampleT, PredictionT, Any],
416387
clock,
417-
metrics_namespace: Optional[str],
418-
drop_example: Optional[bool] = False):
388+
metrics_namespace):
419389
"""A DoFn implementation generic to frameworks.
420390
421391
Args:
422392
model_handler: An implementation of ModelHandler.
423393
clock: A clock implementing time_ns. *Used for unit testing.*
424394
metrics_namespace: Namespace of the transform to collect metrics.
425-
drop_example: Boolean flag indicating whether to
426-
drop the example from PredictionResult """
395+
"""
427396
self._model_handler = model_handler
428397
self._shared_model_handle = shared.Shared()
429398
self._clock = clock
430399
self._model = None
431400
self._metrics_namespace = metrics_namespace
432-
self._drop_example = drop_example
433401

434402
def _load_model(self):
435403
def load():
@@ -459,7 +427,7 @@ def setup(self):
459427
def process(self, batch, inference_args):
460428
start_time = _to_microseconds(self._clock.time_ns())
461429
result_generator = self._model_handler.run_inference(
462-
batch, self._model, inference_args, self._drop_example)
430+
batch, self._model, inference_args)
463431
predictions = list(result_generator)
464432

465433
end_time = _to_microseconds(self._clock.time_ns())

sdks/python/apache_beam/ml/inference/base_test.py

+4-45
Original file line numberDiff line numberDiff line change
@@ -48,38 +48,13 @@ def run_inference(
4848
self,
4949
batch: Sequence[int],
5050
model: FakeModel,
51-
inference_args=None,
52-
drop_example=False) -> Iterable[int]:
51+
inference_args=None) -> Iterable[int]:
5352
if self._fake_clock:
5453
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
5554
for example in batch:
5655
yield model.predict(example)
5756

5857

59-
class FakeModelHandlerReturnsPredictionResult(
60-
base.ModelHandler[int, base.PredictionResult, FakeModel]):
61-
def __init__(self, clock=None):
62-
self._fake_clock = clock
63-
64-
def load_model(self):
65-
if self._fake_clock:
66-
self._fake_clock.current_time_ns += 500_000_000 # 500ms
67-
return FakeModel()
68-
69-
def run_inference(
70-
self,
71-
batch: Sequence[int],
72-
model: FakeModel,
73-
inference_args=None,
74-
drop_example=False) -> Iterable[base.PredictionResult]:
75-
if self._fake_clock:
76-
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
77-
78-
predictions = [model.predict(example) for example in batch]
79-
return base._convert_to_result(
80-
batch=batch, predictions=predictions, drop_example=drop_example)
81-
82-
8358
class FakeClock:
8459
def __init__(self):
8560
# Start at 10 seconds.
@@ -95,8 +70,7 @@ def process(self, prediction_result):
9570

9671

9772
class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
98-
def run_inference(
99-
self, batch, model, inference_args=None, drop_example=False):
73+
def run_inference(self, batch, unused_model, inference_args=None):
10074
if len(batch) < 100:
10175
raise ValueError('Unexpectedly small batch')
10276
return batch
@@ -106,16 +80,14 @@ def batch_elements_kwargs(self):
10680

10781

10882
class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
109-
def run_inference(
110-
self, batch, model, inference_args=None, drop_example=False):
83+
def run_inference(self, batch, unused_model, inference_args=None):
11184
raise ValueError(
11285
'run_inference should not be called because error should already be '
11386
'thrown from the validate_inference_args check.')
11487

11588

11689
class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
117-
def run_inference(
118-
self, batch, model, inference_args=None, drop_example=False):
90+
def run_inference(self, batch, unused_model, inference_args=None):
11991
if not inference_args:
12092
raise ValueError('inference_args should exist')
12193
return batch
@@ -279,19 +251,6 @@ def test_run_inference_keyed_examples_with_unkeyed_model_handler(self):
279251
| 'RunKeyed' >> base.RunInference(model_handler))
280252
pipeline.run()
281253

282-
def test_drop_example_prediction_result(self):
283-
def assert_drop_example(prediction_result):
284-
assert prediction_result.example is None
285-
286-
pipeline = TestPipeline()
287-
examples = [1, 3, 5]
288-
model_handler = FakeModelHandlerReturnsPredictionResult()
289-
_ = (
290-
pipeline | 'keyed' >> beam.Create(examples)
291-
| 'RunKeyed' >> base.RunInference(model_handler, drop_example=True)
292-
| beam.Map(assert_drop_example))
293-
pipeline.run()
294-
295254

296255
if __name__ == '__main__':
297256
unittest.main()

sdks/python/apache_beam/ml/inference/pytorch_inference.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from typing import Iterable
2626
from typing import Optional
2727
from typing import Sequence
28+
from typing import Union
2829

2930
import torch
3031
from apache_beam.io.filesystems import FileSystems
3132
from apache_beam.ml.inference.base import ModelHandler
3233
from apache_beam.ml.inference.base import PredictionResult
33-
from apache_beam.ml.inference.base import _convert_to_result
3434
from apache_beam.utils.annotations import experimental
3535

3636
__all__ = [
@@ -83,6 +83,23 @@ def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
8383
return examples
8484

8585

86+
def _convert_to_result(
87+
batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
88+
) -> Iterable[PredictionResult]:
89+
if isinstance(predictions, dict):
90+
# Go from one dictionary of type: {key_type1: Iterable<val_type1>,
91+
# key_type2: Iterable<val_type2>, ...} where each Iterable is of
92+
# length batch_size, to a list of dictionaries:
93+
# [{key_type1: value_type1, key_type2: value_type2}]
94+
predictions_per_tensor = [
95+
dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
96+
]
97+
return [
98+
PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
99+
]
100+
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
101+
102+
86103
class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
87104
PredictionResult,
88105
torch.nn.Module]):
@@ -135,8 +152,8 @@ def run_inference(
135152
self,
136153
batch: Sequence[torch.Tensor],
137154
model: torch.nn.Module,
138-
inference_args: Optional[Dict[str, Any]] = None,
139-
drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
155+
inference_args: Optional[Dict[str, Any]] = None
156+
) -> Iterable[PredictionResult]:
140157
"""
141158
Runs inferences on a batch of Tensors and returns an Iterable of
142159
Tensor Predictions.
@@ -153,8 +170,7 @@ def run_inference(
153170
inference_args: Non-batchable arguments required as inputs to the model's
154171
forward() function. Unlike Tensors in `batch`, these parameters will
155172
not be dynamically batched
156-
drop_example: Boolean flag indicating whether to
157-
drop the example from PredictionResult
173+
158174
Returns:
159175
An Iterable of type PredictionResult.
160176
"""
@@ -166,7 +182,7 @@ def run_inference(
166182
batched_tensors = torch.stack(batch)
167183
batched_tensors = _convert_to_device(batched_tensors, self._device)
168184
predictions = model(batched_tensors, **inference_args)
169-
return _convert_to_result(batch, predictions, drop_example=drop_example)
185+
return _convert_to_result(batch, predictions)
170186

171187
def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
172188
"""
@@ -243,8 +259,7 @@ def run_inference(
243259
self,
244260
batch: Sequence[Dict[str, torch.Tensor]],
245261
model: torch.nn.Module,
246-
inference_args: Optional[Dict[str, Any]] = None,
247-
drop_example: Optional[bool] = False,
262+
inference_args: Optional[Dict[str, Any]] = None
248263
) -> Iterable[PredictionResult]:
249264
"""
250265
Runs inferences on a batch of Keyed Tensors and returns an Iterable of
@@ -262,8 +277,6 @@ def run_inference(
262277
inference_args: Non-batchable arguments required as inputs to the model's
263278
forward() function. Unlike Tensors in `batch`, these parameters will
264279
not be dynamically batched
265-
drop_example: Boolean flag indicating whether to
266-
drop the example from PredictionResult
267280
268281
Returns:
269282
An Iterable of type PredictionResult.
@@ -287,7 +300,7 @@ def run_inference(
287300
key_to_batched_tensors[key] = batched_tensors
288301
predictions = model(**key_to_batched_tensors, **inference_args)
289302

290-
return _convert_to_result(batch, predictions, drop_example=drop_example)
303+
return _convert_to_result(batch, predictions)
291304

292305
def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
293306
"""

0 commit comments

Comments
 (0)