From f3c6f47f297cf5c270a63b1e395e67cbf5e41f0e Mon Sep 17 00:00:00 2001 From: akrherz Date: Tue, 21 Apr 2026 14:36:38 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Allow=20ip=5Fthrottle=20to=20be?= =?UTF-8?q?=20callable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pyiem/webutil.py | 43 +++++++++++++++++++++++-------------------- tests/test_webutil.py | 17 +++++++++++++++++ 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/pyiem/webutil.py b/src/pyiem/webutil.py index 347e8f1cd..92f75d603 100644 --- a/src/pyiem/webutil.py +++ b/src/pyiem/webutil.py @@ -507,20 +507,23 @@ def _mcall( return res -def ip_is_throttled(environ: dict, throttle_secs: float) -> bool: +def ip_is_throttled(environ: dict, throttle_secs: float | Callable) -> bool: """Return True if the REMOTE_ADDR is throttled.""" client_ip = environ.get("REMOTE_ADDR") if not client_ip or client_ip.startswith(("127.", "129.186.", "10.")): return False - try: - mc = Client("iem-memcached:11211") - key = f"throttle:{client_ip}" - res = mc.get(key) - if res: - return True - mc.set(key, "1", expire=int(throttle_secs) + 1) - except Exception: - pass + if isinstance(throttle_secs, Callable): + throttle_secs = throttle_secs(environ) + if throttle_secs > 0: + try: + mc = Client("iem-memcached:11211") + key = f"throttle:{client_ip}" + res = mc.get(key) + if res: + return True + mc.set(key, "1", expire=int(throttle_secs) + 1) + except Exception: + pass return False @@ -542,8 +545,9 @@ def iemapp(**kwargs): response. - allowed_as_list (list): CGI parameters that are permitted to be lists. - - ip_throttle_secs (float): Number of seconds between requests from - the same REMOTE_ADDR, 0 to disable, which is the default. + - ip_throttle_secs (float or callable): Number of seconds between + requests from the same REMOTE_ADDR, 0 to disable, + which is the default. What all this does: 1) Attempts to catch database connection errors and handle nicely @@ -589,14 +593,13 @@ def _handle_exp(errormsg, routine=False, code=500): ) return msg.encode("ascii", errors="replace") - if kwargs.get("ip_throttle_secs", 0) > 0: - if ip_is_throttled(environ, kwargs["ip_throttle_secs"]): - start_response( - "429 Too Many Requests", - [("Content-type", "text/plain")], - ) - yield b"Too many requests from your IP address, slow down." - return + if ip_is_throttled(environ, kwargs.get("ip_throttle_secs", 0)): + start_response( + "429 Too Many Requests", + [("Content-type", "text/plain")], + ) + yield b"Too many requests from your IP address, slow down." + return start_time = datetime.now(timezone.utc) status_code = 500 diff --git a/tests/test_webutil.py b/tests/test_webutil.py index f385121b4..582f90d1a 100644 --- a/tests/test_webutil.py +++ b/tests/test_webutil.py @@ -51,6 +51,23 @@ def test_xss_false_positive_ampersand(): assert not _is_xss_payload("Bread & Butter") +def test_ip_throttled_callable(): + """Test that the ip throttle is callable.""" + + @iemapp(allowed_as_list=["q"], ip_throttle_secs=lambda _x: 0) + def application(_environ, start_response): + """Test.""" + start_response("200 OK", [("Content-type", "text/plain")]) + return f"{random.random()}" + + eo = {"REMOTE_ADDR": "7.7.7.7"} + c = Client(application) + resp = c.get("/?q=1", environ_overrides=eo) + assert resp.status_code == 200 + resp = c.get("/?q=1", environ_overrides=eo) + assert resp.status_code == 200 + + def test_ip_throttled(): """Test how our throttle behaves."""