17
17
18
18
from typing import Any
19
19
from typing import Dict
20
+ from typing import SupportsFloat
21
+ from typing import SupportsInt
22
+ from typing import Tuple
23
+ from typing import TypeVar
20
24
from typing import Optional
21
25
22
26
import apache_beam as beam
23
27
from apache_beam .ml .anomaly .base import AnomalyDetector
28
+ from apache_beam .ml .anomaly .base import AnomalyPrediction
24
29
from apache_beam .ml .anomaly .specifiable import specifiable
25
30
from apache_beam .ml .inference .base import KeyedModelHandler
31
+ from apache_beam .ml .inference .base import PredictionResult
32
+ from apache_beam .ml .inference .base import PredictionT
33
+
34
+ KeyT = TypeVar ('KeyT' )
26
35
27
36
28
37
@specifiable
@@ -31,14 +40,66 @@ class OfflineDetector(AnomalyDetector):
31
40
32
41
Args:
33
42
keyed_model_handler: The model handler to use for inference.
34
- Requires a `KeyModelHandler[Any, Row, float , Any]` instance.
43
+ Requires a `KeyModelHandler[Any, Row, PredictionT , Any]` instance.
35
44
run_inference_args: Optional arguments to pass to RunInference
36
45
**kwargs: Additional keyword arguments to pass to the base
37
46
AnomalyDetector class.
38
47
"""
48
+ @staticmethod
49
+ def score_prediction_adapter (
50
+ keyed_prediction : Tuple [KeyT , PredictionResult ]
51
+ ) -> Tuple [KeyT , AnomalyPrediction ]:
52
+ """Extracts a float score from `PredictionResult.inference` and wraps it.
53
+
54
+ Takes a keyed `PredictionResult` from common ModelHandler output, assumes
55
+ its `inference` attribute is a float-convertible score, and returns the key
56
+ paired with an `AnomalyPrediction` containing that float score.
57
+
58
+ Args:
59
+ keyed_prediction: Tuple of `(key, PredictionResult)`. `PredictionResult`
60
+ must have an `inference` attribute supporting float conversion.
61
+
62
+ Returns:
63
+ Tuple of `(key, AnomalyPrediction)` with the extracted score.
64
+
65
+ Raises:
66
+ AssertionError: If `PredictionResult.inference` doesn't support float().
67
+ """
68
+
69
+ key , prediction = keyed_prediction
70
+ score = prediction .inference
71
+ assert isinstance (score , SupportsFloat )
72
+ return key , AnomalyPrediction (score = float (score ))
73
+
74
+ @staticmethod
75
+ def label_prediction_adapter (
76
+ keyed_prediction : Tuple [KeyT , PredictionResult ]
77
+ ) -> Tuple [KeyT , AnomalyPrediction ]:
78
+ """Extracts an integer label from `PredictionResult.inference` and wraps it.
79
+
80
+ Takes a keyed `PredictionResult`, assumes its `inference` attribute is an
81
+ integer-convertible label, and returns the key paired with an
82
+ `AnomalyPrediction` containing that integer label.
83
+
84
+ Args:
85
+ keyed_prediction: Tuple of `(key, PredictionResult)`. `PredictionResult`
86
+ must have an `inference` attribute supporting int conversion.
87
+
88
+ Returns:
89
+ Tuple of `(key, AnomalyPrediction)` with the extracted label.
90
+
91
+ Raises:
92
+ AssertionError: If `PredictionResult.inference` doesn't support int().
93
+ """
94
+
95
+ key , prediction = keyed_prediction
96
+ label = prediction .inference
97
+ assert isinstance (label , SupportsInt )
98
+ return key , AnomalyPrediction (label = int (label ))
99
+
39
100
def __init__ (
40
101
self ,
41
- keyed_model_handler : KeyedModelHandler [Any , beam .Row , float , Any ],
102
+ keyed_model_handler : KeyedModelHandler [Any , beam .Row , PredictionT , Any ],
42
103
run_inference_args : Optional [Dict [str , Any ]] = None ,
43
104
** kwargs ):
44
105
super ().__init__ (** kwargs )
0 commit comments