Skip to content

Commit e435ec2

Browse files
committed
Add pre-migration tests for networking internals
1 parent 7c3f88a commit e435ec2

File tree

5 files changed

+996
-2
lines changed

5 files changed

+996
-2
lines changed

tests/test_blacklisting.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,62 @@
99

1010
from unittest.mock import patch
1111

12+
from netaddr import IPAddress, IPSet
1213
from twisted.internet.error import DNSLookupError
1314
from twisted.test.proto_helpers import StringTransport
1415
from twisted.trial.unittest import TestCase
1516
from twisted.web.client import Agent
1617

17-
from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
18+
from sydent.http.blacklisting_reactor import (
19+
BlacklistingReactorWrapper,
20+
check_against_blacklist,
21+
)
1822
from sydent.http.srvresolver import Server
1923

2024
from tests.utils import AsyncMock, make_request, make_sydent
2125

2226

27+
class CheckAgainstBlacklistTest(TestCase):
28+
"""Tests for the check_against_blacklist() utility function."""
29+
30+
def test_blacklisted_ipv4(self) -> None:
31+
blacklist = IPSet(["5.0.0.0/8"])
32+
self.assertTrue(check_against_blacklist(IPAddress("5.1.2.3"), None, blacklist))
33+
34+
def test_not_blacklisted(self) -> None:
35+
blacklist = IPSet(["5.0.0.0/8"])
36+
self.assertFalse(check_against_blacklist(IPAddress("1.2.3.4"), None, blacklist))
37+
38+
def test_whitelisted_overrides_blacklist(self) -> None:
39+
blacklist = IPSet(["5.0.0.0/8"])
40+
whitelist = IPSet(["5.1.1.1"])
41+
self.assertFalse(
42+
check_against_blacklist(IPAddress("5.1.1.1"), whitelist, blacklist)
43+
)
44+
45+
def test_ipv6_loopback_blocked(self) -> None:
46+
blacklist = IPSet(["::1/128"])
47+
self.assertTrue(check_against_blacklist(IPAddress("::1"), None, blacklist))
48+
49+
def test_ipv6_link_local_blocked(self) -> None:
50+
blacklist = IPSet(["fe80::/10"])
51+
self.assertTrue(check_against_blacklist(IPAddress("fe80::1"), None, blacklist))
52+
53+
def test_ipv4_mapped_ipv6_blocked(self) -> None:
54+
"""IPv4-mapped IPv6 addresses (::ffff:127.0.0.1) should be blockable."""
55+
blacklist = IPSet(["127.0.0.0/8", "::ffff:127.0.0.0/104"])
56+
self.assertTrue(
57+
check_against_blacklist(IPAddress("::ffff:127.0.0.1"), None, blacklist)
58+
)
59+
60+
def test_ipv6_whitelist_overrides(self) -> None:
61+
blacklist = IPSet(["fe80::/10"])
62+
whitelist = IPSet(["fe80::1"])
63+
self.assertFalse(
64+
check_against_blacklist(IPAddress("fe80::1"), whitelist, blacklist)
65+
)
66+
67+
2368
class BlacklistingAgentTest(TestCase):
2469
def setUp(self):
2570
config = {

tests/test_httpcommon.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2025 New Vector Ltd.
2+
#
3+
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4+
# Please see LICENSE files in the repository root for full details.
5+
6+
from io import BytesIO
7+
from unittest.mock import MagicMock
8+
9+
from twisted.internet import defer
10+
from twisted.python.failure import Failure
11+
from twisted.trial import unittest
12+
from twisted.web.client import ResponseDone
13+
from twisted.web.http import PotentialDataLoss
14+
from twisted.web.iweb import UNKNOWN_LENGTH
15+
16+
from sydent.http.httpcommon import (
17+
BodyExceededMaxSize,
18+
SizeLimitingRequest,
19+
_DiscardBodyWithMaxSizeProtocol,
20+
_ReadBodyWithMaxSizeProtocol,
21+
read_body_with_max_size,
22+
)
23+
24+
25+
class ReadBodyWithMaxSizeProtocolTest(unittest.TestCase):
26+
"""Tests for _ReadBodyWithMaxSizeProtocol."""
27+
28+
def _make_protocol(
29+
self, max_size: int | None = None
30+
) -> tuple[_ReadBodyWithMaxSizeProtocol, "defer.Deferred[bytes]"]:
31+
d: defer.Deferred[bytes] = defer.Deferred()
32+
protocol = _ReadBodyWithMaxSizeProtocol(d, max_size)
33+
protocol.transport = MagicMock()
34+
return protocol, d
35+
36+
def test_reads_body_under_limit(self) -> None:
37+
"""Body under the limit is read successfully."""
38+
protocol, d = self._make_protocol(max_size=100)
39+
protocol.dataReceived(b"hello ")
40+
protocol.dataReceived(b"world")
41+
protocol.connectionLost(Failure(ResponseDone()))
42+
self.assertEqual(self.successResultOf(d), b"hello world")
43+
44+
def test_exceeds_max_size(self) -> None:
45+
"""Body exceeding max_size triggers BodyExceededMaxSize."""
46+
protocol, d = self._make_protocol(max_size=5)
47+
protocol.dataReceived(b"too much data")
48+
self.failureResultOf(d, BodyExceededMaxSize)
49+
protocol.transport.abortConnection.assert_called_once()
50+
51+
def test_exact_boundary(self) -> None:
52+
"""Body exactly at max_size triggers the error (>= check)."""
53+
protocol, d = self._make_protocol(max_size=5)
54+
protocol.dataReceived(b"12345")
55+
self.failureResultOf(d, BodyExceededMaxSize)
56+
57+
def test_no_max_size(self) -> None:
58+
"""With max_size=None, any amount of data is accepted."""
59+
protocol, d = self._make_protocol(max_size=None)
60+
protocol.dataReceived(b"x" * 10000)
61+
protocol.connectionLost(Failure(ResponseDone()))
62+
self.assertEqual(len(self.successResultOf(d)), 10000)
63+
64+
def test_potential_data_loss_succeeds(self) -> None:
65+
"""PotentialDataLoss is treated as success (same as ResponseDone)."""
66+
protocol, d = self._make_protocol(max_size=1000)
67+
protocol.dataReceived(b"partial data")
68+
protocol.connectionLost(Failure(PotentialDataLoss()))
69+
self.assertEqual(self.successResultOf(d), b"partial data")
70+
71+
def test_connection_error_propagates(self) -> None:
72+
"""An unexpected connection loss reason propagates the error."""
73+
protocol, d = self._make_protocol(max_size=1000)
74+
protocol.dataReceived(b"some data")
75+
error = Failure(Exception("connection reset"))
76+
protocol.connectionLost(error)
77+
f = self.failureResultOf(d)
78+
self.assertIsInstance(f.value, Exception)
79+
80+
81+
class DiscardBodyWithMaxSizeProtocolTest(unittest.TestCase):
82+
"""Tests for _DiscardBodyWithMaxSizeProtocol."""
83+
84+
def test_errors_on_data_received(self) -> None:
85+
"""Fires errback and aborts connection on first data."""
86+
d: defer.Deferred[bytes] = defer.Deferred()
87+
protocol = _DiscardBodyWithMaxSizeProtocol(d)
88+
protocol.transport = MagicMock()
89+
protocol.dataReceived(b"any data")
90+
self.failureResultOf(d, BodyExceededMaxSize)
91+
protocol.transport.abortConnection.assert_called_once()
92+
93+
def test_errors_on_connection_lost(self) -> None:
94+
"""Also fires errback if connectionLost fires first."""
95+
d: defer.Deferred[bytes] = defer.Deferred()
96+
protocol = _DiscardBodyWithMaxSizeProtocol(d)
97+
protocol.transport = MagicMock()
98+
protocol.connectionLost(Failure(ResponseDone()))
99+
self.failureResultOf(d, BodyExceededMaxSize)
100+
101+
def test_idempotent(self) -> None:
102+
"""Multiple calls to _maybe_fail don't double-fire the deferred."""
103+
d: defer.Deferred[bytes] = defer.Deferred()
104+
protocol = _DiscardBodyWithMaxSizeProtocol(d)
105+
protocol.transport = MagicMock()
106+
protocol.dataReceived(b"first")
107+
protocol.dataReceived(b"second") # should not raise
108+
protocol.connectionLost(Failure(ResponseDone()))
109+
self.failureResultOf(d, BodyExceededMaxSize)
110+
111+
112+
class ReadBodyWithMaxSizeFunctionTest(unittest.TestCase):
113+
"""Tests for the read_body_with_max_size() top-level function."""
114+
115+
def _make_response(self, length: int | object = UNKNOWN_LENGTH) -> MagicMock:
116+
response = MagicMock()
117+
response.length = length
118+
return response
119+
120+
def test_content_length_exceeds_limit_uses_discard(self) -> None:
121+
"""When Content-Length > max_size, the Discard protocol is used."""
122+
response = self._make_response(length=200)
123+
read_body_with_max_size(response, max_size=100)
124+
response.deliverBody.assert_called_once()
125+
protocol = response.deliverBody.call_args[0][0]
126+
self.assertIsInstance(protocol, _DiscardBodyWithMaxSizeProtocol)
127+
128+
def test_content_length_under_limit_uses_reader(self) -> None:
129+
"""When Content-Length <= max_size, the Read protocol is used."""
130+
response = self._make_response(length=50)
131+
read_body_with_max_size(response, max_size=100)
132+
response.deliverBody.assert_called_once()
133+
protocol = response.deliverBody.call_args[0][0]
134+
self.assertIsInstance(protocol, _ReadBodyWithMaxSizeProtocol)
135+
136+
def test_unknown_length_uses_reader(self) -> None:
137+
"""When Content-Length is unknown, the Read protocol is used."""
138+
response = self._make_response(length=UNKNOWN_LENGTH)
139+
read_body_with_max_size(response, max_size=100)
140+
response.deliverBody.assert_called_once()
141+
protocol = response.deliverBody.call_args[0][0]
142+
self.assertIsInstance(protocol, _ReadBodyWithMaxSizeProtocol)
143+
144+
def test_no_max_size_uses_reader(self) -> None:
145+
"""When max_size is None, always uses the Read protocol."""
146+
response = self._make_response(length=999999)
147+
read_body_with_max_size(response, max_size=None)
148+
response.deliverBody.assert_called_once()
149+
protocol = response.deliverBody.call_args[0][0]
150+
self.assertIsInstance(protocol, _ReadBodyWithMaxSizeProtocol)
151+
152+
153+
class SizeLimitingRequestTest(unittest.TestCase):
154+
"""Tests for SizeLimitingRequest.handleContentChunk()."""
155+
156+
def _make_request(self) -> SizeLimitingRequest:
157+
"""Create a minimal SizeLimitingRequest for testing."""
158+
req = SizeLimitingRequest.__new__(SizeLimitingRequest)
159+
req.content = BytesIO()
160+
req.transport = MagicMock()
161+
# The client attribute is accessed for logging.
162+
req.client = MagicMock()
163+
return req
164+
165+
def test_accepts_data_under_limit(self) -> None:
166+
"""Chunks totalling under MAX_REQUEST_SIZE are accepted."""
167+
req = self._make_request()
168+
data = b"x" * 1000
169+
req.handleContentChunk(data)
170+
self.assertEqual(req.content.tell(), 1000)
171+
req.transport.abortConnection.assert_not_called()
172+
173+
def test_aborts_on_oversize(self) -> None:
174+
"""Connection is aborted when cumulative data exceeds MAX_REQUEST_SIZE."""
175+
req = self._make_request()
176+
# Write data right up to the limit.
177+
req.content.write(b"x" * (512 * 1024))
178+
req.content.seek(512 * 1024)
179+
# The next chunk pushes over.
180+
req.handleContentChunk(b"x")
181+
req.transport.abortConnection.assert_called_once()

tests/test_replication.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
2-
from unittest.mock import Mock
2+
from unittest.mock import Mock, patch
33

4+
from OpenSSL import crypto
45
from twisted.internet import defer
56
from twisted.trial import unittest
67
from twisted.web.client import Response
@@ -190,3 +191,101 @@ def request(method, uri, headers, body):
190191
# will push will be 1, so we need to subtract 1 when figuring out which index
191192
# to lookup.
192193
self.assertDictEqual(assoc, signed_assocs[int(assoc_id) - 1])
194+
195+
196+
class ReplicationCNTest(unittest.TestCase):
197+
"""Tests for peer certificate CN extraction edge cases in the replication servlet."""
198+
199+
def setUp(self) -> None:
200+
self.sydent = make_sydent()
201+
202+
# Insert a known peer.
203+
cur = self.sydent.db.cursor()
204+
cur.execute(
205+
"INSERT INTO peers (name, port, lastSentVersion, active) VALUES (?, ?, ?, ?)",
206+
("fake.server", 1234, 0, 1),
207+
)
208+
peer_public_key_base64 = "+vB8mTaooD/MA8YYZM8t9+vnGhP1937q2icrqPV9JTs"
209+
cur.execute(
210+
"INSERT INTO peer_pubkeys (peername, alg, key) VALUES (?, ?, ?)",
211+
("fake.server", "ed25519", peer_public_key_base64),
212+
)
213+
self.sydent.db.commit()
214+
215+
def test_known_peer_cn_accepted(self) -> None:
216+
"""A peer cert with CN matching a known peer is accepted (existing test validates this,
217+
but let's have a focused unit-level check)."""
218+
self.sydent.run()
219+
220+
# The FakeChannel.getPeerCertificate() returns a cert with CN=fake.server,
221+
# and we inserted fake.server as a peer. A valid request should be accepted.
222+
body = {"sgAssocs": {}}
223+
request, channel = make_request(
224+
self.sydent.reactor,
225+
self.sydent.replicationHttpsServer.factory,
226+
"POST",
227+
"/_matrix/identity/replicate/v1/push",
228+
body,
229+
)
230+
self.assertEqual(channel.code, 200)
231+
232+
def test_unknown_peer_cn_rejected(self) -> None:
233+
"""A peer cert with CN that doesn't match any known peer returns 403."""
234+
self.sydent.run()
235+
236+
# Generate a cert with a CN that is NOT in the peers table.
237+
unknown_key = crypto.PKey()
238+
unknown_key.generate_key(crypto.TYPE_RSA, 2048)
239+
unknown_cert = crypto.X509()
240+
unknown_cert.get_subject().CN = "unknown.server"
241+
unknown_cert.set_serial_number(1000)
242+
unknown_cert.gmtime_adj_notBefore(0)
243+
unknown_cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
244+
unknown_cert.set_issuer(unknown_cert.get_subject())
245+
unknown_cert.set_pubkey(unknown_key)
246+
unknown_cert.sign(unknown_key, "sha256")
247+
248+
# Patch FakeChannel.getPeerCertificate to return our unknown cert.
249+
with patch(
250+
"tests.utils.FakeChannel.getPeerCertificate", return_value=unknown_cert
251+
):
252+
body = {"sgAssocs": {}}
253+
request, channel = make_request(
254+
self.sydent.reactor,
255+
self.sydent.replicationHttpsServer.factory,
256+
"POST",
257+
"/_matrix/identity/replicate/v1/push",
258+
body,
259+
)
260+
self.assertEqual(channel.code, 403)
261+
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_PEER")
262+
263+
def test_no_cn_rejected(self) -> None:
264+
"""A peer cert with no commonName returns 403."""
265+
self.sydent.run()
266+
267+
# Generate a cert with no CN set.
268+
no_cn_key = crypto.PKey()
269+
no_cn_key.generate_key(crypto.TYPE_RSA, 2048)
270+
no_cn_cert = crypto.X509()
271+
# Don't set CN — leave subject empty.
272+
no_cn_cert.set_serial_number(2000)
273+
no_cn_cert.gmtime_adj_notBefore(0)
274+
no_cn_cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
275+
no_cn_cert.set_issuer(no_cn_cert.get_subject())
276+
no_cn_cert.set_pubkey(no_cn_key)
277+
no_cn_cert.sign(no_cn_key, "sha256")
278+
279+
with patch(
280+
"tests.utils.FakeChannel.getPeerCertificate", return_value=no_cn_cert
281+
):
282+
body = {"sgAssocs": {}}
283+
request, channel = make_request(
284+
self.sydent.reactor,
285+
self.sydent.replicationHttpsServer.factory,
286+
"POST",
287+
"/_matrix/identity/replicate/v1/push",
288+
body,
289+
)
290+
self.assertEqual(channel.code, 403)
291+
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_PEER")

0 commit comments

Comments
 (0)