@@ -82,24 +82,6 @@ def _to_microseconds(time_ns: int) -> int:
82
82
return int (time_ns / _NANOSECOND_TO_MICROSECOND )
83
83
84
84
85
- def _convert_to_result (
86
- batch : Iterable ,
87
- predictions : Union [Iterable , Dict [Any , Iterable ]],
88
- drop_example : Optional [bool ] = False ) -> Iterable [PredictionResult ]:
89
- if isinstance (predictions , dict ):
90
- # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
91
- # key_type2: Iterable<val_type2>, ...} where each Iterable is of
92
- # length batch_size, to a list of dictionaries:
93
- # [{key_type1: value_type1, key_type2: value_type2}]
94
- predictions = [
95
- dict (zip (predictions .keys (), v )) for v in zip (* predictions .values ())
96
- ]
97
- if drop_example :
98
- return [PredictionResult (None , y ) for x , y in zip (batch , predictions )]
99
-
100
- return [PredictionResult (x , y ) for x , y in zip (batch , predictions )]
101
-
102
-
103
85
class ModelHandler (Generic [ExampleT , PredictionT , ModelT ]):
104
86
"""Has the ability to load and apply an ML model."""
105
87
def load_model (self ) -> ModelT :
@@ -110,17 +92,15 @@ def run_inference(
110
92
self ,
111
93
batch : Sequence [ExampleT ],
112
94
model : ModelT ,
113
- inference_args : Optional [Dict [str , Any ]] = None ,
114
- drop_example : Optional [bool ] = False ) -> Iterable [PredictionT ]:
95
+ inference_args : Optional [Dict [str , Any ]] = None ) -> Iterable [PredictionT ]:
115
96
"""Runs inferences on a batch of examples.
116
97
117
98
Args:
118
99
batch: A sequence of examples or features.
119
100
model: The model used to make inferences.
120
101
inference_args: Extra arguments for models whose inference call requires
121
102
extra parameters.
122
- drop_example: Boolean flag indicating whether to
123
- drop the example from PredictionResult
103
+
124
104
Returns:
125
105
An Iterable of Predictions.
126
106
"""
@@ -190,8 +170,7 @@ def run_inference(
190
170
self ,
191
171
batch : Sequence [Tuple [KeyT , ExampleT ]],
192
172
model : ModelT ,
193
- inference_args : Optional [Dict [str , Any ]] = None ,
194
- drop_example : Optional [bool ] = False
173
+ inference_args : Optional [Dict [str , Any ]] = None
195
174
) -> Iterable [Tuple [KeyT , PredictionT ]]:
196
175
keys , unkeyed_batch = zip (* batch )
197
176
return zip (
@@ -246,8 +225,7 @@ def run_inference(
246
225
self ,
247
226
batch : Sequence [Union [ExampleT , Tuple [KeyT , ExampleT ]]],
248
227
model : ModelT ,
249
- inference_args : Optional [Dict [str , Any ]] = None ,
250
- drop_example : Optional [bool ] = False
228
+ inference_args : Optional [Dict [str , Any ]] = None
251
229
) -> Union [Iterable [PredictionT ], Iterable [Tuple [KeyT , PredictionT ]]]:
252
230
# Really the input should be
253
231
# Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
@@ -295,9 +273,7 @@ def __init__(
295
273
model_handler : ModelHandler [ExampleT , PredictionT , Any ],
296
274
clock = time ,
297
275
inference_args : Optional [Dict [str , Any ]] = None ,
298
- metrics_namespace : Optional [str ] = None ,
299
- drop_example : Optional [bool ] = False ,
300
- ):
276
+ metrics_namespace : Optional [str ] = None ):
301
277
"""A transform that takes a PCollection of examples (or features) to be used
302
278
on an ML model. It will then output inferences (or predictions) for those
303
279
examples in a PCollection of PredictionResults, containing the input
@@ -315,13 +291,11 @@ def __init__(
315
291
inference_args: Extra arguments for models whose inference call requires
316
292
extra parameters.
317
293
metrics_namespace: Namespace of the transform to collect metrics.
318
- drop_example: Boolean flag indicating whether to
319
- drop the example from PredictionResult """
294
+ """
320
295
self ._model_handler = model_handler
321
296
self ._inference_args = inference_args
322
297
self ._clock = clock
323
298
self ._metrics_namespace = metrics_namespace
324
- self ._drop_example = drop_example
325
299
326
300
# TODO(BEAM-14046): Add and link to help documentation.
327
301
@classmethod
@@ -353,10 +327,7 @@ def expand(
353
327
| 'BeamML_RunInference' >> (
354
328
beam .ParDo (
355
329
_RunInferenceDoFn (
356
- model_handler = self ._model_handler ,
357
- clock = self ._clock ,
358
- metrics_namespace = self ._metrics_namespace ,
359
- drop_example = self ._drop_example ),
330
+ self ._model_handler , self ._clock , self ._metrics_namespace ),
360
331
self ._inference_args ).with_resource_hints (** resource_hints )))
361
332
362
333
@@ -414,22 +385,19 @@ def __init__(
414
385
self ,
415
386
model_handler : ModelHandler [ExampleT , PredictionT , Any ],
416
387
clock ,
417
- metrics_namespace : Optional [str ],
418
- drop_example : Optional [bool ] = False ):
388
+ metrics_namespace ):
419
389
"""A DoFn implementation generic to frameworks.
420
390
421
391
Args:
422
392
model_handler: An implementation of ModelHandler.
423
393
clock: A clock implementing time_ns. *Used for unit testing.*
424
394
metrics_namespace: Namespace of the transform to collect metrics.
425
- drop_example: Boolean flag indicating whether to
426
- drop the example from PredictionResult """
395
+ """
427
396
self ._model_handler = model_handler
428
397
self ._shared_model_handle = shared .Shared ()
429
398
self ._clock = clock
430
399
self ._model = None
431
400
self ._metrics_namespace = metrics_namespace
432
- self ._drop_example = drop_example
433
401
434
402
def _load_model (self ):
435
403
def load ():
@@ -459,7 +427,7 @@ def setup(self):
459
427
def process (self , batch , inference_args ):
460
428
start_time = _to_microseconds (self ._clock .time_ns ())
461
429
result_generator = self ._model_handler .run_inference (
462
- batch , self ._model , inference_args , self . _drop_example )
430
+ batch , self ._model , inference_args )
463
431
predictions = list (result_generator )
464
432
465
433
end_time = _to_microseconds (self ._clock .time_ns ())
0 commit comments