Skip to content

Get rid of unnecessary cogbk when running offline detector. #34656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions sdks/python/apache_beam/ml/anomaly/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
Expand Down Expand Up @@ -427,8 +426,8 @@ class RunOfflineDetector(beam.PTransform[beam.PCollection[KeyedInputT],
def __init__(self, offline_detector: OfflineDetector):
self._offline_detector = offline_detector

def unnest_and_convert(
self, nested: Tuple[Tuple[Any, Any], dict[str, List]]) -> KeyedOutputT:
def restore_and_convert(
self, elem: Tuple[Tuple[Any, Any, beam.Row], float]) -> KeyedOutputT:
"""Unnests and converts the model output to AnomalyResult.

Args:
Expand All @@ -438,15 +437,14 @@ def unnest_and_convert(
Returns:
A tuple containing the original key and AnomalyResult.
"""
key, value_dict = nested
score = value_dict['output'][0]
(orig_key, temp_key, row), score = elem
result = AnomalyResult(
example=value_dict['input'][0],
example=row,
predictions=[
AnomalyPrediction(
model_id=self._offline_detector._model_id, score=score)
])
return key[0], (key[1], result)
return orig_key, (temp_key, result)

def expand(
self,
Expand All @@ -458,23 +456,18 @@ def expand(
self._offline_detector._keyed_model_handler,
**self._offline_detector._run_inference_args)

# ((orig_key, temp_key), beam.Row)
# ((orig_key, temp_key, beam.Row), beam.Row)
rekeyed_model_input = input | "Rekey" >> beam.Map(
lambda x: ((x[0], x[1][0]), x[1][1]))
lambda x: ((x[0], x[1][0], x[1][1]), x[1][1]))

# ((orig_key, temp_key), float)
# ((orig_key, temp_key, beam.Row), float)
rekeyed_model_output = (
rekeyed_model_input
| f"Call RunInference ({model_uuid})" >> run_inference)

# ((orig_key, temp_key), {'input':[row], 'output:[float]})
rekeyed_cogbk = {
'input': rekeyed_model_input, 'output': rekeyed_model_output
} | beam.CoGroupByKey()

ret = (
rekeyed_cogbk |
"Unnest and convert model output" >> beam.Map(self.unnest_and_convert))
rekeyed_model_output | "Restore keys and convert model output" >>
beam.Map(self.restore_and_convert))

if self._offline_detector._threshold_criterion:
ret = (
Expand Down
Loading