1
1
import asyncio
2
+ from typing import List
2
3
3
4
import electrum_ecc as ecc
4
5
5
6
from electrum import util
7
+ from electrum import lntransport
6
8
from electrum .lntransport import LNPeerAddr , LNResponderTransport , LNTransport , extract_nodeid , split_host_port , ConnStringFormatError
7
9
from electrum .util import OldTaskGroup
8
10
@@ -71,6 +73,7 @@ async def write_messages(transport, expected_messages):
71
73
72
74
async def cb (reader , writer ):
73
75
t = LNResponderTransport (responder_key .get_secret_bytes (), reader , writer )
76
+ transports .append (t )
74
77
self .assertEqual (await t .handshake (), initiator_key .get_public_key_bytes ())
75
78
async with OldTaskGroup () as group :
76
79
await group .spawn (read_messages (t , messages_sent_by_client ))
@@ -79,12 +82,14 @@ async def cb(reader, writer):
79
82
async def connect (port : int ):
80
83
peer_addr = LNPeerAddr ('127.0.0.1' , port , responder_key .get_public_key_bytes ())
81
84
t = LNTransport (initiator_key .get_secret_bytes (), peer_addr , e_proxy = None )
85
+ transports .append (t )
82
86
await t .handshake ()
83
87
async with OldTaskGroup () as group :
84
88
await group .spawn (read_messages (t , messages_sent_by_server ))
85
89
await group .spawn (write_messages (t , messages_sent_by_client ))
86
90
server_shaked .set ()
87
91
92
+ transports = [] # type: List[lntransport.LNTransportBase]
88
93
async def f ():
89
94
server = await asyncio .start_server (cb , '127.0.0.1' , port = None )
90
95
server_port = server .sockets [0 ].getsockname ()[1 ]
@@ -94,6 +99,8 @@ async def f():
94
99
await group .spawn (responder_shaked .wait ())
95
100
await group .spawn (server_shaked .wait ())
96
101
finally :
102
+ for t in transports :
103
+ t .close ()
97
104
server .close ()
98
105
99
106
await f ()
0 commit comments