Skip to content

Commit 1f6a3e1

Browse files
committed
Updated rate limitis processing, refactored to use Bucket
1 parent 9f45bda commit 1f6a3e1

File tree

2 files changed

+196
-77
lines changed

2 files changed

+196
-77
lines changed

tb_device_mqtt.py

+181-76
Original file line numberDiff line numberDiff line change
@@ -161,70 +161,128 @@ def get(self):
161161
return self.rc()
162162

163163

164+
class GreedyTokenBucket:
165+
def __init__(self, capacity, duration_sec):
166+
self.capacity = capacity
167+
self.duration = duration_sec
168+
self.tokens = capacity
169+
self.last_updated = int(monotonic())
170+
171+
def refill(self):
172+
now = int(monotonic())
173+
elapsed = now - self.last_updated
174+
refill_rate = self.capacity / self.duration
175+
refill_amount = elapsed * refill_rate
176+
self.tokens = min(self.capacity, self.tokens + refill_amount)
177+
self.last_updated = now
178+
179+
def can_consume(self, amount=1):
180+
self.refill()
181+
return self.tokens >= amount
182+
183+
def consume(self, amount=1):
184+
self.refill()
185+
if self.tokens >= amount:
186+
self.tokens -= amount
187+
return True
188+
return False
189+
190+
def get_remaining_tokens(self):
191+
self.refill()
192+
return self.tokens
193+
194+
164195
class RateLimit:
165196
def __init__(self, rate_limit, name=None, percentage=80):
197+
self.__reached_limit_index = 0
198+
self.__reached_limit_index_time = 0
166199
self._no_limit = False
167-
self._rate_limit_dict = {}
200+
self._rate_buckets = {}
168201
self.__lock = RLock()
169202
self._minimal_timeout = DEFAULT_TIMEOUT
170-
self._minimal_limit = 1000000000
203+
self._minimal_limit = float("inf")
204+
171205
from_dict = isinstance(rate_limit, dict)
172-
if from_dict:
173-
self._rate_limit_dict = rate_limit.get('rateLimits', rate_limit)
174-
name = rate_limit.get('name', name)
175-
percentage = rate_limit.get('percentage', percentage)
176-
self._no_limit = rate_limit.get('no_limit', False)
177206
self.name = name
178207
self.percentage = percentage
179-
self.__start_time = int(monotonic())
180-
if not from_dict:
181-
if ''.join(c for c in rate_limit if c not in [' ', ',', ';']) in ("", "0:0"):
208+
209+
if from_dict:
210+
self._no_limit = rate_limit.get('no_limit', False)
211+
self.percentage = rate_limit.get('percentage', percentage)
212+
self.name = rate_limit.get('name', name)
213+
214+
rate_limits = rate_limit.get('rateLimits', {})
215+
for duration_str, bucket_info in rate_limits.items():
216+
try:
217+
duration = int(duration_str)
218+
capacity = bucket_info.get("capacity")
219+
tokens = bucket_info.get("tokens")
220+
last_updated = bucket_info.get("last_updated")
221+
222+
if capacity is None or tokens is None:
223+
continue
224+
225+
bucket = GreedyTokenBucket(capacity, duration)
226+
bucket.tokens = min(capacity, float(tokens))
227+
bucket.last_updated = float(last_updated) if last_updated is not None else monotonic()
228+
229+
self._rate_buckets[duration] = bucket
230+
self._minimal_limit = min(self._minimal_limit, capacity)
231+
self._minimal_timeout = min(self._minimal_timeout, duration + 1)
232+
except Exception as e:
233+
log.warning("Invalid bucket format for duration %s: %s", duration_str, e)
234+
235+
else:
236+
clean = ''.join(c for c in rate_limit if c not in [' ', ',', ';'])
237+
if clean in ("", "0:0"):
182238
self._no_limit = True
183239
return
184-
rate_configs = rate_limit.split(";")
185-
if "," in rate_limit:
186-
rate_configs = rate_limit.split(",")
240+
241+
rate_configs = rate_limit.replace(";", ",").split(",")
187242
for rate in rate_configs:
188-
if rate == "":
243+
if not rate.strip():
189244
continue
190-
rate = rate.split(":")
191-
self._rate_limit_dict[int(rate[1])] = {"counter": 0,
192-
"start": int(monotonic()),
193-
"limit": int(int(rate[0]) * self.percentage / 100)}
194-
log.debug("Rate limit %s set to values: " % self.name)
195-
with self.__lock:
196-
if not self._no_limit:
197-
for rate_limit_time in self._rate_limit_dict:
198-
log.debug("Time: %s, Limit: %s", rate_limit_time,
199-
self._rate_limit_dict[rate_limit_time]["limit"])
200-
if self._rate_limit_dict[rate_limit_time]["limit"] < self._minimal_limit:
201-
self._minimal_limit = self._rate_limit_dict[rate_limit_time]["limit"]
202-
if rate_limit_time < self._minimal_limit:
203-
self._minimal_timeout = rate_limit_time + 1
204-
else:
205-
log.debug("No rate limits.")
245+
try:
246+
limit_str, duration_str = rate.strip().split(":")
247+
limit = int(int(limit_str) * self.percentage / 100)
248+
duration = int(duration_str)
249+
bucket = GreedyTokenBucket(limit, duration)
250+
self._rate_buckets[duration] = bucket
251+
self._minimal_limit = min(self._minimal_limit, limit)
252+
self._minimal_timeout = min(self._minimal_timeout, duration + 1)
253+
except Exception as e:
254+
log.warning("Invalid rate limit format '%s': %s", rate, e)
255+
256+
log.debug("Rate limit %s set to values:", self.name)
257+
for duration, bucket in self._rate_buckets.items():
258+
log.debug("Window: %ss, Limit: %s", duration, bucket.capacity)
206259

207260
def increase_rate_limit_counter(self, amount=1):
208261
if self._no_limit:
209262
return
210263
with self.__lock:
211-
for rate_limit_time in self._rate_limit_dict:
212-
self._rate_limit_dict[rate_limit_time]["counter"] += amount
264+
for bucket in self._rate_buckets.values():
265+
bucket.refill()
266+
bucket.tokens = max(0.0, bucket.tokens - amount)
213267

214268
def check_limit_reached(self, amount=1):
215269
if self._no_limit:
216270
return False
217271
with self.__lock:
218-
current_time = int(monotonic())
219-
for rate_limit_time, rate_limit_info in self._rate_limit_dict.items():
220-
if self._rate_limit_dict[rate_limit_time]["start"] + rate_limit_time <= current_time:
221-
self._rate_limit_dict[rate_limit_time]["start"] = current_time
222-
self._rate_limit_dict[rate_limit_time]["counter"] = 0
223-
current_limit = rate_limit_info['limit']
224-
if rate_limit_info['counter'] + amount > current_limit:
225-
return current_limit, rate_limit_time
272+
for duration, bucket in self._rate_buckets.items():
273+
if not bucket.can_consume(amount):
274+
return bucket.capacity, duration
275+
276+
for duration, bucket in self._rate_buckets.items():
277+
log.debug("%s left tokens: %.2f per %r seconds",
278+
self.name,
279+
bucket.get_remaining_tokens(),
280+
duration)
281+
bucket.consume(amount)
282+
226283
return False
227284

285+
228286
def get_minimal_limit(self):
229287
return self._minimal_limit if self.has_limit() else 0
230288

@@ -234,43 +292,89 @@ def get_minimal_timeout(self):
234292
def has_limit(self):
235293
return not self._no_limit
236294

237-
def set_limit(self, rate_limit, percentage=100):
295+
def set_limit(self, rate_limit, percentage=80):
238296
with self.__lock:
239297
self._minimal_timeout = DEFAULT_TIMEOUT
240-
self._minimal_limit = 1000000000
241-
old_rate_limit_dict = deepcopy(self._rate_limit_dict)
242-
self._rate_limit_dict = {}
298+
self._minimal_limit = float("inf")
299+
300+
old_buckets = deepcopy(self._rate_buckets)
301+
self._rate_buckets = {}
243302
self.percentage = percentage if percentage > 0 else self.percentage
244-
rate_configs = rate_limit.split(";")
245-
if "," in rate_limit:
246-
rate_configs = rate_limit.split(",")
247-
if len(rate_configs) == 2 and rate_configs[0] == "0:0":
303+
304+
clean = ''.join(c for c in rate_limit if c not in [' ', ',', ';'])
305+
if clean in ("", "0:0"):
248306
self._no_limit = True
249307
return
308+
309+
rate_configs = rate_limit.replace(";", ",").split(",")
310+
250311
for rate in rate_configs:
251-
if rate == "":
312+
if not rate.strip():
252313
continue
253-
rate = rate.split(":")
254-
rate_limit_time = int(rate[1])
255-
limit = int(int(rate[0]) * percentage / 100)
256-
self._rate_limit_dict[int(rate[1])] = {
257-
"counter": old_rate_limit_dict.get(rate_limit_time, {}).get('counter', 0),
258-
"start": old_rate_limit_dict.get(rate_limit_time, {}).get('start', int(monotonic())),
259-
"limit": limit}
260-
if rate_limit_time < self._minimal_limit:
261-
self._minimal_timeout = rate_limit_time + 1
262-
if limit < self._minimal_limit:
263-
self._minimal_limit = limit
264-
if self._rate_limit_dict:
265-
self._no_limit = False
266-
log.debug("Rate limit set to values: ")
267-
for rate_limit_time in self._rate_limit_dict:
268-
log.debug("Time: %s, Limit: %s", rate_limit_time, self._rate_limit_dict[rate_limit_time]["limit"])
314+
try:
315+
limit_str, duration_str = rate.strip().split(":")
316+
duration = int(duration_str)
317+
new_capacity = int(int(limit_str) * self.percentage / 100)
318+
319+
previous_bucket = old_buckets.get(duration)
320+
new_bucket = GreedyTokenBucket(new_capacity, duration)
321+
322+
if previous_bucket:
323+
previous_bucket.refill()
324+
used = previous_bucket.capacity - previous_bucket.tokens
325+
new_bucket.tokens = max(0.0, new_capacity - used)
326+
new_bucket.last_updated = monotonic()
327+
else:
328+
new_bucket.tokens = new_capacity
329+
new_bucket.last_updated = monotonic()
330+
331+
self._rate_buckets[duration] = new_bucket
332+
self._minimal_limit = min(self._minimal_limit, new_bucket.capacity)
333+
self._minimal_timeout = min(self._minimal_timeout, duration + 1)
334+
335+
except Exception as e:
336+
log.warning("Invalid rate limit format '%s': %s", rate, e)
337+
338+
self._no_limit = not bool(self._rate_buckets)
339+
log.debug("Rate limit set to values:")
340+
for duration, bucket in self._rate_buckets.items():
341+
log.debug("Duration: %ss, Limit: %s", duration, bucket.capacity)
342+
343+
def reach_limit(self):
344+
if self._no_limit or not self._rate_buckets:
345+
return
346+
347+
with self.__lock:
348+
durations = sorted(self._rate_buckets.keys())
349+
current_monotonic = int(monotonic())
350+
if self.__reached_limit_index_time >= current_monotonic - self._rate_buckets[durations[-1]].duration:
351+
self.__reached_limit_index = 0
352+
self.__reached_limit_index_time = current_monotonic
353+
if self.__reached_limit_index >= len(durations):
354+
self.__reached_limit_index = 0
355+
self.__reached_limit_index_time = current_monotonic
356+
357+
target_duration = durations[self.__reached_limit_index]
358+
bucket = self._rate_buckets[target_duration]
359+
bucket.refill()
360+
bucket.tokens = 0.0
361+
362+
self.__reached_limit_index += 1
363+
log.info("Received disconnection due to rate limit for \"%s\" rate limit, waiting for tokens in bucket for %s seconds",
364+
self.name,
365+
target_duration)
269366

270367
@property
271368
def __dict__(self):
369+
rate_limits_dict = {}
370+
for duration, bucket in self._rate_buckets.items():
371+
rate_limits_dict[str(duration)] = {
372+
"capacity": bucket.capacity,
373+
"tokens": bucket.get_remaining_tokens(),
374+
"last_updated": bucket.last_updated
375+
}
272376
return {
273-
"rateLimits": self._rate_limit_dict,
377+
"rateLimits": rate_limits_dict,
274378
"name": self.name,
275379
"percentage": self.percentage,
276380
"no_limit": self._no_limit
@@ -574,6 +678,8 @@ def _on_decoded_message(self, content, message):
574678
callback[0](content, None, callback[1])
575679
elif callback is not None:
576680
callback(content, None)
681+
else:
682+
log.debug("Message received with topic: %s", message.topic)
577683

578684
if message.topic.startswith("v1/devices/me/attributes"):
579685
self._messages_rate_limit.increase_rate_limit_counter()
@@ -769,8 +875,8 @@ def _wait_for_rate_limit_released(self, timeout, message_rate_limit, dp_rate_lim
769875
limit_reached_check = (message_rate_limit_check
770876
or datapoints_rate_limit_check
771877
or not self.is_connected())
772-
if timeout < limit_reached_check:
773-
timeout = limit_reached_check
878+
if isinstance(limit_reached_check, tuple) and timeout < limit_reached_check[1]:
879+
timeout = limit_reached_check[1]
774880
if not timeout_updated and limit_reached_check:
775881
timeout += 10
776882
timeout_updated = True
@@ -791,14 +897,13 @@ def _wait_for_rate_limit_released(self, timeout, message_rate_limit, dp_rate_lim
791897
datapoints_rate_limit_check)
792898
return TBPublishInfo(paho.MQTTMessageInfo(None))
793899
if not log_posted and limit_reached_check:
794-
if message_rate_limit_check:
795-
log.debug("Rate limit for messages [%r:%r] - reached, waiting for rate limit to be released...",
796-
message_rate_limit_check,
797-
message_rate_limit_check)
798-
elif datapoints_rate_limit_check:
799-
log.debug("Rate limit for data points [%r:%r] - reached, waiting for rate limit to be released...",
800-
datapoints_rate_limit_check,
801-
datapoints_rate_limit_check)
900+
if log.isEnabledFor(logging.DEBUG):
901+
if isinstance(message_rate_limit_check, tuple):
902+
log.debug("Rate limit for messages (%r messages per %r second(s)) - almost reached, waiting for rate limit to be released...",
903+
*message_rate_limit_check)
904+
if isinstance(datapoints_rate_limit_check, tuple):
905+
log.debug("Rate limit for data points (%r data points per %r second(s)) - almost reached, waiting for rate limit to be released...",
906+
*datapoints_rate_limit_check)
802907
waited = True
803908
log_posted = True
804909
if limit_reached_check:

tb_gateway_mqtt.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
GATEWAY_ATTRIBUTES_TOPIC = "v1/gateway/attributes"
2626
GATEWAY_TELEMETRY_TOPIC = "v1/gateway/telemetry"
27+
GATEWAY_DISCONNECT_TOPIC = "v1/gateway/disconnect"
2728
GATEWAY_ATTRIBUTES_REQUEST_TOPIC = "v1/gateway/attributes/request"
2829
GATEWAY_ATTRIBUTES_RESPONSE_TOPIC = "v1/gateway/attributes/response"
2930
GATEWAY_MAIN_TOPIC = "v1/gateway/"
@@ -73,6 +74,7 @@ def __init__(self, host, port=1883, username=None, password=None, gateway=None,
7374
self.__sub_dict = {}
7475
self.__connected_devices = set("*")
7576
self.devices_server_side_rpc_request_handler = None
77+
self.device_disconnect_callback = None
7678
self._client.on_connect = self._on_connect
7779
self._client.on_message = self._on_message
7880
self._client.on_subscribe = self._on_subscribe
@@ -162,6 +164,18 @@ def _on_decoded_message(self, content, message, **kwargs):
162164
self._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter(1)
163165
if self.devices_server_side_rpc_request_handler:
164166
self.devices_server_side_rpc_request_handler(self, content)
167+
elif message.topic == GATEWAY_DISCONNECT_TOPIC:
168+
if content.get("reason"):
169+
reason = content["reason"]
170+
log.info("Device \"%s\" disconnected with reason %s", content["device"], content["reason"])
171+
if reason == 150: # 150 - Rate limit reached
172+
self._devices_connected_through_gateway_messages_rate_limit.reach_limit()
173+
self._devices_connected_through_gateway_telemetry_messages_rate_limit.reach_limit()
174+
self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.reach_limit()
175+
if self.device_disconnect_callback is not None:
176+
self.device_disconnect_callback(self, content)
177+
else:
178+
log.warning("Unknown message from topic %s", message.topic)
165179

166180
def __request_attributes(self, device, keys, callback, type_is_client=False):
167181
if not keys:
@@ -323,4 +337,4 @@ def __on_service_configuration(self, _, response, *args, **kwargs):
323337
{'rateLimits': gateway_device_itself_rate_limit_config, **service_config},
324338
*args,
325339
**kwargs)
326-
log.info("Current gateway limits: %r", service_config)
340+
log.info("Current limits for devices connected through the gateway: %r", gateway_devices_rate_limit_config)

0 commit comments

Comments
 (0)