Skip to content

Add adapters to address output conversion for OfflineDetector. #34662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions sdks/python/apache_beam/ml/anomaly/detectors/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@
from typing import Any
from typing import Dict
from typing import Optional
from typing import SupportsFloat
from typing import SupportsInt
from typing import Tuple
from typing import TypeVar

import apache_beam as beam
from apache_beam.ml.anomaly.base import AnomalyDetector
from apache_beam.ml.anomaly.base import AnomalyPrediction
from apache_beam.ml.anomaly.specifiable import specifiable
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import PredictionT

KeyT = TypeVar('KeyT')


@specifiable
Expand All @@ -31,14 +40,66 @@ class OfflineDetector(AnomalyDetector):

Args:
keyed_model_handler: The model handler to use for inference.
Requires a `KeyModelHandler[Any, Row, float, Any]` instance.
Requires a `KeyModelHandler[Any, Row, PredictionT, Any]` instance.
run_inference_args: Optional arguments to pass to RunInference
**kwargs: Additional keyword arguments to pass to the base
AnomalyDetector class.
"""
@staticmethod
def score_prediction_adapter(
keyed_prediction: Tuple[KeyT, PredictionResult]
) -> Tuple[KeyT, AnomalyPrediction]:
"""Extracts a float score from `PredictionResult.inference` and wraps it.

Takes a keyed `PredictionResult` from common ModelHandler output, assumes
its `inference` attribute is a float-convertible score, and returns the key
paired with an `AnomalyPrediction` containing that float score.

Args:
keyed_prediction: Tuple of `(key, PredictionResult)`. `PredictionResult`
must have an `inference` attribute supporting float conversion.

Returns:
Tuple of `(key, AnomalyPrediction)` with the extracted score.

Raises:
AssertionError: If `PredictionResult.inference` doesn't support float().
"""

key, prediction = keyed_prediction
score = prediction.inference
assert isinstance(score, SupportsFloat)
return key, AnomalyPrediction(score=float(score))

@staticmethod
def label_prediction_adapter(
keyed_prediction: Tuple[KeyT, PredictionResult]
) -> Tuple[KeyT, AnomalyPrediction]:
"""Extracts an integer label from `PredictionResult.inference` and wraps it.

Takes a keyed `PredictionResult`, assumes its `inference` attribute is an
integer-convertible label, and returns the key paired with an
`AnomalyPrediction` containing that integer label.

Args:
keyed_prediction: Tuple of `(key, PredictionResult)`. `PredictionResult`
must have an `inference` attribute supporting int conversion.

Returns:
Tuple of `(key, AnomalyPrediction)` with the extracted label.

Raises:
AssertionError: If `PredictionResult.inference` doesn't support int().
"""

key, prediction = keyed_prediction
label = prediction.inference
assert isinstance(label, SupportsInt)
return key, AnomalyPrediction(label=int(label))

def __init__(
self,
keyed_model_handler: KeyedModelHandler[Any, beam.Row, float, Any],
keyed_model_handler: KeyedModelHandler[Any, beam.Row, PredictionT, Any],
run_inference_args: Optional[Dict[str, Any]] = None,
**kwargs):
super().__init__(**kwargs)
Expand Down
30 changes: 19 additions & 11 deletions sdks/python/apache_beam/ml/anomaly/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,23 +426,31 @@ class RunOfflineDetector(beam.PTransform[beam.PCollection[KeyedInputT],
def __init__(self, offline_detector: OfflineDetector):
self._offline_detector = offline_detector

def restore_and_convert(
self, elem: Tuple[Tuple[Any, Any, beam.Row], float]) -> KeyedOutputT:
"""Unnests and converts the model output to AnomalyResult.
def _restore_and_convert(
self, elem: Tuple[Tuple[Any, Any, beam.Row], Any]) -> KeyedOutputT:
"""Converts the model output to AnomalyResult.

Args:
nested: A tuple containing the combined key (origin key, temp key) and
a dictionary of input and output from RunInference.
elem: A tuple containing the combined key (original key, temp key, row)
and the output from RunInference.

Returns:
A tuple containing the original key and AnomalyResult.
A tuple containing the keyed AnomalyResult.
"""
(orig_key, temp_key, row), score = elem
(orig_key, temp_key, row), prediction = elem
assert isinstance(prediction, AnomalyPrediction), (
"Wrong model handler output type." +
f"Expected: 'AnomalyPrediction', but got '{type(prediction).__name__}'. " + # pylint: disable=line-too-long
"Consider adding a post-processing function via `with_postprocess_fn` " +
f"to convert from '{type(prediction).__name__}' to 'AnomalyPrediction', " + # pylint: disable=line-too-long
"or use `score_prediction_adapter` or `label_prediction_adapter` to " +
"perform the conversion.")

result = AnomalyResult(
example=row,
predictions=[
AnomalyPrediction(
model_id=self._offline_detector._model_id, score=score)
dataclasses.replace(
prediction, model_id=self._offline_detector._model_id)
])
return orig_key, (temp_key, result)

Expand All @@ -460,14 +468,14 @@ def expand(
rekeyed_model_input = input | "Rekey" >> beam.Map(
lambda x: ((x[0], x[1][0], x[1][1]), x[1][1]))

# ((orig_key, temp_key, beam.Row), float)
# ((orig_key, temp_key, beam.Row), AnomalyPrediction)
rekeyed_model_output = (
rekeyed_model_input
| f"Call RunInference ({model_uuid})" >> run_inference)

ret = (
rekeyed_model_output | "Restore keys and convert model output" >>
beam.Map(self.restore_and_convert))
beam.Map(self._restore_and_convert))

if self._offline_detector._threshold_criterion:
ret = (
Expand Down
25 changes: 4 additions & 21 deletions sdks/python/apache_beam/ml/anomaly/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@
import tempfile
import unittest
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import SupportsFloat
from typing import Tuple

import mock
import numpy
from sklearn.base import BaseEstimator

import apache_beam as beam
from apache_beam.ml.anomaly.aggregations import AnyVote
Expand All @@ -51,7 +46,6 @@
from apache_beam.ml.anomaly.transforms import _StatefulThresholdDoFn
from apache_beam.ml.anomaly.transforms import _StatelessThresholdDoFn
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.base import _PostProcessingModelHandler
from apache_beam.ml.inference.base import _PreProcessingModelHandler
Expand Down Expand Up @@ -287,23 +281,11 @@ def predict(self, input_vector: numpy.ndarray):
return [input_vector[0][0] * 10 - input_vector[0][1]]


def alternate_numpy_inference_fn(
model: BaseEstimator,
batch: Sequence[numpy.ndarray],
inference_args: Optional[Dict[str, Any]] = None) -> Any:
return [0]


def _to_keyed_numpy_array(t: Tuple[Any, beam.Row]):
"""Converts an Apache Beam Row to a NumPy array."""
return t[0], numpy.array(list(t[1]))


def _from_keyed_numpy_array(t: Tuple[Any, PredictionResult]):
assert isinstance(t[1].inference, SupportsFloat)
return t[0], float(t[1].inference)


class TestOfflineDetector(unittest.TestCase):
def setUp(self):
global SklearnModelHandlerNumpy, KeyedModelHandler
Expand All @@ -330,7 +312,8 @@ def test_default_inference_fn(self):

keyed_model_handler = KeyedModelHandler(
SklearnModelHandlerNumpy(model_uri=temp_file_name)).with_preprocess_fn(
_to_keyed_numpy_array).with_postprocess_fn(_from_keyed_numpy_array)
_to_keyed_numpy_array).with_postprocess_fn(
OfflineDetector.score_prediction_adapter)

detector = OfflineDetector(keyed_model_handler=keyed_model_handler)
detector_spec = detector.to_spec()
Expand All @@ -354,7 +337,7 @@ def test_default_inference_fn(self):
type='_to_keyed_numpy_array', config=None)
}),
'postprocess_fn': Spec(
type='_from_keyed_numpy_array', config=None)
type='score_prediction_adapter', config=None)
})
})
self.assertEqual(detector_spec, expected_spec)
Expand All @@ -363,7 +346,7 @@ def test_default_inference_fn(self):
self.assertEqual(_spec_type_to_subspace('_PreProcessingModelHandler'), '*')
self.assertEqual(_spec_type_to_subspace('_PostProcessingModelHandler'), '*')
self.assertEqual(_spec_type_to_subspace('_to_keyed_numpy_array'), '*')
self.assertEqual(_spec_type_to_subspace('_from_keyed_numpy_array'), '*')
self.assertEqual(_spec_type_to_subspace('score_prediction_adapter'), '*')

# Make sure the spec from the detector can be used to reconstruct the same
# detector
Expand Down
Loading