Skip to content

Commit 46314d6

Browse files
committed
Add dummy coordinator to tests
1 parent 2440266 commit 46314d6

File tree

3 files changed

+237
-13
lines changed

3 files changed

+237
-13
lines changed

tests/fixtures.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from concurrent import futures
22
from ephemeral_port_reserve import reserve
33
from test_framework.bitcoind import Bitcoind, BitcoindRpcProxy
4+
from test_framework.coordinator import SecureChannel, DummyCoordinator
45
from test_framework.miradord import Miradord
56
from test_framework.utils import (
67
get_descriptors,
@@ -114,7 +115,27 @@ def bitcoind(directory):
114115

115116

116117
@pytest.fixture
117-
def miradord(request, bitcoind, directory):
118+
def secure_channel():
119+
# The Noise keys are interdependant, so generate everything in advance
120+
# to avoid roundtrips
121+
secure_channel = SecureChannel(os.urandom(32), os.urandom(32))
122+
yield secure_channel
123+
124+
125+
@pytest.fixture
126+
def coordinator(secure_channel):
127+
coordinator_port = reserve()
128+
coordinator = DummyCoordinator(
129+
coordinator_port,
130+
secure_channel.coordinator_privkey,
131+
[secure_channel.client_pubkey],
132+
)
133+
coordinator.start()
134+
yield coordinator
135+
136+
137+
@pytest.fixture
138+
def miradord(request, bitcoind, coordinator, secure_channel, directory):
118139
"""If a 'mock_bitcoind' pytest marker is set, it will create a proxy for the communication
119140
from the miradord process to the bitcoind process. An optional 'mocks' parameter can be set
120141
for this marker in order to specify some pre-registered mock of RPC commands.
@@ -130,10 +151,6 @@ def miradord(request, bitcoind, directory):
130151
)
131152
emer_addr = "bcrt1qewc2348370pgw8kjz8gy09z8xyh0d9fxde6nzamd3txc9gkmjqmq8m4cdq"
132153

133-
coordinator_noise_key = (
134-
"d91563973102454a7830137e92d0548bc83b4ea2799f1df04622ca1307381402"
135-
)
136-
137154
bitcoind_cookie = os.path.join(bitcoind.bitcoin_dir, "regtest", ".cookie")
138155
bitcoind_rpcport = bitcoind.rpcport
139156

@@ -151,10 +168,10 @@ def miradord(request, bitcoind, directory):
151168
cpfp_desc,
152169
emer_addr,
153170
reserve(),
171+
secure_channel.client_privkey,
154172
os.urandom(32),
155-
os.urandom(32),
156-
coordinator_noise_key, # Unused yet
157-
reserve(), # Unused yet
173+
coordinator.coordinator_pubkey,
174+
coordinator.port,
158175
bitcoind_rpcport,
159176
bitcoind_cookie,
160177
)
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import cryptography
2+
import json
3+
import os
4+
import select
5+
import socket
6+
import threading
7+
8+
from nacl.public import PrivateKey as Curve25519Private
9+
from noise.connection import NoiseConnection, Keypair
10+
from test_framework.utils import (
11+
TIMEOUT,
12+
)
13+
14+
HANDSHAKE_MSG = b"practical_revault_0"
15+
16+
17+
class SecureChannel:
18+
"""An exchange of paired keys"""
19+
20+
def __init__(
21+
self,
22+
coordinator_privkey,
23+
client_privkey,
24+
):
25+
self.coordinator_privkey = coordinator_privkey
26+
self.coordinator_pubkey = bytes(
27+
Curve25519Private(coordinator_privkey).public_key
28+
)
29+
self.client_privkey = client_privkey
30+
self.client_pubkey = bytes(Curve25519Private(client_privkey).public_key)
31+
32+
33+
class DummyCoordinator:
34+
"""A simple in-RAM synchronization server."""
35+
36+
def __init__(
37+
self,
38+
port,
39+
coordinator_privkey,
40+
client_pubkeys,
41+
):
42+
self.port = port
43+
self.coordinator_privkey = coordinator_privkey
44+
self.coordinator_pubkey = bytes(
45+
Curve25519Private(coordinator_privkey).public_key
46+
)
47+
self.client_pubkeys = client_pubkeys
48+
49+
# Spin up the coordinator proxy
50+
self.s = socket.socket()
51+
self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
52+
self.s.bind(("localhost", self.port))
53+
self.s.listen(1_000)
54+
# Use a pipe to communicate to threads to stop
55+
self.r_close_chann, self.w_close_chann = os.pipe()
56+
57+
# A mapping from txid to pubkey to signature
58+
self.sigs = {}
59+
# A mapping from deposit_outpoint to base64 PSBT
60+
self.spend_txs = {}
61+
62+
def __del__(self):
63+
self.cleanup()
64+
65+
def start(self):
66+
self.server_thread = threading.Thread(target=self.run)
67+
self.server_thread.start()
68+
69+
def cleanup(self):
70+
# Write to the pipe to notify the thread it needs to stop
71+
os.write(self.w_close_chann, b".")
72+
self.server_thread.join()
73+
74+
def run(self):
75+
"""Accept new connections until we are told to stop."""
76+
while True:
77+
r_fds, _, _ = select.select([self.r_close_chann, self.s.fileno()], [], [])
78+
79+
# First check if we've been told to stop, then spawn a new thread per connection
80+
if self.r_close_chann in r_fds:
81+
break
82+
if self.s.fileno() in r_fds:
83+
t = threading.Thread(target=self.connection_handle, daemon=True)
84+
t.start()
85+
86+
def connection_handle(self):
87+
"""Read and treat requests from this client. Blocking."""
88+
client_fd, _ = self.s.accept()
89+
client_fd.settimeout(TIMEOUT // 2)
90+
client_noise = self.server_noise_conn(client_fd)
91+
92+
while True:
93+
# Manually do the select to check if we've been told to stop
94+
r_fds, _, _ = select.select([self.r_close_chann, client_fd], [], [])
95+
if self.r_close_chann in r_fds:
96+
break
97+
req = self.read_msg(client_fd, client_noise)
98+
if req is None:
99+
break
100+
request = json.loads(req)
101+
method, params = request["method"], request["params"]
102+
103+
if method == "sig":
104+
# TODO: mutex
105+
if params["txid"] not in self.sigs:
106+
self.sigs[params["txid"]] = {}
107+
self.sigs[params["txid"]][params["pubkey"]] = params["signature"]
108+
# TODO: remove this useless response from the protocol
109+
resp = {"result": {"ack": True}, "id": request["id"]}
110+
self.send_msg(client_fd, client_noise, json.dumps(resp))
111+
112+
elif method == "get_sigs":
113+
txid = params["txid"]
114+
sigs = self.sigs.get(txid, {})
115+
resp = {"result": {"signatures": sigs}, "id": request["id"]}
116+
self.send_msg(client_fd, client_noise, json.dumps(resp))
117+
118+
elif method == "set_spend_tx":
119+
for outpoint in params["deposit_outpoints"]:
120+
self.spend_txs[outpoint] = params["spend_tx"]
121+
# TODO: remove this useless response from the protocol
122+
resp = {"result": {"ack": True}, "id": request["id"]}
123+
self.send_msg(client_fd, client_noise, json.dumps(resp))
124+
125+
elif method == "get_spend_tx":
126+
spend_tx = self.spend_txs.get(params["deposit_outpoint"])
127+
resp = {"result": {"spend_tx": spend_tx}, "id": request["id"]}
128+
self.send_msg(client_fd, client_noise, json.dumps(resp))
129+
130+
else:
131+
assert False, "Invalid request '{}'".format(method)
132+
133+
def server_noise_conn(self, fd):
134+
"""Do practical-revault's Noise handshake with a given client connection."""
135+
# Read the first message of the handshake only once
136+
data = self.read_data(fd, 32 + len(HANDSHAKE_MSG) + 16)
137+
138+
# We brute force all pubkeys. FIXME!
139+
for pubkey in self.client_pubkeys:
140+
# Set the local and remote static keys
141+
conn = NoiseConnection.from_name(b"Noise_KK_25519_ChaChaPoly_SHA256")
142+
conn.set_as_responder()
143+
conn.set_keypair_from_private_bytes(
144+
Keypair.STATIC, self.coordinator_privkey
145+
)
146+
conn.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pubkey)
147+
148+
# Now, get the first message of the handshake
149+
conn.start_handshake()
150+
try:
151+
plaintext = conn.read_message(data)
152+
except cryptography.exceptions.InvalidTag:
153+
continue
154+
else:
155+
assert plaintext[: len(HANDSHAKE_MSG)] == HANDSHAKE_MSG
156+
157+
# If it didn't fail it was the right key! Finalize the handshake.
158+
resp = conn.write_message()
159+
fd.sendall(resp)
160+
assert conn.handshake_finished
161+
162+
return conn
163+
164+
raise Exception(
165+
f"Unknown client key. Keys: {','.join(k.hex() for k in self.client_pubkeys)}"
166+
)
167+
168+
def read_msg(self, fd, noise_conn):
169+
"""read a noise-encrypted message from this stream.
170+
171+
Returns None if the socket closed.
172+
"""
173+
# Read first the length prefix
174+
cypher_header = self.read_data(fd, 2 + 16)
175+
if cypher_header == b"":
176+
return None
177+
msg_header = noise_conn.decrypt(cypher_header)
178+
msg_len = int.from_bytes(msg_header, "big")
179+
180+
# And then the message
181+
cypher_msg = self.read_data(fd, msg_len)
182+
assert len(cypher_msg) == msg_len
183+
msg = noise_conn.decrypt(cypher_msg).decode("utf-8")
184+
return msg
185+
186+
def send_msg(self, fd, noise_conn, msg):
187+
"""Write a noise-encrypted message from this stream."""
188+
assert isinstance(msg, str)
189+
190+
# Compute the message header
191+
msg_raw = msg.encode("utf-8")
192+
length_prefix = (len(msg_raw) + 16).to_bytes(2, "big")
193+
encrypted_header = noise_conn.encrypt(length_prefix)
194+
encrypted_body = noise_conn.encrypt(msg_raw)
195+
196+
# Then send both the header and the message concatenated
197+
fd.sendall(encrypted_header + encrypted_body)
198+
199+
def read_data(self, fd, max_len):
200+
"""Read data from the given fd until there is nothing to read."""
201+
data = b""
202+
while True:
203+
d = fd.recv(max_len)
204+
if len(d) == max_len:
205+
return d
206+
if d == b"":
207+
return data
208+
data += d

tests/test_framework/miradord.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,12 @@ def __init__(
7676
f.write("daemon = false\n")
7777
f.write(f"log_level = '{LOG_LEVEL}'\n")
7878

79+
f.write(f'listen = "127.0.0.1:{listen_port}"\n')
7980
f.write(f'stakeholder_noise_key = "{stk_noise_key.hex()}"\n')
8081

81-
f.write(f'coordinator_host = "127.0.0.1:{coordinator_port}"\n')
82-
f.write(f'coordinator_noise_key = "{coordinator_noise_key}"\n')
83-
f.write("coordinator_poll_seconds = 5\n")
84-
85-
f.write(f'listen = "127.0.0.1:{listen_port}"\n')
82+
f.write("[coordinator_config]\n")
83+
f.write(f'host = "127.0.0.1:{coordinator_port}"\n')
84+
f.write(f'noise_key = "{coordinator_noise_key.hex()}"\n')
8685

8786
f.write("[scripts_config]\n")
8887
f.write(f'deposit_descriptor = "{deposit_desc}"\n')

0 commit comments

Comments
 (0)