Skip to content

Commit 979fa37

Browse files
committed
Add adapters to address output conversion for OfflineDetector.
1 parent 38192de commit 979fa37

File tree

3 files changed

+83
-28
lines changed

3 files changed

+83
-28
lines changed

sdks/python/apache_beam/ml/anomaly/detectors/offline.py

+63-2
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,21 @@
1717

1818
from typing import Any
1919
from typing import Dict
20+
from typing import SupportsFloat
21+
from typing import SupportsInt
22+
from typing import Tuple
23+
from typing import TypeVar
2024
from typing import Optional
2125

2226
import apache_beam as beam
2327
from apache_beam.ml.anomaly.base import AnomalyDetector
28+
from apache_beam.ml.anomaly.base import AnomalyPrediction
2429
from apache_beam.ml.anomaly.specifiable import specifiable
2530
from apache_beam.ml.inference.base import KeyedModelHandler
31+
from apache_beam.ml.inference.base import PredictionResult
32+
from apache_beam.ml.inference.base import PredictionT
33+
34+
KeyT = TypeVar('KeyT')
2635

2736

2837
@specifiable
@@ -31,14 +40,66 @@ class OfflineDetector(AnomalyDetector):
3140
3241
Args:
3342
keyed_model_handler: The model handler to use for inference.
34-
Requires a `KeyModelHandler[Any, Row, float, Any]` instance.
43+
Requires a `KeyModelHandler[Any, Row, PredictionT, Any]` instance.
3544
run_inference_args: Optional arguments to pass to RunInference
3645
**kwargs: Additional keyword arguments to pass to the base
3746
AnomalyDetector class.
3847
"""
48+
@staticmethod
49+
def score_prediction_adapter(
50+
keyed_prediction: Tuple[KeyT, PredictionResult]
51+
) -> Tuple[KeyT, AnomalyPrediction]:
52+
"""Extracts a float score from `PredictionResult.inference` and wraps it.
53+
54+
Takes a keyed `PredictionResult` from common ModelHandler output, assumes
55+
its `inference` attribute is a float-convertible score, and returns the key
56+
paired with an `AnomalyPrediction` containing that float score.
57+
58+
Args:
59+
keyed_prediction: Tuple of `(key, PredictionResult)`. `PredictionResult`
60+
must have an `inference` attribute supporting float conversion.
61+
62+
Returns:
63+
Tuple of `(key, AnomalyPrediction)` with the extracted score.
64+
65+
Raises:
66+
AssertionError: If `PredictionResult.inference` doesn't support float().
67+
"""
68+
69+
key, prediction = keyed_prediction
70+
score = prediction.inference
71+
assert isinstance(score, SupportsFloat)
72+
return key, AnomalyPrediction(score=float(score))
73+
74+
@staticmethod
75+
def label_prediction_adapter(
76+
keyed_prediction: Tuple[KeyT, PredictionResult]
77+
) -> Tuple[KeyT, AnomalyPrediction]:
78+
"""Extracts an integer label from `PredictionResult.inference` and wraps it.
79+
80+
Takes a keyed `PredictionResult`, assumes its `inference` attribute is an
81+
integer-convertible label, and returns the key paired with an
82+
`AnomalyPrediction` containing that integer label.
83+
84+
Args:
85+
keyed_prediction: Tuple of `(key, PredictionResult)`. `PredictionResult`
86+
must have an `inference` attribute supporting int conversion.
87+
88+
Returns:
89+
Tuple of `(key, AnomalyPrediction)` with the extracted label.
90+
91+
Raises:
92+
AssertionError: If `PredictionResult.inference` doesn't support int().
93+
"""
94+
95+
key, prediction = keyed_prediction
96+
label = prediction.inference
97+
assert isinstance(label, SupportsInt)
98+
return key, AnomalyPrediction(label=int(label))
99+
39100
def __init__(
40101
self,
41-
keyed_model_handler: KeyedModelHandler[Any, beam.Row, float, Any],
102+
keyed_model_handler: KeyedModelHandler[Any, beam.Row, PredictionT, Any],
42103
run_inference_args: Optional[Dict[str, Any]] = None,
43104
**kwargs):
44105
super().__init__(**kwargs)

sdks/python/apache_beam/ml/anomaly/transforms.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -426,23 +426,28 @@ class RunOfflineDetector(beam.PTransform[beam.PCollection[KeyedInputT],
426426
def __init__(self, offline_detector: OfflineDetector):
427427
self._offline_detector = offline_detector
428428

429-
def restore_and_convert(
430-
self, elem: Tuple[Tuple[Any, Any, beam.Row], float]) -> KeyedOutputT:
431-
"""Unnests and converts the model output to AnomalyResult.
429+
def _restore_and_convert(
430+
self, elem: Tuple[Tuple[Any, Any, beam.Row], Any]) -> KeyedOutputT:
431+
"""Converts the model output to AnomalyResult.
432432
433433
Args:
434-
nested: A tuple containing the combined key (origin key, temp key) and
435-
a dictionary of input and output from RunInference.
434+
elem: A tuple containing the combined key (original key, temp key, row)
435+
and the output from RunInference.
436436
437437
Returns:
438-
A tuple containing the original key and AnomalyResult.
438+
A tuple containing the keyed AnomalyResult.
439439
"""
440-
(orig_key, temp_key, row), score = elem
440+
(orig_key, temp_key, row), prediction = elem
441+
assert isinstance(prediction, AnomalyPrediction), (
442+
"Wrong model handler output type." +
443+
f"Expected: 'AnomalyPrediction', but got '{type(prediction).__name__}'. " + # pylint: disable=line-too-long
444+
"Consider adding a post-processing function via `with_postprocess_fn`.")
445+
441446
result = AnomalyResult(
442447
example=row,
443448
predictions=[
444-
AnomalyPrediction(
445-
model_id=self._offline_detector._model_id, score=score)
449+
dataclasses.replace(
450+
prediction, model_id=self._offline_detector._model_id)
446451
])
447452
return orig_key, (temp_key, result)
448453

@@ -460,14 +465,14 @@ def expand(
460465
rekeyed_model_input = input | "Rekey" >> beam.Map(
461466
lambda x: ((x[0], x[1][0], x[1][1]), x[1][1]))
462467

463-
# ((orig_key, temp_key, beam.Row), float)
468+
# ((orig_key, temp_key, beam.Row), AnomalyPrediction)
464469
rekeyed_model_output = (
465470
rekeyed_model_input
466471
| f"Call RunInference ({model_uuid})" >> run_inference)
467472

468473
ret = (
469474
rekeyed_model_output | "Restore keys and convert model output" >>
470-
beam.Map(self.restore_and_convert))
475+
beam.Map(self._restore_and_convert))
471476

472477
if self._offline_detector._threshold_criterion:
473478
ret = (

sdks/python/apache_beam/ml/anomaly/transforms_test.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -287,23 +287,11 @@ def predict(self, input_vector: numpy.ndarray):
287287
return [input_vector[0][0] * 10 - input_vector[0][1]]
288288

289289

290-
def alternate_numpy_inference_fn(
291-
model: BaseEstimator,
292-
batch: Sequence[numpy.ndarray],
293-
inference_args: Optional[Dict[str, Any]] = None) -> Any:
294-
return [0]
295-
296-
297290
def _to_keyed_numpy_array(t: Tuple[Any, beam.Row]):
298291
"""Converts an Apache Beam Row to a NumPy array."""
299292
return t[0], numpy.array(list(t[1]))
300293

301294

302-
def _from_keyed_numpy_array(t: Tuple[Any, PredictionResult]):
303-
assert isinstance(t[1].inference, SupportsFloat)
304-
return t[0], float(t[1].inference)
305-
306-
307295
class TestOfflineDetector(unittest.TestCase):
308296
def setUp(self):
309297
global SklearnModelHandlerNumpy, KeyedModelHandler
@@ -330,7 +318,8 @@ def test_default_inference_fn(self):
330318

331319
keyed_model_handler = KeyedModelHandler(
332320
SklearnModelHandlerNumpy(model_uri=temp_file_name)).with_preprocess_fn(
333-
_to_keyed_numpy_array).with_postprocess_fn(_from_keyed_numpy_array)
321+
_to_keyed_numpy_array).with_postprocess_fn(
322+
OfflineDetector.score_prediction_adapter)
334323

335324
detector = OfflineDetector(keyed_model_handler=keyed_model_handler)
336325
detector_spec = detector.to_spec()
@@ -354,7 +343,7 @@ def test_default_inference_fn(self):
354343
type='_to_keyed_numpy_array', config=None)
355344
}),
356345
'postprocess_fn': Spec(
357-
type='_from_keyed_numpy_array', config=None)
346+
type='score_prediction_adapter', config=None)
358347
})
359348
})
360349
self.assertEqual(detector_spec, expected_spec)
@@ -363,7 +352,7 @@ def test_default_inference_fn(self):
363352
self.assertEqual(_spec_type_to_subspace('_PreProcessingModelHandler'), '*')
364353
self.assertEqual(_spec_type_to_subspace('_PostProcessingModelHandler'), '*')
365354
self.assertEqual(_spec_type_to_subspace('_to_keyed_numpy_array'), '*')
366-
self.assertEqual(_spec_type_to_subspace('_from_keyed_numpy_array'), '*')
355+
self.assertEqual(_spec_type_to_subspace('score_prediction_adapter'), '*')
367356

368357
# Make sure the spec from the detector can be used to reconstruct the same
369358
# detector

0 commit comments

Comments
 (0)