Skip to content

Commit 0e347b8

Browse files
authored
Merge pull request #1199 from akrherz/callable_ip_throttle
🎨 Allow ip_throttle to be callable
2 parents d3e4697 + f3c6f47 commit 0e347b8

2 files changed

Lines changed: 40 additions & 20 deletions

File tree

src/pyiem/webutil.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -507,20 +507,23 @@ def _mcall(
507507
return res
508508

509509

510-
def ip_is_throttled(environ: dict, throttle_secs: float) -> bool:
510+
def ip_is_throttled(environ: dict, throttle_secs: float | Callable) -> bool:
511511
"""Return True if the REMOTE_ADDR is throttled."""
512512
client_ip = environ.get("REMOTE_ADDR")
513513
if not client_ip or client_ip.startswith(("127.", "129.186.", "10.")):
514514
return False
515-
try:
516-
mc = Client("iem-memcached:11211")
517-
key = f"throttle:{client_ip}"
518-
res = mc.get(key)
519-
if res:
520-
return True
521-
mc.set(key, "1", expire=int(throttle_secs) + 1)
522-
except Exception:
523-
pass
515+
if isinstance(throttle_secs, Callable):
516+
throttle_secs = throttle_secs(environ)
517+
if throttle_secs > 0:
518+
try:
519+
mc = Client("iem-memcached:11211")
520+
key = f"throttle:{client_ip}"
521+
res = mc.get(key)
522+
if res:
523+
return True
524+
mc.set(key, "1", expire=int(throttle_secs) + 1)
525+
except Exception:
526+
pass
524527
return False
525528

526529

@@ -542,8 +545,9 @@ def iemapp(**kwargs):
542545
response.
543546
- allowed_as_list (list): CGI parameters that are permitted to be
544547
lists.
545-
- ip_throttle_secs (float): Number of seconds between requests from
546-
the same REMOTE_ADDR, 0 to disable, which is the default.
548+
- ip_throttle_secs (float or callable): Number of seconds between
549+
requests from the same REMOTE_ADDR, 0 to disable,
550+
which is the default.
547551
548552
What all this does:
549553
1) Attempts to catch database connection errors and handle nicely
@@ -589,14 +593,13 @@ def _handle_exp(errormsg, routine=False, code=500):
589593
)
590594
return msg.encode("ascii", errors="replace")
591595

592-
if kwargs.get("ip_throttle_secs", 0) > 0:
593-
if ip_is_throttled(environ, kwargs["ip_throttle_secs"]):
594-
start_response(
595-
"429 Too Many Requests",
596-
[("Content-type", "text/plain")],
597-
)
598-
yield b"Too many requests from your IP address, slow down."
599-
return
596+
if ip_is_throttled(environ, kwargs.get("ip_throttle_secs", 0)):
597+
start_response(
598+
"429 Too Many Requests",
599+
[("Content-type", "text/plain")],
600+
)
601+
yield b"Too many requests from your IP address, slow down."
602+
return
600603

601604
start_time = datetime.now(timezone.utc)
602605
status_code = 500

tests/test_webutil.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ def test_xss_false_positive_ampersand():
5151
assert not _is_xss_payload("Bread & Butter")
5252

5353

54+
def test_ip_throttled_callable():
55+
"""Test that the ip throttle is callable."""
56+
57+
@iemapp(allowed_as_list=["q"], ip_throttle_secs=lambda _x: 0)
58+
def application(_environ, start_response):
59+
"""Test."""
60+
start_response("200 OK", [("Content-type", "text/plain")])
61+
return f"{random.random()}"
62+
63+
eo = {"REMOTE_ADDR": "7.7.7.7"}
64+
c = Client(application)
65+
resp = c.get("/?q=1", environ_overrides=eo)
66+
assert resp.status_code == 200
67+
resp = c.get("/?q=1", environ_overrides=eo)
68+
assert resp.status_code == 200
69+
70+
5471
def test_ip_throttled():
5572
"""Test how our throttle behaves."""
5673

0 commit comments

Comments
 (0)