Skip to content

Commit 7e4c595

Browse files
committed
ENH: extend proxy to bind or connect for either port
1 parent a75224b commit 7e4c595

File tree

2 files changed

+147
-50
lines changed

2 files changed

+147
-50
lines changed

src/bluesky/callbacks/zmq.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -235,52 +235,79 @@ class Proxy:
235235

236236
@staticmethod
237237
def configure_server_socket(
238-
ctx, sock_type, address: Union[str, tuple, int, None], curve: Union[ServerCurve, None], zmq, auth_class
238+
ctx, sock_type, address: Union[str, tuple, int, None], curve: Union[ServerCurve, ClientCurve, None], zmq, auth_class, bind: bool = True
239239
):
240240
socket = ctx.socket(sock_type)
241241
norm_address = _normalize_address(address)
242-
logger.debug(f"Creating socket of type {sock_type} for address {norm_address}")
242+
logger.debug(f"Creating socket of type {sock_type} for address {norm_address}, bind={bind}")
243243
random_port = False
244244
if norm_address.startswith("tcp"):
245245
if ":" not in norm_address[6:]:
246246
random_port = True
247+
247248
if curve is not None:
248-
logger.debug(f"Configuring CURVE security with secret_path={curve.secret_path}")
249-
# build authenticator
250-
auth = auth_class(ctx)
251-
auth.start()
252-
logger.debug("Started ZMQ authenticator")
253-
if curve.allow is not None:
254-
auth.allow(*curve.allow)
255-
logger.debug(f"Configured IP address allowlist: {curve.allow}")
256-
257-
# Tell the authenticator how to handle CURVE requests
258-
if curve.client_public_keys is None:
259-
# accept any client that knows the public key
260-
auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY)
261-
logger.debug("Configured CURVE to allow any client with valid public key")
249+
if bind:
250+
# Server mode - expect ServerCurve
251+
if not isinstance(curve, ServerCurve):
252+
raise TypeError("When bind=True, curve must be a ServerCurve instance")
253+
logger.debug(f"Configuring CURVE server security with secret_path={curve.secret_path}")
254+
# build authenticator
255+
auth = auth_class(ctx)
256+
auth.start()
257+
logger.debug("Started ZMQ authenticator")
258+
if curve.allow is not None:
259+
auth.allow(*curve.allow)
260+
logger.debug(f"Configured IP address allowlist: {curve.allow}")
261+
262+
# Tell the authenticator how to handle CURVE requests
263+
if curve.client_public_keys is None:
264+
# accept any client that knows the public key
265+
auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY)
266+
logger.debug("Configured CURVE to allow any client with valid public key")
267+
else:
268+
auth.configure_curve(domain="*", location=curve.client_public_keys)
269+
logger.debug(f"Configured CURVE client public keys from: {curve.client_public_keys}")
270+
271+
# get public and private keys from the certificate
272+
server_public, server_secret = zmq.auth.load_certificate(curve.secret_path)
273+
# attach them to the socket
274+
socket.setsockopt(zmq.CURVE_PUBLICKEY, server_public)
275+
socket.setsockopt(zmq.CURVE_SECRETKEY, server_secret)
276+
socket.setsockopt(zmq.CURVE_SERVER, True)
277+
logger.debug("Applied CURVE keys and enabled CURVE server mode")
278+
else:
279+
# Client mode - expect ClientCurve
280+
if not isinstance(curve, ClientCurve):
281+
raise TypeError("When bind=False, curve must be a ClientCurve instance")
282+
logger.debug(f"Configuring CURVE client security with secret_path={curve.secret_path}")
283+
284+
# Load the client cert pair
285+
client_public, client_secret = zmq.auth.load_certificate(curve.secret_path)
286+
socket.setsockopt(zmq.CURVE_PUBLICKEY, client_public)
287+
if client_secret is None:
288+
raise ValueError("The client secret key could not be found.")
289+
socket.setsockopt(zmq.CURVE_SECRETKEY, client_secret)
290+
291+
# Load the server public key and register with the socket
292+
server_key, _ = zmq.auth.load_certificate(curve.server_public_key)
293+
socket.setsockopt(zmq.CURVE_SERVERKEY, server_key)
294+
logger.debug("Applied CURVE client keys and server public key")
295+
296+
if bind:
297+
if random_port:
298+
port = socket.bind_to_random_port(norm_address)
299+
logger.debug(f"Bound to random port: {port}")
262300
else:
263-
auth.configure_curve(domain="*", location=curve.client_public_keys)
264-
logger.debug(f"Configured CURVE client public keys from: {curve.client_public_keys}")
265-
266-
# get public and private keys from the certificate
267-
server_public, server_secret = zmq.auth.load_certificate(curve.secret_path)
268-
# attach them to the
269-
socket.setsockopt(zmq.CURVE_PUBLICKEY, server_public)
270-
socket.setsockopt(zmq.CURVE_SECRETKEY, server_secret)
271-
socket.setsockopt(zmq.CURVE_SERVER, True)
272-
logger.debug("Applied CURVE keys and enabled CURVE server mode")
273-
274-
if random_port:
275-
port = socket.bind_to_random_port(norm_address)
276-
logger.debug(f"Bound to random port: {port}")
301+
port = socket.bind(norm_address)
302+
logger.debug(f"Bound to address: {norm_address}")
277303
else:
278-
port = socket.bind(norm_address)
279-
logger.debug(f"Bound to address: {norm_address}")
304+
socket.connect(norm_address)
305+
port = norm_address
306+
logger.debug(f"Connected to address: {norm_address}")
280307

281308
return socket, port
282309

283-
def __init__(self, in_address=None, out_address=None, *, zmq=None, in_curve=None, out_curve=None):
310+
def __init__(self, in_address=None, out_address=None, *, zmq=None, in_curve=None, out_curve=None, in_bind=True, out_bind=True):
284311
if zmq is None:
285312
import zmq
286313
self.zmq = zmq
@@ -291,12 +318,12 @@ def __init__(self, in_address=None, out_address=None, *, zmq=None, in_curve=None
291318
from zmq.auth.thread import ThreadAuthenticator
292319

293320
frontend, in_port = self.configure_server_socket(
294-
context, zmq.SUB, in_address, in_curve, zmq, ThreadAuthenticator
321+
context, zmq.SUB, in_address, in_curve, zmq, ThreadAuthenticator, bind=in_bind
295322
)
296323
frontend.setsockopt_string(zmq.SUBSCRIBE, "")
297324

298325
backend, out_port = self.configure_server_socket(
299-
context, zmq.PUB, out_address, out_curve, zmq, ThreadAuthenticator
326+
context, zmq.PUB, out_address, out_curve, zmq, ThreadAuthenticator, bind=out_bind
300327
)
301328

302329
except BaseException:
@@ -315,8 +342,8 @@ def __init__(self, in_address=None, out_address=None, *, zmq=None, in_curve=None
315342
...
316343
raise
317344
else:
318-
self.in_port = in_port.addr if hasattr(in_port, "addr") else _normalize_address(in_port)
319-
self.out_port = out_port.addr if hasattr(out_port, "addr") else _normalize_address(out_port)
345+
self.in_port = in_port.addr if hasattr(in_port, "addr") else _normalize_address(in_port) if in_bind else in_port
346+
self.out_port = out_port.addr if hasattr(out_port, "addr") else _normalize_address(out_port) if out_bind else out_port
320347
self._frontend = frontend
321348
self._backend = backend
322349
self._context = context

src/bluesky/commandline/zmq_proxy.py

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,37 @@ def main():
4444
parser = argparse.ArgumentParser(description=DESC)
4545
parser.add_argument("--in-address", help="port that RunEngines should broadcast to")
4646
parser.add_argument("--out-address", help="port that subscribers should subscribe to")
47-
# CURVE security options for input socket
47+
48+
# Socket mode options
49+
parser.add_argument("--in-mode", choices=["bind", "connect"], default="bind", help="Input socket mode: bind (server) or connect (client)")
50+
parser.add_argument("--out-mode", choices=["bind", "connect"], default="bind", help="Output socket mode: bind (server) or connect (client)")
51+
52+
# CURVE security options for input socket (server mode)
4853
parser.add_argument("--in-curve-secret", type=str, help="Path to CURVE server secret key for input socket")
4954
parser.add_argument(
5055
"--in-curve-client-keys", type=str, help="Path to folder of client public keys for input socket"
5156
)
5257
parser.add_argument(
5358
"--in-curve-allow", type=str, nargs="*", help="Set of IP addresses to allow for input socket"
5459
)
55-
# CURVE security options for output socket
60+
61+
# CURVE security options for input socket (client mode)
62+
parser.add_argument("--in-client-secret", type=str, help="Path to client secret key for input socket")
63+
parser.add_argument("--in-server-public", type=str, help="Path to server public key for input socket")
64+
65+
# CURVE security options for output socket (server mode)
5666
parser.add_argument("--out-curve-secret", type=str, help="Path to CURVE server secret key for output socket")
5767
parser.add_argument(
5868
"--out-curve-client-keys", type=str, help="Path to folder of client public keys for output socket"
5969
)
6070
parser.add_argument(
6171
"--out-curve-allow", type=str, nargs="*", help="Set of IP addresses to allow for output socket"
6272
)
73+
74+
# CURVE security options for output socket (client mode)
75+
parser.add_argument("--out-client-secret", type=str, help="Path to client secret key for output socket")
76+
parser.add_argument("--out-server-public", type=str, help="Path to server public key for output socket")
77+
6378
parser.add_argument(
6479
"--verbose",
6580
"-v",
@@ -69,6 +84,29 @@ def main():
6984
parser.add_argument("--logfile", type=str, help="Redirect logging output to a file on disk.")
7085
args = parser.parse_args()
7186

87+
in_bind = args.in_mode == "bind"
88+
out_bind = args.out_mode == "bind"
89+
90+
# Validate CURVE configuration consistency for input
91+
if in_bind:
92+
# Server mode - check for client mode flags
93+
if args.in_client_secret or args.in_server_public:
94+
raise ValueError("Cannot use client CURVE options (--in-client-secret, --in-server-public) when input is in bind mode")
95+
else:
96+
# Client mode - check for server mode flags
97+
if args.in_curve_secret or args.in_curve_client_keys or args.in_curve_allow:
98+
raise ValueError("Cannot use server CURVE options (--in-curve-secret, --in-curve-client-keys, --in-curve-allow) when input is in connect mode")
99+
100+
# Validate CURVE configuration consistency for output
101+
if out_bind:
102+
# Server mode - check for client mode flags
103+
if args.out_client_secret or args.out_server_public:
104+
raise ValueError("Cannot use client CURVE options (--out-client-secret, --out-server-public) when output is in bind mode")
105+
else:
106+
# Client mode - check for server mode flags
107+
if args.out_curve_secret or args.out_curve_client_keys or args.out_curve_allow:
108+
raise ValueError("Cannot use server CURVE options (--out-curve-secret, --out-curve-client-keys, --out-curve-allow) when output is in connect mode")
109+
72110
# Helper to build ServerCurve or None
73111
def build_server_curve(
74112
secret: str | None, client_keys: str | None, allow: list[str] | None
@@ -82,8 +120,26 @@ def build_server_curve(
82120
allow_set = set(allow) if allow else None
83121
return ServerCurve(secret_path=secret_path, client_public_keys=client_public_keys, allow=allow_set)
84122

85-
in_curve = build_server_curve(args.in_curve_secret, args.in_curve_client_keys, args.in_curve_allow)
86-
out_curve = build_server_curve(args.out_curve_secret, args.out_curve_client_keys, args.out_curve_allow)
123+
# Helper to build ClientCurve or None
124+
def build_client_curve(
125+
secret: str | None, server_public: str | None
126+
) -> ClientCurve | None:
127+
if secret is None and server_public is None:
128+
return None
129+
if secret is None or server_public is None:
130+
raise ValueError("Both client secret and server public key must be provided for CURVE client mode")
131+
return ClientCurve(secret_path=Path(secret), server_public_key=Path(server_public))
132+
133+
# Build CURVE configurations based on mode
134+
if in_bind:
135+
in_curve = build_server_curve(args.in_curve_secret, args.in_curve_client_keys, args.in_curve_allow)
136+
else:
137+
in_curve = build_client_curve(args.in_client_secret, args.in_server_public)
138+
139+
if out_bind:
140+
out_curve = build_server_curve(args.out_curve_secret, args.out_curve_client_keys, args.out_curve_allow)
141+
else:
142+
out_curve = build_client_curve(args.out_client_secret, args.out_server_public)
87143

88144
# Configure logging BEFORE creating the proxy so we capture socket configuration debug messages
89145
if args.verbose:
@@ -109,19 +165,33 @@ def build_server_curve(
109165
out_address = int(args.out_address)
110166
except (ValueError, TypeError):
111167
out_address = args.out_address
112-
proxy = Proxy(in_address, out_address, in_curve=in_curve, out_curve=out_curve)
168+
proxy = Proxy(in_address, out_address, in_curve=in_curve, out_curve=out_curve, in_bind=in_bind, out_bind=out_bind)
113169
print("Receiving on address %s; publishing to address %s." % (proxy.in_port, proxy.out_port))
114170
if args.verbose:
115171
# Set daemon to kill all threads upon IPython exit
116-
if out_curve is None:
117-
# We would need client certificates setup to connect to the output port
118-
client_curve = None
172+
dispatcher_address = None
173+
client_curve = None
174+
175+
if out_bind:
176+
# Output is bound - we can connect to it
177+
dispatcher_address = proxy.out_port
178+
if out_curve is None:
179+
client_curve = None
180+
else:
181+
# this looks funny, but the secret file also contains the public key
182+
# this bets that the public key for the server is in the folder of public keys
183+
# it will accept and that we can route to the output port on an allowed ip
184+
client_curve = ClientCurve(out_curve.secret_path, out_curve.secret_path)
185+
elif not in_bind:
186+
# Output is connect and input is connect - connect to same source as input
187+
dispatcher_address = in_address
188+
client_curve = in_curve # Use the same curve config as input
119189
else:
120-
# this looks funny, but the secret file also contains the public key
121-
# this bets that the public key for the server is in the folder of public keys
122-
# it will accept and that we can route to the output port on an allowed ip
123-
client_curve = ClientCurve(out_curve.secret_path, out_curve.secret_path)
124-
threading.Thread(target=start_dispatcher, args=(proxy.out_port, client_curve), daemon=True).start()
190+
# Output is connect and input is bind - nowhere to connect dispatcher
191+
print("WARNING: Cannot subscribe dispatcher when output is in connect mode and input is in bind mode")
192+
193+
if dispatcher_address is not None:
194+
threading.Thread(target=start_dispatcher, args=(dispatcher_address, client_curve), daemon=True).start()
125195

126196

127197
print("Use Ctrl+C to exit.")

0 commit comments

Comments
 (0)