1616import copy
1717import pickle
1818import warnings
19+ from typing import Union
1920
2021from ..run_engine import Dispatcher , DocumentNames
2122
2223
24+ def _normalize_address (inp : Union [str , tuple , int ]):
25+ if isinstance (inp , str ):
26+ if "://" in inp :
27+ protocol , _ , rest_str = inp .partition ("://" )
28+ else :
29+ protocol = "tcp"
30+ rest_str = inp
31+ elif isinstance (inp , tuple ):
32+ if inp [0 ] in ["tcp" , "ipc" ]:
33+ protocol , * rest = inp
34+ else :
35+ protocol = "tcp"
36+ rest = list (inp )
37+ if protocol == "tcp" :
38+ if len (rest ) == 2 :
39+ rest_str = ":" .join (str (r ) for r in rest )
40+ else :
41+ (rest_str ,) = rest
42+ else :
43+ (rest_str ,) = rest
44+ elif isinstance (inp , int ):
45+ protocol = "tcp"
46+ rest_str = f"localhost:{ inp } "
47+
48+ else :
49+ raise TypeError (f"Input expected to be str or tuple, not { type (inp )} " )
50+
51+ return f"{ protocol } ://{ rest_str } "
52+
53+
2354class Bluesky0MQDecodeError (Exception ):
2455 """Custom exception class for things that go wrong reading message from wire."""
2556
@@ -73,20 +104,20 @@ def __init__(self, address, *, prefix=b"", RE=None, zmq=None, serializer=pickle.
73104 raise ValueError (f"prefix { prefix !r} may not contain b' '" )
74105 if zmq is None :
75106 import zmq
76- if isinstance (address , str ):
77- address = address .split (":" , maxsplit = 1 )
78- self .address = (address [0 ], int (address [1 ]))
107+
108+ self .address = _normalize_address (address )
79109 self .RE = RE
80- url = "tcp://%s:%d" % self . address
110+
81111 self ._prefix = bytes (prefix )
82112 self ._context = zmq .Context ()
83113 self ._socket = self ._context .socket (zmq .PUB )
84- self ._socket .connect (url )
114+ self ._socket .connect (self . address )
85115 if RE :
86116 self ._subscription_token = RE .subscribe (self )
87117 self ._serializer = serializer
88118
89119 def __call__ (self , name , doc ):
120+ print (f"{ name = } \n { doc } \n " )
90121 doc = copy .deepcopy (doc )
91122 message = b" " .join ([self ._prefix , name .encode (), self ._serializer (doc )])
92123 self ._socket .send (message )
@@ -102,23 +133,40 @@ class Proxy:
102133 """
103134 Start a 0MQ proxy on the local host.
104135
136+ The addresses can be specified flexibly. It is best to use
137+ a domain_socket (available on unix):
138+
139+ - ``'icp:///tmp/domain_socket'``
140+ - ``('ipc', '/tmp/domain_socket')``
141+
142+ tcp sockets are also supported:
143+
144+ - ``'tcp://localhost:6557'``
145+ - ``6657`` (implicitly binds to ``'tcp://localhost:6657'``
146+ - ``('tcp', 'localhost', 6657)``
147+ - ``('localhost', 6657)``
148+
105149 Parameters
106150 ----------
107- in_port : int, optional
108- Port that RunEngines should broadcast to. If None, a random port is
109- used.
110- out_port : int, optional
111- Port that subscribers should subscribe to. If None, a random port is
112- used.
151+ in_address : str or tuple or int, optional
152+ Address that RunEngines should broadcast to.
153+
154+ If None, a random tcp port on all interfaces is used.
155+
156+ out_address : str or tuple or int, optional
157+ Address that subscribers should subscribe to.
158+
159+ If None, a random tcp port on all interfaces is used.
160+
113161 zmq : object, optional
114162 By default, the 'zmq' module is imported and used. Anything else
115163 mocking its interface is accepted.
116164
117165 Attributes
118166 ----------
119- in_port : int
167+ in_port : int or str
120168 Port that RunEngines should broadcast to.
121- out_port : int
169+ out_port : int or str
122170 Port that subscribers should subscribe to.
123171 closed : boolean
124172 True if the Proxy has already been started and subsequently
@@ -146,7 +194,7 @@ class Proxy:
146194 >>> proxy.start() # runs until interrupted
147195 """
148196
149- def __init__ (self , in_port = None , out_port = None , * , zmq = None ):
197+ def __init__ (self , in_address = None , out_address = None , * , zmq = None ):
150198 if zmq is None :
151199 import zmq
152200 self .zmq = zmq
@@ -155,19 +203,22 @@ def __init__(self, in_port=None, out_port=None, *, zmq=None):
155203 context = zmq .Context (1 )
156204 # Socket facing clients
157205 frontend = context .socket (zmq .SUB )
158- if in_port is None :
206+ if in_address is None :
159207 in_port = frontend .bind_to_random_port ("tcp://*" )
160208 else :
161- frontend .bind ("tcp://*:%d" % in_port )
209+ in_address = _normalize_address (in_address )
210+ in_port = frontend .bind (in_address )
162211
163212 frontend .setsockopt_string (zmq .SUBSCRIBE , "" )
164213
165214 # Socket facing services
166215 backend = context .socket (zmq .PUB )
167- if out_port is None :
216+ if out_address is None :
168217 out_port = backend .bind_to_random_port ("tcp://*" )
169218 else :
170- backend .bind ("tcp://*:%d" % out_port )
219+ out_address = _normalize_address (out_address )
220+ out_port = backend .bind (out_address )
221+
171222 except BaseException :
172223 # Clean up whichever components we have defined so far.
173224 try :
@@ -257,10 +308,8 @@ def __init__(
257308 import zmq
258309 if zmq_asyncio is None :
259310 import zmq .asyncio as zmq_asyncio
260- if isinstance (address , str ):
261- address = address .split (":" , maxsplit = 1 )
262311 self ._deserializer = deserializer
263- self .address = (address [ 0 ], int ( address [ 1 ]) )
312+ self .address = _normalize_address (address )
264313
265314 if loop is None :
266315 loop = asyncio .new_event_loop ()
@@ -274,8 +323,7 @@ def __finish_setup():
274323 self ._context = zmq_asyncio .Context ()
275324 self ._socket = self ._context .socket (zmq .SUB )
276325
277- url = "tcp://%s:%d" % self .address
278- self ._socket .connect (url )
326+ self ._socket .connect (self .address )
279327 self ._socket .setsockopt_string (zmq .SUBSCRIBE , "" )
280328
281329 self .__factory = __finish_setup
@@ -332,6 +380,7 @@ async def _poll(self):
332380 f"\n \n { e } "
333381 )
334382 continue
383+ print (f"{ name = } \n { doc } " )
335384 self .loop .call_soon (self .process , DocumentNames [name ], doc )
336385
337386 def start (self ):
0 commit comments