Skip to content

Commit 0c4d81b

Browse files
authored
Cherry Pick RateLimiter SDK changes to Beam 2.71 release (#37306)
* Support for RateLimiter in Beam Remote Model Handler (#37218) * Support for EnvoyRateLimiter in Apache Beam * fix format issues * fix test formatting * Fix test and syntax * fix lint * Add dependency based on python version * revert setup to separete pr * fix lint * fix formatting * resolve comments * Support Ratelimiter through RemoteModelHandler * fix lint * fix lint * fix comments * Add custom RateLimited Exception * fix doc * fix test * fix lint * update RateLimiter execution function name (#37287) * Catch breaking import error (#37295) * Catch Import Error * import order
1 parent c77c61e commit 0c4d81b

File tree

6 files changed

+225
-21
lines changed

6 files changed

+225
-21
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A simple example demonstrating usage of the EnvoyRateLimiter with Vertex AI.
19+
"""
20+
21+
import argparse
22+
import logging
23+
24+
import apache_beam as beam
25+
from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
26+
from apache_beam.ml.inference.base import RunInference
27+
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
28+
from apache_beam.options.pipeline_options import PipelineOptions
29+
from apache_beam.options.pipeline_options import SetupOptions
30+
31+
32+
def run(argv=None):
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument(
35+
'--project',
36+
dest='project',
37+
help='The Google Cloud project ID for Vertex AI.')
38+
parser.add_argument(
39+
'--location',
40+
dest='location',
41+
help='The Google Cloud location (e.g. us-central1) for Vertex AI.')
42+
parser.add_argument(
43+
'--endpoint_id',
44+
dest='endpoint_id',
45+
help='The ID of the Vertex AI endpoint.')
46+
parser.add_argument(
47+
'--rls_address',
48+
dest='rls_address',
49+
help='The address of the Envoy Rate Limit Service (e.g. localhost:8081).')
50+
51+
known_args, pipeline_args = parser.parse_known_args(argv)
52+
pipeline_options = PipelineOptions(pipeline_args)
53+
pipeline_options.view_as(SetupOptions).save_main_session = True
54+
55+
# Initialize the EnvoyRateLimiter
56+
rate_limiter = EnvoyRateLimiter(
57+
service_address=known_args.rls_address,
58+
domain="mongo_cps",
59+
descriptors=[{
60+
"database": "users"
61+
}],
62+
namespace='example_pipeline')
63+
64+
# Initialize the VertexAIModelHandler with the rate limiter
65+
model_handler = VertexAIModelHandlerJSON(
66+
endpoint_id=known_args.endpoint_id,
67+
project=known_args.project,
68+
location=known_args.location,
69+
rate_limiter=rate_limiter)
70+
71+
# Input features for the model
72+
features = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
73+
[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]
74+
75+
with beam.Pipeline(options=pipeline_options) as p:
76+
_ = (
77+
p
78+
| 'CreateInputs' >> beam.Create(features)
79+
| 'RunInference' >> RunInference(model_handler)
80+
| 'PrintPredictions' >> beam.Map(logging.info))
81+
82+
83+
if __name__ == '__main__':
84+
logging.getLogger().setLevel(logging.INFO)
85+
run()

sdks/python/apache_beam/examples/rate_limiter_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def init_limiter():
5353
self.rate_limiter = self._shared.acquire(init_limiter)
5454

5555
def process(self, element):
56-
self.rate_limiter.throttle()
56+
self.rate_limiter.allow()
5757

5858
# Process the element mock API call
5959
logging.info("Processing element: %s", element)

sdks/python/apache_beam/io/components/rate_limiter.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,13 @@ def __init__(self, namespace: str = ""):
6161
self.rpc_latency = Metrics.distribution(namespace, 'RatelimitRpcLatencyMs')
6262

6363
@abc.abstractmethod
64-
def throttle(self, **kwargs) -> bool:
65-
"""Check if request should be throttled.
64+
def allow(self, **kwargs) -> bool:
65+
"""Applies rate limiting to the request.
66+
67+
This method checks if the request is permitted by the rate limiting policy.
68+
Depending on the implementation and configuration, it may block (sleep)
69+
until the request is allowed, or return false if the rate limit retry is
70+
exceeded.
6671
6772
Args:
6873
**kwargs: Keyword arguments specific to the RateLimiter implementation.
@@ -78,8 +83,12 @@ def throttle(self, **kwargs) -> bool:
7883

7984

8085
class EnvoyRateLimiter(RateLimiter):
81-
"""
82-
Rate limiter implementation that uses an external Envoy Rate Limit Service.
86+
"""Rate limiter implementation that uses an external Envoy Rate Limit Service.
87+
88+
This limiter connects to a gRPC Envoy Rate Limit Service (RLS) to determine
89+
whether a request should be allowed. It supports defining a domain and a
90+
list of descriptors that correspond to the rate limit configuration in the
91+
RLS.
8392
"""
8493
def __init__(
8594
self,
@@ -89,7 +98,7 @@ def __init__(
8998
timeout: float = 5.0,
9099
block_until_allowed: bool = True,
91100
retries: int = 3,
92-
namespace: str = ""):
101+
namespace: str = ''):
93102
"""
94103
Args:
95104
service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
@@ -139,8 +148,16 @@ def init_connection(self):
139148
channel = grpc.insecure_channel(self.service_address)
140149
self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
141150

142-
def throttle(self, hits_added: int = 1) -> bool:
143-
"""Calls the Envoy RLS to check for rate limits.
151+
def allow(self, hits_added: int = 1) -> bool:
152+
"""Calls the Envoy RLS to apply rate limits.
153+
154+
Sends a rate limit request to the configured Envoy Rate Limit Service.
155+
If 'block_until_allowed' is True, this method will sleep and retry
156+
if the limit is exceeded, effectively blocking until the request is
157+
permitted.
158+
159+
If 'block_until_allowed' is False, it will return False after the retry
160+
limit is exceeded.
144161
145162
Args:
146163
hits_added: Number of hits to add to the rate limit.
@@ -224,3 +241,16 @@ def throttle(self, hits_added: int = 1) -> bool:
224241
response.overall_code)
225242
break
226243
return throttled
244+
245+
def __getstate__(self):
246+
state = self.__dict__.copy()
247+
if '_lock' in state:
248+
del state['_lock']
249+
if '_stub' in state:
250+
del state['_stub']
251+
return state
252+
253+
def __setstate__(self, state):
254+
self.__dict__.update(state)
255+
self._lock = threading.Lock()
256+
self._stub = None

sdks/python/apache_beam/io/components/rate_limiter_test.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def setUp(self):
4242
namespace='test_namespace')
4343

4444
@mock.patch('grpc.insecure_channel')
45-
def test_throttle_allowed(self, mock_channel):
45+
def test_allow_success(self, mock_channel):
4646
# Mock successful OK response
4747
mock_stub = mock.Mock()
4848
mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
@@ -51,13 +51,13 @@ def test_throttle_allowed(self, mock_channel):
5151
# Inject mock stub
5252
self.limiter._stub = mock_stub
5353

54-
throttled = self.limiter.throttle()
54+
allowed = self.limiter.allow()
5555

56-
self.assertTrue(throttled)
56+
self.assertTrue(allowed)
5757
mock_stub.ShouldRateLimit.assert_called_once()
5858

5959
@mock.patch('grpc.insecure_channel')
60-
def test_throttle_over_limit_retries_exceeded(self, mock_channel):
60+
def test_allow_over_limit_retries_exceeded(self, mock_channel):
6161
# Mock OVER_LIMIT response
6262
mock_stub = mock.Mock()
6363
mock_response = RateLimitResponse(
@@ -69,9 +69,9 @@ def test_throttle_over_limit_retries_exceeded(self, mock_channel):
6969

7070
# We mock time.sleep to run fast
7171
with mock.patch('time.sleep'):
72-
throttled = self.limiter.throttle()
72+
allowed = self.limiter.allow()
7373

74-
self.assertFalse(throttled)
74+
self.assertFalse(allowed)
7575
# Should be called 1 (initial) + 2 (retries) + 1 (last check > retries
7676
# logic depends on loop)
7777
# Logic: attempt starts at 0.
@@ -83,7 +83,7 @@ def test_throttle_over_limit_retries_exceeded(self, mock_channel):
8383
self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
8484

8585
@mock.patch('grpc.insecure_channel')
86-
def test_throttle_rpc_error_retry(self, mock_channel):
86+
def test_allow_rpc_error_retry(self, mock_channel):
8787
# Mock RpcError then Success
8888
mock_stub = mock.Mock()
8989
mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
@@ -95,13 +95,13 @@ def test_throttle_rpc_error_retry(self, mock_channel):
9595
self.limiter._stub = mock_stub
9696

9797
with mock.patch('time.sleep'):
98-
throttled = self.limiter.throttle()
98+
allowed = self.limiter.allow()
9999

100-
self.assertTrue(throttled)
100+
self.assertTrue(allowed)
101101
self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
102102

103103
@mock.patch('grpc.insecure_channel')
104-
def test_throttle_rpc_error_fail(self, mock_channel):
104+
def test_allow_rpc_error_fail(self, mock_channel):
105105
# Mock Persistent RpcError
106106
mock_stub = mock.Mock()
107107
error = grpc.RpcError()
@@ -111,7 +111,7 @@ def test_throttle_rpc_error_fail(self, mock_channel):
111111

112112
with mock.patch('time.sleep'):
113113
with self.assertRaises(grpc.RpcError):
114-
self.limiter.throttle()
114+
self.limiter.allow()
115115

116116
# The inner loop tries 5 times for connection errors
117117
self.assertEqual(mock_stub.ShouldRateLimit.call_count, 5)
@@ -134,7 +134,7 @@ def test_extract_duration_from_response(self, mock_random, mock_channel):
134134
self.limiter.retries = 0 # Single attempt
135135

136136
with mock.patch('time.sleep') as mock_sleep:
137-
self.limiter.throttle()
137+
self.limiter.allow()
138138
# Should sleep for 5 seconds (jitter is 0.0)
139139
mock_sleep.assert_called_with(5.0)
140140

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@
6060
from apache_beam.utils import retry
6161
from apache_beam.utils import shared
6262

63+
try:
64+
from apache_beam.io.components.rate_limiter import RateLimiter
65+
except ImportError:
66+
RateLimiter = None
67+
6368
try:
6469
# pylint: disable=wrong-import-order, wrong-import-position
6570
import resource
@@ -102,6 +107,11 @@ def __new__(cls, example, inference, model_id=None):
102107
PredictionResult.model_id.__doc__ = """Model ID used to run the prediction."""
103108

104109

110+
class RateLimitExceeded(RuntimeError):
111+
"""RateLimit Exceeded to process a batch of requests."""
112+
pass
113+
114+
105115
class ModelMetadata(NamedTuple):
106116
model_id: str
107117
model_name: str
@@ -349,7 +359,8 @@ def __init__(
349359
*,
350360
window_ms: int = 1 * _MILLISECOND_TO_SECOND,
351361
bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
352-
overload_ratio: float = 2):
362+
overload_ratio: float = 2,
363+
rate_limiter: Optional[RateLimiter] = None):
353364
"""Initializes a ReactiveThrottler class for enabling
354365
client-side throttling for remote calls to an inference service. Also wraps
355366
provided calls to the service with retry logic.
@@ -372,6 +383,7 @@ def __init__(
372383
overload_ratio: the target ratio between requests sent and successful
373384
requests. This is "K" in the formula in
374385
https://landing.google.com/sre/book/chapters/handling-overload.html.
386+
rate_limiter: A RateLimiter object for setting a global rate limit.
375387
"""
376388
# Configure ReactiveThrottler for client-side throttling behavior.
377389
self.throttler = ReactiveThrottler(
@@ -383,6 +395,9 @@ def __init__(
383395
self.logger = logging.getLogger(namespace)
384396
self.num_retries = num_retries
385397
self.retry_filter = retry_filter
398+
self._rate_limiter = rate_limiter
399+
self._shared_rate_limiter = None
400+
self._shared_handle = shared.Shared()
386401

387402
def __init_subclass__(cls):
388403
if cls.load_model is not RemoteModelHandler.load_model:
@@ -431,6 +446,19 @@ def run_inference(
431446
Returns:
432447
An Iterable of Predictions.
433448
"""
449+
if self._rate_limiter:
450+
if self._shared_rate_limiter is None:
451+
452+
def init_limiter():
453+
return self._rate_limiter
454+
455+
self._shared_rate_limiter = self._shared_handle.acquire(init_limiter)
456+
457+
if not self._shared_rate_limiter.allow(hits_added=len(batch)):
458+
raise RateLimitExceeded(
459+
"Rate Limit Exceeded, "
460+
"Could not process this batch.")
461+
434462
self.throttler.throttle()
435463

436464
try:

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,67 @@ def run_inference(self,
20712071
responses.append(model.predict(example))
20722072
return responses
20732073

2074+
def test_run_inference_with_rate_limiter(self):
2075+
class FakeRateLimiter(base.RateLimiter):
2076+
def __init__(self):
2077+
super().__init__(namespace='test_namespace')
2078+
2079+
def allow(self, hits_added=1):
2080+
self.requests_counter.inc()
2081+
return True
2082+
2083+
limiter = FakeRateLimiter()
2084+
2085+
with TestPipeline() as pipeline:
2086+
examples = [1, 5]
2087+
2088+
class ConcreteRemoteModelHandler(base.RemoteModelHandler):
2089+
def create_client(self):
2090+
return FakeModel()
2091+
2092+
def request(self, batch, model, inference_args=None):
2093+
return [model.predict(example) for example in batch]
2094+
2095+
model_handler = ConcreteRemoteModelHandler(
2096+
rate_limiter=limiter, namespace='test_namespace')
2097+
2098+
pcoll = pipeline | 'start' >> beam.Create(examples)
2099+
actual = pcoll | base.RunInference(model_handler)
2100+
2101+
expected = [2, 6]
2102+
assert_that(actual, equal_to(expected))
2103+
2104+
result = pipeline.run()
2105+
result.wait_until_finish()
2106+
2107+
metrics_filter = MetricsFilter().with_name(
2108+
'RatelimitRequestsTotal').with_namespace('test_namespace')
2109+
metrics = result.metrics().query(metrics_filter)
2110+
self.assertGreaterEqual(metrics['counters'][0].committed, 0)
2111+
2112+
def test_run_inference_with_rate_limiter_exceeded(self):
2113+
class FakeRateLimiter(base.RateLimiter):
2114+
def __init__(self):
2115+
super().__init__(namespace='test_namespace')
2116+
2117+
def allow(self, hits_added=1):
2118+
return False
2119+
2120+
class ConcreteRemoteModelHandler(base.RemoteModelHandler):
2121+
def create_client(self):
2122+
return FakeModel()
2123+
2124+
def request(self, batch, model, inference_args=None):
2125+
return [model.predict(example) for example in batch]
2126+
2127+
model_handler = ConcreteRemoteModelHandler(
2128+
rate_limiter=FakeRateLimiter(),
2129+
namespace='test_namespace',
2130+
num_retries=0)
2131+
2132+
with self.assertRaises(base.RateLimitExceeded):
2133+
model_handler.run_inference([1], FakeModel())
2134+
20742135

20752136
if __name__ == '__main__':
20762137
unittest.main()

0 commit comments

Comments
 (0)