@@ -235,52 +235,79 @@ class Proxy:
235235
236236 @staticmethod
237237 def configure_server_socket (
238- ctx , sock_type , address : Union [str , tuple , int , None ], curve : Union [ServerCurve , None ], zmq , auth_class
238+ ctx , sock_type , address : Union [str , tuple , int , None ], curve : Union [ServerCurve , ClientCurve , None ], zmq , auth_class , bind : bool = True
239239 ):
240240 socket = ctx .socket (sock_type )
241241 norm_address = _normalize_address (address )
242- logger .debug (f"Creating socket of type { sock_type } for address { norm_address } " )
242+ logger .debug (f"Creating socket of type { sock_type } for address { norm_address } , bind= { bind } " )
243243 random_port = False
244244 if norm_address .startswith ("tcp" ):
245245 if ":" not in norm_address [6 :]:
246246 random_port = True
247+
247248 if curve is not None :
248- logger .debug (f"Configuring CURVE security with secret_path={ curve .secret_path } " )
249- # build authenticator
250- auth = auth_class (ctx )
251- auth .start ()
252- logger .debug ("Started ZMQ authenticator" )
253- if curve .allow is not None :
254- auth .allow (* curve .allow )
255- logger .debug (f"Configured IP address allowlist: { curve .allow } " )
256-
257- # Tell the authenticator how to handle CURVE requests
258- if curve .client_public_keys is None :
259- # accept any client that knows the public key
260- auth .configure_curve (domain = "*" , location = zmq .auth .CURVE_ALLOW_ANY )
261- logger .debug ("Configured CURVE to allow any client with valid public key" )
249+ if bind :
250+ # Server mode - expect ServerCurve
251+ if not isinstance (curve , ServerCurve ):
252+ raise TypeError ("When bind=True, curve must be a ServerCurve instance" )
253+ logger .debug (f"Configuring CURVE server security with secret_path={ curve .secret_path } " )
254+ # build authenticator
255+ auth = auth_class (ctx )
256+ auth .start ()
257+ logger .debug ("Started ZMQ authenticator" )
258+ if curve .allow is not None :
259+ auth .allow (* curve .allow )
260+ logger .debug (f"Configured IP address allowlist: { curve .allow } " )
261+
262+ # Tell the authenticator how to handle CURVE requests
263+ if curve .client_public_keys is None :
264+ # accept any client that knows the public key
265+ auth .configure_curve (domain = "*" , location = zmq .auth .CURVE_ALLOW_ANY )
266+ logger .debug ("Configured CURVE to allow any client with valid public key" )
267+ else :
268+ auth .configure_curve (domain = "*" , location = curve .client_public_keys )
269+ logger .debug (f"Configured CURVE client public keys from: { curve .client_public_keys } " )
270+
271+ # get public and private keys from the certificate
272+ server_public , server_secret = zmq .auth .load_certificate (curve .secret_path )
273+ # attach them to the socket
274+ socket .setsockopt (zmq .CURVE_PUBLICKEY , server_public )
275+ socket .setsockopt (zmq .CURVE_SECRETKEY , server_secret )
276+ socket .setsockopt (zmq .CURVE_SERVER , True )
277+ logger .debug ("Applied CURVE keys and enabled CURVE server mode" )
278+ else :
279+ # Client mode - expect ClientCurve
280+ if not isinstance (curve , ClientCurve ):
281+ raise TypeError ("When bind=False, curve must be a ClientCurve instance" )
282+ logger .debug (f"Configuring CURVE client security with secret_path={ curve .secret_path } " )
283+
284+ # Load the client cert pair
285+ client_public , client_secret = zmq .auth .load_certificate (curve .secret_path )
286+ socket .setsockopt (zmq .CURVE_PUBLICKEY , client_public )
287+ if client_secret is None :
288+ raise ValueError ("The client secret key could not be found." )
289+ socket .setsockopt (zmq .CURVE_SECRETKEY , client_secret )
290+
291+ # Load the server public key and register with the socket
292+ server_key , _ = zmq .auth .load_certificate (curve .server_public_key )
293+ socket .setsockopt (zmq .CURVE_SERVERKEY , server_key )
294+ logger .debug ("Applied CURVE client keys and server public key" )
295+
296+ if bind :
297+ if random_port :
298+ port = socket .bind_to_random_port (norm_address )
299+ logger .debug (f"Bound to random port: { port } " )
262300 else :
263- auth .configure_curve (domain = "*" , location = curve .client_public_keys )
264- logger .debug (f"Configured CURVE client public keys from: { curve .client_public_keys } " )
265-
266- # get public and private keys from the certificate
267- server_public , server_secret = zmq .auth .load_certificate (curve .secret_path )
268- # attach them to the
269- socket .setsockopt (zmq .CURVE_PUBLICKEY , server_public )
270- socket .setsockopt (zmq .CURVE_SECRETKEY , server_secret )
271- socket .setsockopt (zmq .CURVE_SERVER , True )
272- logger .debug ("Applied CURVE keys and enabled CURVE server mode" )
273-
274- if random_port :
275- port = socket .bind_to_random_port (norm_address )
276- logger .debug (f"Bound to random port: { port } " )
301+ port = socket .bind (norm_address )
302+ logger .debug (f"Bound to address: { norm_address } " )
277303 else :
278- port = socket .bind (norm_address )
279- logger .debug (f"Bound to address: { norm_address } " )
304+ socket .connect (norm_address )
305+ port = norm_address
306+ logger .debug (f"Connected to address: { norm_address } " )
280307
281308 return socket , port
282309
283- def __init__ (self , in_address = None , out_address = None , * , zmq = None , in_curve = None , out_curve = None ):
310+ def __init__ (self , in_address = None , out_address = None , * , zmq = None , in_curve = None , out_curve = None , in_bind = True , out_bind = True ):
284311 if zmq is None :
285312 import zmq
286313 self .zmq = zmq
@@ -291,12 +318,12 @@ def __init__(self, in_address=None, out_address=None, *, zmq=None, in_curve=None
291318 from zmq .auth .thread import ThreadAuthenticator
292319
293320 frontend , in_port = self .configure_server_socket (
294- context , zmq .SUB , in_address , in_curve , zmq , ThreadAuthenticator
321+ context , zmq .SUB , in_address , in_curve , zmq , ThreadAuthenticator , bind = in_bind
295322 )
296323 frontend .setsockopt_string (zmq .SUBSCRIBE , "" )
297324
298325 backend , out_port = self .configure_server_socket (
299- context , zmq .PUB , out_address , out_curve , zmq , ThreadAuthenticator
326+ context , zmq .PUB , out_address , out_curve , zmq , ThreadAuthenticator , bind = out_bind
300327 )
301328
302329 except BaseException :
@@ -315,8 +342,8 @@ def __init__(self, in_address=None, out_address=None, *, zmq=None, in_curve=None
315342 ...
316343 raise
317344 else :
318- self .in_port = in_port .addr if hasattr (in_port , "addr" ) else _normalize_address (in_port )
319- self .out_port = out_port .addr if hasattr (out_port , "addr" ) else _normalize_address (out_port )
345+ self .in_port = in_port .addr if hasattr (in_port , "addr" ) else _normalize_address (in_port ) if in_bind else in_port
346+ self .out_port = out_port .addr if hasattr (out_port , "addr" ) else _normalize_address (out_port ) if out_bind else out_port
320347 self ._frontend = frontend
321348 self ._backend = backend
322349 self ._context = context
0 commit comments