Skip to content

Commit 81642eb

Browse files
authored
Introduce RemoteModelHandler abstract base class (apache#34379)
* Stash first functioning version * unit tests * Unit tests, documentation * linting and formatting * adjust line-too-long * move external doc to shortlink * create_client + __init_subclass__ * fix exception message to match symbol change * Fix linting * make RemoteModelHandler an explicit abstract base class
1 parent fdb285d commit 81642eb

File tree

3 files changed

+326
-1
lines changed

3 files changed

+326
-1
lines changed

sdks/python/apache_beam/ml/inference/base.py

+140
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@
2727
collection, sharing model between threads, and batching elements.
2828
"""
2929

30+
import functools
3031
import logging
3132
import os
3233
import pickle
3334
import sys
3435
import threading
3536
import time
3637
import uuid
38+
from abc import ABC
39+
from abc import abstractmethod
3740
from collections import OrderedDict
3841
from collections import defaultdict
3942
from copy import deepcopy
@@ -56,7 +59,10 @@
5659
from typing import Union
5760

5861
import apache_beam as beam
62+
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
63+
from apache_beam.metrics.metric import Metrics
5964
from apache_beam.utils import multi_process_shared
65+
from apache_beam.utils import retry
6066
from apache_beam.utils import shared
6167

6268
try:
@@ -67,6 +73,7 @@
6773

6874
_NANOSECOND_TO_MILLISECOND = 1_000_000
6975
_NANOSECOND_TO_MICROSECOND = 1_000
76+
_MILLISECOND_TO_SECOND = 1_000
7077

7178
ModelT = TypeVar('ModelT')
7279
ExampleT = TypeVar('ExampleT')
@@ -339,6 +346,139 @@ def should_garbage_collect_on_timeout(self) -> bool:
339346
return self.share_model_across_processes()
340347

341348

349+
class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]):
350+
"""Has the ability to call a model at a remote endpoint."""
351+
def __init__(
352+
self,
353+
namespace: str = '',
354+
num_retries: int = 5,
355+
throttle_delay_secs: int = 5,
356+
retry_filter: Callable[[Exception], bool] = lambda x: True,
357+
*,
358+
window_ms: int = 1 * _MILLISECOND_TO_SECOND,
359+
bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
360+
overload_ratio: float = 2):
361+
"""Initializes metrics tracking + an AdaptiveThrottler class for enabling
362+
client-side throttling for remote calls to an inference service.
363+
See https://s.apache.org/beam-client-side-throttling for more details
364+
on the configuration of the throttling and retry
365+
mechanics.
366+
367+
Args:
368+
namespace: the metrics and logging namespace
369+
num_retries: the maximum number of times to retry a request on retriable
370+
errors before failing
371+
throttle_delay_secs: the amount of time to throttle when the client-side
372+
elects to throttle
373+
retry_filter: a function accepting an exception as an argument and
374+
returning a boolean. On a true return, the run_inference call will
375+
be retried. Defaults to always retrying.
376+
window_ms: length of history to consider, in ms, to set throttling.
377+
bucket_ms: granularity of time buckets that we store data in, in ms.
378+
overload_ratio: the target ratio between requests sent and successful
379+
requests. This is "K" in the formula in
380+
https://landing.google.com/sre/book/chapters/handling-overload.html.
381+
"""
382+
# Configure AdaptiveThrottler and throttling metrics for client-side
383+
# throttling behavior.
384+
self.throttled_secs = Metrics.counter(
385+
namespace, "cumulativeThrottlingSeconds")
386+
self.throttler = AdaptiveThrottler(
387+
window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio)
388+
self.logger = logging.getLogger(namespace)
389+
390+
self.num_retries = num_retries
391+
self.throttle_delay_secs = throttle_delay_secs
392+
self.retry_filter = retry_filter
393+
394+
def __init_subclass__(cls):
395+
if cls.load_model is not RemoteModelHandler.load_model:
396+
raise Exception(
397+
"Cannot override RemoteModelHandler.load_model, ",
398+
"implement create_client instead.")
399+
if cls.run_inference is not RemoteModelHandler.run_inference:
400+
raise Exception(
401+
"Cannot override RemoteModelHandler.run_inference, ",
402+
"implement request instead.")
403+
404+
@abstractmethod
405+
def create_client(self) -> ModelT:
406+
"""Creates the client that is used to make the remote inference request
407+
in request(). All relevant arguments should be passed to __init__().
408+
"""
409+
raise NotImplementedError(type(self))
410+
411+
def load_model(self) -> ModelT:
412+
return self.create_client()
413+
414+
def retry_on_exception(func):
415+
@functools.wraps(func)
416+
def wrapper(self, *args, **kwargs):
417+
return retry.with_exponential_backoff(
418+
num_retries=self.num_retries,
419+
retry_filter=self.retry_filter)(func)(self, *args, **kwargs)
420+
421+
return wrapper
422+
423+
@retry_on_exception
424+
def run_inference(
425+
self,
426+
batch: Sequence[ExampleT],
427+
model: ModelT,
428+
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
429+
"""Runs inferences on a batch of examples. Calls a remote model for
430+
predictions and will retry if a retryable exception is raised.
431+
432+
Args:
433+
batch: A sequence of examples or features.
434+
model: The model used to make inferences.
435+
inference_args: Extra arguments for models whose inference call requires
436+
extra parameters.
437+
438+
Returns:
439+
An Iterable of Predictions.
440+
"""
441+
while self.throttler.throttle_request(time.time() * _MILLISECOND_TO_SECOND):
442+
self.logger.info(
443+
"Delaying request for %d seconds due to previous failures",
444+
self.throttle_delay_secs)
445+
time.sleep(self.throttle_delay_secs)
446+
self.throttled_secs.inc(self.throttle_delay_secs)
447+
448+
try:
449+
req_time = time.time()
450+
predictions = self.request(batch, model, inference_args)
451+
self.throttler.successful_request(req_time * _MILLISECOND_TO_SECOND)
452+
return predictions
453+
except Exception as e:
454+
self.logger.error("exception raised as part of request, got %s", e)
455+
raise
456+
457+
@abstractmethod
458+
def request(
459+
self,
460+
batch: Sequence[ExampleT],
461+
model: ModelT,
462+
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
463+
"""Makes a request to a remote inference service and returns the response.
464+
Should raise an exception of some kind if there is an error to enable the
465+
retry and client-side throttling logic to work. Returns an iterable of the
466+
desired prediction type. This method should return the values directly, as
467+
handling return values as a generator can prevent the retry logic from
468+
functioning correctly.
469+
470+
Args:
471+
batch: A sequence of examples or features.
472+
model: The model used to make inferences.
473+
inference_args: Extra arguments for models whose inference call requires
474+
extra parameters.
475+
476+
Returns:
477+
An Iterable of Predictions.
478+
"""
479+
raise NotImplementedError(type(self))
480+
481+
342482
class _ModelManager:
343483
"""
344484
A class for efficiently managing copies of multiple models. Will load a

sdks/python/apache_beam/ml/inference/base_test.py

+183
Original file line numberDiff line numberDiff line change
@@ -1870,5 +1870,188 @@ def test_model_status_provides_valid_garbage_collection(self):
18701870
self.assertEqual(0, len(tags))
18711871

18721872

1873+
def _always_retry(e: Exception) -> bool:
1874+
return True
1875+
1876+
1877+
class FakeRemoteModelHandler(base.RemoteModelHandler[int, int, FakeModel]):
1878+
def __init__(
1879+
self,
1880+
clock=None,
1881+
min_batch_size=1,
1882+
max_batch_size=9999,
1883+
retry_filter=_always_retry,
1884+
**kwargs):
1885+
self._fake_clock = clock
1886+
self._min_batch_size = min_batch_size
1887+
self._max_batch_size = max_batch_size
1888+
self._env_vars = kwargs.get('env_vars', {})
1889+
self._multi_process_shared = multi_process_shared
1890+
super().__init__(
1891+
namespace='FakeRemoteModelHandler', retry_filter=retry_filter)
1892+
1893+
def create_client(self):
1894+
return FakeModel()
1895+
1896+
def request(self, batch, model, inference_args=None) -> Iterable[int]:
1897+
responses = []
1898+
for example in batch:
1899+
responses.append(model.predict(example))
1900+
return responses
1901+
1902+
def batch_elements_kwargs(self):
1903+
return {
1904+
'min_batch_size': self._min_batch_size,
1905+
'max_batch_size': self._max_batch_size
1906+
}
1907+
1908+
1909+
class FakeAlwaysFailsRemoteModelHandler(base.RemoteModelHandler[int,
1910+
int,
1911+
FakeModel]):
1912+
def __init__(
1913+
self,
1914+
clock=None,
1915+
min_batch_size=1,
1916+
max_batch_size=9999,
1917+
retry_filter=_always_retry,
1918+
**kwargs):
1919+
self._fake_clock = clock
1920+
self._min_batch_size = min_batch_size
1921+
self._max_batch_size = max_batch_size
1922+
self._env_vars = kwargs.get('env_vars', {})
1923+
super().__init__(
1924+
namespace='FakeRemoteModelHandler',
1925+
retry_filter=retry_filter,
1926+
num_retries=2,
1927+
throttle_delay_secs=1)
1928+
1929+
def create_client(self):
1930+
return FakeModel()
1931+
1932+
def request(self, batch, model, inference_args=None) -> Iterable[int]:
1933+
raise Exception
1934+
1935+
def batch_elements_kwargs(self):
1936+
return {
1937+
'min_batch_size': self._min_batch_size,
1938+
'max_batch_size': self._max_batch_size
1939+
}
1940+
1941+
1942+
class FakeFailsOnceRemoteModelHandler(base.RemoteModelHandler[int,
1943+
int,
1944+
FakeModel]):
1945+
def __init__(
1946+
self,
1947+
clock=None,
1948+
min_batch_size=1,
1949+
max_batch_size=9999,
1950+
retry_filter=_always_retry,
1951+
**kwargs):
1952+
self._fake_clock = clock
1953+
self._min_batch_size = min_batch_size
1954+
self._max_batch_size = max_batch_size
1955+
self._env_vars = kwargs.get('env_vars', {})
1956+
self._should_fail = True
1957+
super().__init__(
1958+
namespace='FakeRemoteModelHandler',
1959+
retry_filter=retry_filter,
1960+
num_retries=2,
1961+
throttle_delay_secs=1)
1962+
1963+
def create_client(self):
1964+
return FakeModel()
1965+
1966+
def request(self, batch, model, inference_args=None) -> Iterable[int]:
1967+
if self._should_fail:
1968+
self._should_fail = False
1969+
raise Exception
1970+
else:
1971+
self._should_fail = True
1972+
responses = []
1973+
for example in batch:
1974+
responses.append(model.predict(example))
1975+
return responses
1976+
1977+
def batch_elements_kwargs(self):
1978+
return {
1979+
'min_batch_size': self._min_batch_size,
1980+
'max_batch_size': self._max_batch_size
1981+
}
1982+
1983+
1984+
class RunInferenceRemoteTest(unittest.TestCase):
1985+
def test_normal_model_execution(self):
1986+
with TestPipeline() as pipeline:
1987+
examples = [1, 5, 3, 10]
1988+
expected = [example + 1 for example in examples]
1989+
pcoll = pipeline | 'start' >> beam.Create(examples)
1990+
actual = pcoll | base.RunInference(FakeRemoteModelHandler())
1991+
assert_that(actual, equal_to(expected), label='assert:inferences')
1992+
1993+
def test_repeated_requests_fail(self):
1994+
test_pipeline = TestPipeline()
1995+
with self.assertRaises(Exception):
1996+
_ = (
1997+
test_pipeline
1998+
| beam.Create([1, 2, 3, 4])
1999+
| base.RunInference(FakeAlwaysFailsRemoteModelHandler()))
2000+
test_pipeline.run()
2001+
2002+
def test_works_on_retry(self):
2003+
with TestPipeline() as pipeline:
2004+
examples = [1, 5, 3, 10]
2005+
expected = [example + 1 for example in examples]
2006+
pcoll = pipeline | 'start' >> beam.Create(examples)
2007+
actual = pcoll | base.RunInference(FakeFailsOnceRemoteModelHandler())
2008+
assert_that(actual, equal_to(expected), label='assert:inferences')
2009+
2010+
def test_exception_on_load_model_override(self):
2011+
with self.assertRaises(Exception):
2012+
2013+
class _(base.RemoteModelHandler[int, int, FakeModel]):
2014+
def __init__(self, clock=None, retry_filter=_always_retry, **kwargs):
2015+
self._fake_clock = clock
2016+
self._min_batch_size = 1
2017+
self._max_batch_size = 1
2018+
self._env_vars = kwargs.get('env_vars', {})
2019+
super().__init__(
2020+
namespace='FakeRemoteModelHandler', retry_filter=retry_filter)
2021+
2022+
def load_model(self):
2023+
return FakeModel()
2024+
2025+
def request(self, batch, model, inference_args=None) -> Iterable[int]:
2026+
responses = []
2027+
for example in batch:
2028+
responses.append(model.predict(example))
2029+
return responses
2030+
2031+
def test_exception_on_run_inference_override(self):
2032+
with self.assertRaises(Exception):
2033+
2034+
class _(base.RemoteModelHandler[int, int, FakeModel]):
2035+
def __init__(self, clock=None, retry_filter=_always_retry, **kwargs):
2036+
self._fake_clock = clock
2037+
self._min_batch_size = 1
2038+
self._max_batch_size = 1
2039+
self._env_vars = kwargs.get('env_vars', {})
2040+
super().__init__(
2041+
namespace='FakeRemoteModelHandler', retry_filter=retry_filter)
2042+
2043+
def create_client(self):
2044+
return FakeModel()
2045+
2046+
def run_inference(self,
2047+
batch,
2048+
model,
2049+
inference_args=None) -> Iterable[int]:
2050+
responses = []
2051+
for example in batch:
2052+
responses.append(model.predict(example))
2053+
return responses
2054+
2055+
18732056
if __name__ == '__main__':
18742057
unittest.main()

sdks/python/apache_beam/utils/retry.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ def with_exponential_backoff(
274274
The decorator is intended to be used on callables that make HTTP or RPC
275275
requests that can temporarily timeout or have transient errors. For instance
276276
the make_http_request() call below will be retried 16 times with exponential
277-
backoff and fuzzing of the delay interval (default settings).
277+
backoff and fuzzing of the delay interval (default settings). The callable
278+
should return values directly instead of yielding them, as generators are not
279+
evaluated within the try-catch block and will not be retried on exception.
278280
279281
from apache_beam.utils import retry
280282
# ...

0 commit comments

Comments
 (0)