Skip to content

Commit 4b4ae21

Browse files
committed
Add test
1 parent fa70db5 commit 4b4ae21

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

tests/test_protocol.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import struct
2323

2424
import edgedb
25+
from gel import con_utils
2526

2627
from edb.server import args as srv_args
2728
from edb.server import compiler
@@ -30,6 +31,7 @@
3031
from edb.testbase import server as tb
3132
from edb.testbase import connection as tconn
3233
from edb.testbase.protocol.test import ProtocolTestCase
34+
from edb.tools import test
3335

3436

3537
def pack_i32s(*args):
@@ -896,6 +898,54 @@ async def _test_proto_discard_prepared_statement_in_script(self):
896898
finally:
897899
await self.con.recv_match(protocol.ReadyForCommand)
898900

901+
@test.xerror("FIXME")
902+
async def test_proto_tls_close_notify(self):
903+
# Setup connection with custom protocols
904+
args = self.get_connect_args(database=self.get_database_name())
905+
args.setdefault('dsn', None)
906+
args.setdefault('host', None)
907+
args.setdefault('port', None)
908+
args.setdefault('user', None)
909+
args.setdefault('password', None)
910+
args.setdefault('secret_key', None)
911+
args.setdefault('branch', None)
912+
args.setdefault('database', None)
913+
timeout = args.setdefault('timeout', 60)
914+
args.setdefault('tls_ca', None)
915+
args.setdefault('tls_ca_file', None)
916+
args.setdefault('tls_security', 'default')
917+
args.setdefault('credentials', None)
918+
args.setdefault('credentials_file', None)
919+
connect_config, client_config = con_utils.parse_connect_arguments(
920+
**args,
921+
command_timeout=None,
922+
server_settings=None,
923+
tls_server_name=None,
924+
wait_until_available=timeout,
925+
)
926+
loop = asyncio.get_running_loop()
927+
gel_protocol = GelProtocol(connect_config, loop)
928+
protocol_factory = lambda: StrictTlsClientProtocol(
929+
loop, gel_protocol, connect_config.ssl_ctx, None
930+
)
931+
addr = connect_config.address
932+
if isinstance(addr, str):
933+
connector = loop.create_unix_connection(protocol_factory, addr)
934+
else:
935+
connector = loop.create_connection(protocol_factory, *addr)
936+
tls_transport, tls_protocol = await connector
937+
938+
# Complete the Gel handshake
939+
await gel_protocol.connect()
940+
941+
# Now, close the connection without sending a `Terminate`, but with
942+
# only a TLS `close_notify`.
943+
tls_transport.close()
944+
945+
# We expect the server to reply with a `close_notify` too. If not,
946+
# this will fail with the error in StrictTlsClientProtocol.
947+
await gel_protocol.wait_closed()
948+
899949

900950
class TestServerCancellation(tb.TestCase):
901951
@contextlib.asynccontextmanager
@@ -1029,3 +1079,33 @@ async def test_proto_gh3170_connection_lost_error(self):
10291079
except Exception:
10301080
await con.aclose()
10311081
raise
1082+
1083+
1084+
class GelProtocol(protocol.protocol.Protocol):
1085+
_close_fut = None
1086+
1087+
def connection_lost(self, exc):
1088+
if self._close_fut is not None:
1089+
if exc is None:
1090+
self._close_fut.set_result(None)
1091+
else:
1092+
self._close_fut.set_exception(exc)
1093+
super().connection_lost(exc)
1094+
1095+
async def wait_closed(self):
1096+
self._close_fut = asyncio.Future()
1097+
try:
1098+
await self._close_fut
1099+
finally:
1100+
self._close_fut = None
1101+
1102+
1103+
class StrictTlsClientProtocol(asyncio.sslproto.SSLProtocol):
1104+
def connection_lost(self, exc):
1105+
if self._state == asyncio.sslproto.SSLProtocolState.WRAPPED:
1106+
if exc is None:
1107+
exc = ConnectionResetError(
1108+
'peer closed connection without sending '
1109+
'TLS close_notify'
1110+
)
1111+
super().connection_lost(exc)

0 commit comments

Comments
 (0)