Skip to content

Commit d158a5e

Browse files
committed
fix: review fixes — 100% coverage, typo, regex hoist, docstring path
- Fix typo: validatated_headers → validated_headers - Hoist RFC 9110 token regex to module-level compiled constant - Consolidate redundant CRLF check into unified CTL validation - Fix integration test docstring path (was unit test path) - Add tests for _parse_rate minute/hour/error branches - Add tests for unlimited (no-limit) prompt and tool paths - Achieve 100% differential test coverage on rate_limiter.py Signed-off-by: Mihai Criveti <crivetimihai@gmail.com>
1 parent d559ad9 commit d158a5e

File tree

3 files changed

+108
-20
lines changed

3 files changed

+108
-20
lines changed

mcpgateway/main.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,14 @@ async def database_exception_handler(_request: Request, exc: IntegrityError):
14521452
return ORJSONResponse(status_code=409, content=ErrorFormatter.format_database_error(exc))
14531453

14541454

1455+
# RFC 9110 §5.6.2 'token' pattern for header field names:
1456+
# token = 1*tchar
1457+
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
1458+
# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
1459+
# / DIGIT / ALPHA
1460+
_RFC9110_TOKEN_RE = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$")
1461+
1462+
14551463
def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]:
14561464
"""Validate headers according to RFC 9110.
14571465
@@ -1464,34 +1472,24 @@ def _validate_http_headers(headers: dict[str, str]) -> Optional[dict[str, str]]:
14641472
Rules enforced:
14651473
- Header name must match RFC 9110 'token'.
14661474
- No whitespace before colon (enforced by dictionary usage).
1467-
- Header value must not contain CTL characters (0x00–0x1F, 0x7F).
1475+
- Header value must not contain CTL characters (0x00–0x1F, 0x7F),
1476+
except SP (0x20) and HTAB (0x09) which are allowed.
14681477
"""
1469-
1470-
# RFC 9110 'token' definition:
1471-
# token = 1*tchar
1472-
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
1473-
# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
1474-
# / DIGIT / ALPHA
1475-
header_key = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$")
14761478
validated: dict[str, str] = {}
14771479
for key, value in headers.items():
1478-
# Validate header name (RFC 9110)
1479-
if not re.match(header_key, key):
1480+
# Validate header name (RFC 9110 token)
1481+
if not _RFC9110_TOKEN_RE.match(key):
14801482
logger.warning(f"Invalid header name: {key}")
14811483
continue
1482-
# Validate header value (no CRLF)
1483-
if "\r" in value or "\n" in value:
1484-
logger.warning(f"Header value contains CRLF: {key}")
1485-
continue
14861484
# RFC 9110: Reject CTLs (0x00–0x1F, 0x7F). Allow SP (0x20) and HTAB (0x09).
1487-
# Further structure (quoted-string, lists, parameters) is left to higher-level parsers.
14881485
valid = True
14891486
for ch in value:
14901487
code = ord(ch)
14911488
if (0 <= code <= 31 or code == 127) and code not in (9, 32):
14921489
valid = False
14931490
break
14941491
if not valid:
1492+
logger.warning(f"Header value contains invalid characters: {key}")
14951493
continue
14961494
validated[key] = value
14971495
return validated if validated else None
@@ -1576,9 +1574,9 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
15761574

15771575
response = ORJSONResponse(status_code=http_status, content={"error": json_rpc_error.model_dump()})
15781576
if headers:
1579-
validatated_headers = _validate_http_headers(headers)
1580-
if validatated_headers:
1581-
response.headers.update(validatated_headers)
1577+
validated_headers = _validate_http_headers(headers)
1578+
if validated_headers:
1579+
response.headers.update(validated_headers)
15821580
return response
15831581

15841582

tests/integration/test_rate_limiter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
"""Location: ./tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py
2+
"""Location: ./tests/integration/test_rate_limiter.py
33
Copyright 2025
44
SPDX-License-Identifier: Apache-2.0
55
Authors: Mihai Criveti

tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
PromptPrehookPayload,
1818
ToolHookType
1919
)
20-
from plugins.rate_limiter.rate_limiter import RateLimiterPlugin, _make_headers, _select_most_restrictive, _store
20+
from plugins.rate_limiter.rate_limiter import RateLimiterPlugin, _make_headers, _parse_rate, _select_most_restrictive, _store
2121

2222

2323
@pytest.fixture(autouse=True)
@@ -509,3 +509,93 @@ def test_tenant_unlimited_user_tool_limited(self):
509509
assert allowed is True
510510
assert limit == 50 # Tool is most restrictive
511511
assert remaining == 30
512+
513+
514+
# ============================================================================
515+
# _parse_rate Tests
516+
# ============================================================================
517+
518+
519+
class TestParseRate:
520+
"""Tests for _parse_rate helper covering all time units."""
521+
522+
def test_seconds_short(self):
523+
assert _parse_rate("10/s") == (10, 1)
524+
525+
def test_seconds_medium(self):
526+
assert _parse_rate("10/sec") == (10, 1)
527+
528+
def test_seconds_long(self):
529+
assert _parse_rate("10/second") == (10, 1)
530+
531+
def test_minutes_short(self):
532+
assert _parse_rate("60/m") == (60, 60)
533+
534+
def test_minutes_medium(self):
535+
assert _parse_rate("60/min") == (60, 60)
536+
537+
def test_minutes_long(self):
538+
assert _parse_rate("60/minute") == (60, 60)
539+
540+
def test_hours_short(self):
541+
assert _parse_rate("100/h") == (100, 3600)
542+
543+
def test_hours_medium(self):
544+
assert _parse_rate("100/hr") == (100, 3600)
545+
546+
def test_hours_long(self):
547+
assert _parse_rate("100/hour") == (100, 3600)
548+
549+
def test_unsupported_unit_raises(self):
550+
with pytest.raises(ValueError, match="Unsupported rate unit"):
551+
_parse_rate("10/d")
552+
553+
def test_whitespace_stripped(self):
554+
assert _parse_rate("5/ M ") == (5, 60)
555+
556+
557+
# ============================================================================
558+
# Unlimited (no-limit) path tests
559+
# ============================================================================
560+
561+
562+
def _mk_unlimited() -> RateLimiterPlugin:
563+
"""Create a plugin with no rate limits configured."""
564+
return RateLimiterPlugin(
565+
PluginConfig(
566+
name="rl",
567+
kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin",
568+
hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE],
569+
config={}, # No limits
570+
)
571+
)
572+
573+
574+
@pytest.mark.asyncio
575+
async def test_prompt_pre_fetch_unlimited_returns_no_headers():
576+
"""When no limits are configured, prompt_pre_fetch returns metadata without http_headers."""
577+
plugin = _mk_unlimited()
578+
ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1"))
579+
payload = PromptPrehookPayload(prompt_id="p", args={})
580+
581+
result = await plugin.prompt_pre_fetch(payload, ctx)
582+
assert result.violation is None
583+
assert result.http_headers is None
584+
assert result.metadata is not None
585+
assert result.metadata.get("limited") is False
586+
587+
588+
@pytest.mark.asyncio
589+
async def test_tool_pre_invoke_unlimited_returns_no_headers():
590+
"""When no limits are configured, tool_pre_invoke returns metadata without http_headers."""
591+
from mcpgateway.plugins.framework import ToolPreInvokePayload
592+
593+
plugin = _mk_unlimited()
594+
ctx = PluginContext(global_context=GlobalContext(request_id="r1", user="u1"))
595+
payload = ToolPreInvokePayload(name="test_tool", arguments={})
596+
597+
result = await plugin.tool_pre_invoke(payload, ctx)
598+
assert result.violation is None
599+
assert result.http_headers is None
600+
assert result.metadata is not None
601+
assert result.metadata.get("limited") is False

0 commit comments

Comments
 (0)