Skip to content

Commit 6c538cc

Browse files
committed
fix: guard RESP parser against deeply nested aggregate replies
The pure-Python RESP2 and RESP3 parsers recursively process nested aggregate replies (arrays, sets, maps, push responses) without any depth bound. A malicious or misbehaving server can craft deeply nested replies that exhaust the Python call stack, causing a client-side RecursionError denial of service. Add a _MAX_PARSE_DEPTH (256) guard to all four _read_response methods (sync/async x RESP2/RESP3). When the nesting depth exceeds the limit, raise InvalidResponse with a clear message instead of letting Python hit its recursion limit. Fixes #4116
1 parent 8210f32 commit 6c538cc

3 files changed

Lines changed: 208 additions & 14 deletions

File tree

redis/_parsers/resp2.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
from .base import _AsyncRESPBase, _RESPBase
77
from .socket import SERVER_CLOSED_CONNECTION_ERROR
88

9+
# Maximum nesting depth for aggregate RESP types (arrays, sets, maps).
10+
# A malicious or misbehaving server can craft deeply nested replies that
11+
# exhaust the Python call stack. Guard against this by bounding recursion.
12+
_MAX_PARSE_DEPTH = 256
13+
914

1015
class _RESP2Parser(_RESPBase):
1116
"""RESP2 protocol implementation"""
@@ -27,8 +32,15 @@ def read_response(
2732
return result
2833

2934
def _read_response(
30-
self, disable_decoding=False, timeout: Union[float, object] = SENTINEL
35+
self,
36+
disable_decoding=False,
37+
timeout: Union[float, object] = SENTINEL,
38+
_depth: int = 0,
3139
):
40+
if _depth > _MAX_PARSE_DEPTH:
41+
raise InvalidResponse(
42+
f"Exceeded maximum response nesting depth ({_MAX_PARSE_DEPTH})"
43+
)
3244
raw = self._buffer.readline(timeout=timeout)
3345
if not raw:
3446
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -64,7 +76,11 @@ def _read_response(
6476
return None
6577
elif byte == b"*":
6678
response = [
67-
self._read_response(disable_decoding=disable_decoding, timeout=timeout)
79+
self._read_response(
80+
disable_decoding=disable_decoding,
81+
timeout=timeout,
82+
_depth=_depth + 1,
83+
)
6884
for i in range(int(response))
6985
]
7086
else:
@@ -92,8 +108,14 @@ async def read_response(self, disable_decoding: bool = False):
92108
return response
93109

94110
async def _read_response(
95-
self, disable_decoding: bool = False
111+
self,
112+
disable_decoding: bool = False,
113+
_depth: int = 0,
96114
) -> Union[EncodableT, ResponseError, None]:
115+
if _depth > _MAX_PARSE_DEPTH:
116+
raise InvalidResponse(
117+
f"Exceeded maximum response nesting depth ({_MAX_PARSE_DEPTH})"
118+
)
97119
raw = await self._readline()
98120
response: Any
99121
byte, response = raw[:1], raw[1:]
@@ -128,7 +150,7 @@ async def _read_response(
128150
return None
129151
elif byte == b"*":
130152
response = [
131-
(await self._read_response(disable_decoding))
153+
(await self._read_response(disable_decoding, _depth=_depth + 1))
132154
for _ in range(int(response)) # noqa
133155
]
134156
else:

redis/_parsers/resp3.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
)
1313
from .socket import SERVER_CLOSED_CONNECTION_ERROR
1414

15+
# Maximum nesting depth for aggregate RESP types (arrays, sets, maps).
16+
# A malicious or misbehaving server can craft deeply nested replies that
17+
# exhaust the Python call stack. Guard against this by bounding recursion.
18+
_MAX_PARSE_DEPTH = 256
19+
1520

1621
class _RESP3Parser(_RESPBase, PushNotificationsParser):
1722
"""RESP3 protocol implementation"""
@@ -61,7 +66,12 @@ def _read_response(
6166
disable_decoding=False,
6267
push_request=False,
6368
timeout: Union[float, object] = SENTINEL,
69+
_depth: int = 0,
6470
):
71+
if _depth > _MAX_PARSE_DEPTH:
72+
raise InvalidResponse(
73+
f"Exceeded maximum response nesting depth ({_MAX_PARSE_DEPTH})"
74+
)
6575
raw = self._buffer.readline(timeout=timeout)
6676
if not raw:
6777
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -107,15 +117,23 @@ def _read_response(
107117
# array response
108118
elif byte == b"*":
109119
response = [
110-
self._read_response(disable_decoding=disable_decoding, timeout=timeout)
120+
self._read_response(
121+
disable_decoding=disable_decoding,
122+
timeout=timeout,
123+
_depth=_depth + 1,
124+
)
111125
for _ in range(int(response))
112126
]
113127
# set response
114128
elif byte == b"~":
115129
# redis can return unhashable types (like dict) in a set,
116130
# so we return sets as list, all the time, for predictability
117131
response = [
118-
self._read_response(disable_decoding=disable_decoding, timeout=timeout)
132+
self._read_response(
133+
disable_decoding=disable_decoding,
134+
timeout=timeout,
135+
_depth=_depth + 1,
136+
)
119137
for _ in range(int(response))
120138
]
121139
# map response
@@ -126,12 +144,15 @@ def _read_response(
126144
resp_dict = {}
127145
for _ in range(int(response)):
128146
key = self._read_response(
129-
disable_decoding=disable_decoding, timeout=timeout
147+
disable_decoding=disable_decoding,
148+
timeout=timeout,
149+
_depth=_depth + 1,
130150
)
131151
resp_dict[key] = self._read_response(
132152
disable_decoding=disable_decoding,
133153
push_request=push_request,
134154
timeout=timeout,
155+
_depth=_depth + 1,
135156
)
136157
response = resp_dict
137158
# push response
@@ -141,6 +162,7 @@ def _read_response(
141162
disable_decoding=disable_decoding,
142163
push_request=push_request,
143164
timeout=timeout,
165+
_depth=_depth + 1,
144166
)
145167
for _ in range(int(response))
146168
]
@@ -153,6 +175,7 @@ def _read_response(
153175
return self._read_response(
154176
disable_decoding=disable_decoding,
155177
push_request=push_request,
178+
_depth=_depth + 1,
156179
)
157180
else:
158181
raise InvalidResponse(f"Protocol Error: {raw!r}")
@@ -190,8 +213,15 @@ async def read_response(
190213
return response
191214

192215
async def _read_response(
193-
self, disable_decoding: bool = False, push_request: bool = False
216+
self,
217+
disable_decoding: bool = False,
218+
push_request: bool = False,
219+
_depth: int = 0,
194220
) -> Union[EncodableT, ResponseError, None]:
221+
if _depth > _MAX_PARSE_DEPTH:
222+
raise InvalidResponse(
223+
f"Exceeded maximum response nesting depth ({_MAX_PARSE_DEPTH})"
224+
)
195225
if not self._stream or not self.encoder:
196226
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
197227
raw = await self._readline()
@@ -241,15 +271,23 @@ async def _read_response(
241271
# array response
242272
elif byte == b"*":
243273
response = [
244-
(await self._read_response(disable_decoding=disable_decoding))
274+
(
275+
await self._read_response(
276+
disable_decoding=disable_decoding, _depth=_depth + 1
277+
)
278+
)
245279
for _ in range(int(response))
246280
]
247281
# set response
248282
elif byte == b"~":
249283
# redis can return unhashable types (like dict) in a set,
250284
# so we always convert to a list, to have predictable return types
251285
response = [
252-
(await self._read_response(disable_decoding=disable_decoding))
286+
(
287+
await self._read_response(
288+
disable_decoding=disable_decoding, _depth=_depth + 1
289+
)
290+
)
253291
for _ in range(int(response))
254292
]
255293
# map response
@@ -259,25 +297,33 @@ async def _read_response(
259297
# became defined to be left-right in version 3.8
260298
resp_dict = {}
261299
for _ in range(int(response)):
262-
key = await self._read_response(disable_decoding=disable_decoding)
300+
key = await self._read_response(
301+
disable_decoding=disable_decoding, _depth=_depth + 1
302+
)
263303
resp_dict[key] = await self._read_response(
264-
disable_decoding=disable_decoding, push_request=push_request
304+
disable_decoding=disable_decoding,
305+
push_request=push_request,
306+
_depth=_depth + 1,
265307
)
266308
response = resp_dict
267309
# push response
268310
elif byte == b">":
269311
response = [
270312
(
271313
await self._read_response(
272-
disable_decoding=disable_decoding, push_request=push_request
314+
disable_decoding=disable_decoding,
315+
push_request=push_request,
316+
_depth=_depth + 1,
273317
)
274318
)
275319
for _ in range(int(response))
276320
]
277321
response = await self.handle_push_response(response)
278322
if not push_request:
279323
return await self._read_response(
280-
disable_decoding=disable_decoding, push_request=push_request
324+
disable_decoding=disable_decoding,
325+
push_request=push_request,
326+
_depth=_depth + 1,
281327
)
282328
else:
283329
return response
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Tests for the RESP parser recursion depth guard (issue #4116).
2+
3+
A malicious or misbehaving server can send deeply nested aggregate replies
4+
that exhaust the Python call stack. The parsers now bound recursion depth
5+
via ``_MAX_PARSE_DEPTH`` to prevent this.
6+
"""
7+
8+
from unittest.mock import MagicMock
9+
10+
import pytest
11+
12+
from redis._parsers.resp2 import _MAX_PARSE_DEPTH as _MAX_DEPTH_RESP2
13+
from redis._parsers.resp2 import _RESP2Parser
14+
from redis._parsers.resp3 import _MAX_PARSE_DEPTH as _MAX_DEPTH_RESP3
15+
from redis._parsers.resp3 import _RESP3Parser
16+
from redis._parsers.socket import SocketBuffer
17+
from redis.exceptions import InvalidResponse
18+
19+
20+
class _FakeSocket:
21+
"""Minimal socket stand-in that yields pre-built RESP bytes."""
22+
23+
def __init__(self, payload: bytes):
24+
self._payload = payload
25+
26+
def recv(self, n: int) -> bytes:
27+
if not self._payload:
28+
return b""
29+
chunk = self._payload[:n]
30+
self._payload = self._payload[n:]
31+
return chunk
32+
33+
def settimeout(self, _timeout: float) -> None:
34+
pass
35+
36+
37+
def _make_buf(payload: bytes) -> SocketBuffer:
38+
"""Create a SocketBuffer backed by a fake socket with the given payload."""
39+
return SocketBuffer(
40+
_FakeSocket(payload), socket_read_size=65536, socket_timeout=5.0
41+
)
42+
43+
44+
def _make_parser(cls, payload: bytes):
45+
"""Create a parser with buffer and encoder wired up."""
46+
parser = cls(65536)
47+
parser._buffer = _make_buf(payload)
48+
# Wire up a mock encoder so decode doesn't fail
49+
parser.encoder = MagicMock()
50+
parser.encoder.decode = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x
51+
return parser
52+
53+
54+
def _make_nested_arrays(depth: int) -> bytes:
55+
"""Build ``depth`` nested single-element RESP arrays terminating in +OK."""
56+
return (b"*1\r\n" * depth) + b"+OK\r\n"
57+
58+
59+
class TestRESP2DepthGuard:
60+
"""Verify _RESP2Parser rejects excessively nested aggregate replies."""
61+
62+
def test_depth_exceeded_raises(self):
63+
"""Exceeding _MAX_PARSE_DEPTH raises InvalidResponse."""
64+
depth = _MAX_DEPTH_RESP2 + 5
65+
parser = _make_parser(_RESP2Parser, _make_nested_arrays(depth))
66+
67+
with pytest.raises(InvalidResponse, match="nesting depth"):
68+
parser._read_response()
69+
70+
def test_depth_exactly_at_limit_raises(self):
71+
"""Responses nested exactly at the limit still raise."""
72+
depth = _MAX_DEPTH_RESP2 + 1
73+
parser = _make_parser(_RESP2Parser, _make_nested_arrays(depth))
74+
75+
with pytest.raises(InvalidResponse, match="nesting depth"):
76+
parser._read_response()
77+
78+
def test_shallow_nesting_accepted(self):
79+
"""A shallow nested response (within limit) parses without error."""
80+
depth = 10
81+
parser = _make_parser(_RESP2Parser, _make_nested_arrays(depth))
82+
83+
result = parser._read_response()
84+
for _ in range(depth):
85+
assert isinstance(result, list)
86+
result = result[0]
87+
assert result == "OK"
88+
89+
def test_depth_guard_message_includes_limit(self):
90+
"""The error message includes the configured depth limit."""
91+
depth = _MAX_DEPTH_RESP2 + 1
92+
parser = _make_parser(_RESP2Parser, _make_nested_arrays(depth))
93+
94+
with pytest.raises(InvalidResponse, match=str(_MAX_DEPTH_RESP2)):
95+
parser._read_response()
96+
97+
98+
class TestRESP3DepthGuard:
99+
"""Verify _RESP3Parser rejects excessively nested aggregate replies."""
100+
101+
def test_depth_exceeded_raises(self):
102+
"""Exceeding _MAX_PARSE_DEPTH raises InvalidResponse."""
103+
depth = _MAX_DEPTH_RESP3 + 5
104+
parser = _make_parser(_RESP3Parser, _make_nested_arrays(depth))
105+
106+
with pytest.raises(InvalidResponse, match="nesting depth"):
107+
parser._read_response()
108+
109+
def test_shallow_nesting_accepted(self):
110+
"""A shallow nested response (within limit) parses without error."""
111+
depth = 10
112+
parser = _make_parser(_RESP3Parser, _make_nested_arrays(depth))
113+
114+
result = parser._read_response()
115+
for _ in range(depth):
116+
assert isinstance(result, list)
117+
result = result[0]
118+
assert result == "OK"
119+
120+
def test_depth_guard_message_includes_limit(self):
121+
"""The error message includes the configured depth limit."""
122+
depth = _MAX_DEPTH_RESP3 + 1
123+
parser = _make_parser(_RESP3Parser, _make_nested_arrays(depth))
124+
125+
with pytest.raises(InvalidResponse, match=str(_MAX_DEPTH_RESP3)):
126+
parser._read_response()

0 commit comments

Comments
 (0)