Skip to content

Commit 5a8f889

Browse files
mrakitintacaswell
authored andcommitted
ENH: extend the address normalization to allow domain sockets
1 parent 7c59ad7 commit 5a8f889

File tree

3 files changed

+123
-48
lines changed

3 files changed

+123
-48
lines changed

src/bluesky/callbacks/zmq.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,41 @@
1616
import copy
1717
import pickle
1818
import warnings
19+
from typing import Union
1920

2021
from ..run_engine import Dispatcher, DocumentNames
2122

2223

24+
def _normalize_address(inp: Union[str, tuple, int]):
25+
if isinstance(inp, str):
26+
if "://" in inp:
27+
protocol, _, rest_str = inp.partition("://")
28+
else:
29+
protocol = "tcp"
30+
rest_str = inp
31+
elif isinstance(inp, tuple):
32+
if inp[0] in ["tcp", "ipc"]:
33+
protocol, *rest = inp
34+
else:
35+
protocol = "tcp"
36+
rest = list(inp)
37+
if protocol == "tcp":
38+
if len(rest) == 2:
39+
rest_str = ":".join(str(r) for r in rest)
40+
else:
41+
(rest_str,) = rest
42+
else:
43+
(rest_str,) = rest
44+
elif isinstance(inp, int):
45+
protocol = "tcp"
46+
rest_str = f"localhost:{inp}"
47+
48+
else:
49+
raise TypeError(f"Input expected to be str or tuple, not {type(inp)}")
50+
51+
return f"{protocol}://{rest_str}"
52+
53+
2354
class Bluesky0MQDecodeError(Exception):
2455
"""Custom exception class for things that go wrong reading message from wire."""
2556

@@ -73,20 +104,20 @@ def __init__(self, address, *, prefix=b"", RE=None, zmq=None, serializer=pickle.
73104
raise ValueError(f"prefix {prefix!r} may not contain b' '")
74105
if zmq is None:
75106
import zmq
76-
if isinstance(address, str):
77-
address = address.split(":", maxsplit=1)
78-
self.address = (address[0], int(address[1]))
107+
108+
self.address = _normalize_address(address)
79109
self.RE = RE
80-
url = "tcp://%s:%d" % self.address
110+
81111
self._prefix = bytes(prefix)
82112
self._context = zmq.Context()
83113
self._socket = self._context.socket(zmq.PUB)
84-
self._socket.connect(url)
114+
self._socket.connect(self.address)
85115
if RE:
86116
self._subscription_token = RE.subscribe(self)
87117
self._serializer = serializer
88118

89119
def __call__(self, name, doc):
120+
print(f"{name = }\n{doc}\n")
90121
doc = copy.deepcopy(doc)
91122
message = b" ".join([self._prefix, name.encode(), self._serializer(doc)])
92123
self._socket.send(message)
@@ -102,23 +133,40 @@ class Proxy:
102133
"""
103134
Start a 0MQ proxy on the local host.
104135
136+
The addresses can be specified flexibly. It is best to use
137+
a domain_socket (available on unix):
138+
139+
- ``'icp:///tmp/domain_socket'``
140+
- ``('ipc', '/tmp/domain_socket')``
141+
142+
tcp sockets are also supported:
143+
144+
- ``'tcp://localhost:6557'``
145+
- ``6657`` (implicitly binds to ``'tcp://localhost:6657'``
146+
- ``('tcp', 'localhost', 6657)``
147+
- ``('localhost', 6657)``
148+
105149
Parameters
106150
----------
107-
in_port : int, optional
108-
Port that RunEngines should broadcast to. If None, a random port is
109-
used.
110-
out_port : int, optional
111-
Port that subscribers should subscribe to. If None, a random port is
112-
used.
151+
in_address : str or tuple or int, optional
152+
Address that RunEngines should broadcast to.
153+
154+
If None, a random tcp port on all interfaces is used.
155+
156+
out_address : str or tuple or int, optional
157+
Address that subscribers should subscribe to.
158+
159+
If None, a random tcp port on all interfaces is used.
160+
113161
zmq : object, optional
114162
By default, the 'zmq' module is imported and used. Anything else
115163
mocking its interface is accepted.
116164
117165
Attributes
118166
----------
119-
in_port : int
167+
in_port : int or str
120168
Port that RunEngines should broadcast to.
121-
out_port : int
169+
out_port : int or str
122170
Port that subscribers should subscribe to.
123171
closed : boolean
124172
True if the Proxy has already been started and subsequently
@@ -146,7 +194,7 @@ class Proxy:
146194
>>> proxy.start() # runs until interrupted
147195
"""
148196

149-
def __init__(self, in_port=None, out_port=None, *, zmq=None):
197+
def __init__(self, in_address=None, out_address=None, *, zmq=None):
150198
if zmq is None:
151199
import zmq
152200
self.zmq = zmq
@@ -155,19 +203,22 @@ def __init__(self, in_port=None, out_port=None, *, zmq=None):
155203
context = zmq.Context(1)
156204
# Socket facing clients
157205
frontend = context.socket(zmq.SUB)
158-
if in_port is None:
206+
if in_address is None:
159207
in_port = frontend.bind_to_random_port("tcp://*")
160208
else:
161-
frontend.bind("tcp://*:%d" % in_port)
209+
in_address = _normalize_address(in_address)
210+
in_port = frontend.bind(in_address)
162211

163212
frontend.setsockopt_string(zmq.SUBSCRIBE, "")
164213

165214
# Socket facing services
166215
backend = context.socket(zmq.PUB)
167-
if out_port is None:
216+
if out_address is None:
168217
out_port = backend.bind_to_random_port("tcp://*")
169218
else:
170-
backend.bind("tcp://*:%d" % out_port)
219+
out_address = _normalize_address(out_address)
220+
out_port = backend.bind(out_address)
221+
171222
except BaseException:
172223
# Clean up whichever components we have defined so far.
173224
try:
@@ -257,10 +308,8 @@ def __init__(
257308
import zmq
258309
if zmq_asyncio is None:
259310
import zmq.asyncio as zmq_asyncio
260-
if isinstance(address, str):
261-
address = address.split(":", maxsplit=1)
262311
self._deserializer = deserializer
263-
self.address = (address[0], int(address[1]))
312+
self.address = _normalize_address(address)
264313

265314
if loop is None:
266315
loop = asyncio.new_event_loop()
@@ -274,8 +323,7 @@ def __finish_setup():
274323
self._context = zmq_asyncio.Context()
275324
self._socket = self._context.socket(zmq.SUB)
276325

277-
url = "tcp://%s:%d" % self.address
278-
self._socket.connect(url)
326+
self._socket.connect(self.address)
279327
self._socket.setsockopt_string(zmq.SUBSCRIBE, "")
280328

281329
self.__factory = __finish_setup
@@ -332,6 +380,7 @@ async def _poll(self):
332380
f"\n\n{e}"
333381
)
334382
continue
383+
print(f"{name = }\n{doc}")
335384
self.loop.call_soon(self.process, DocumentNames[name], doc)
336385

337386
def start(self):

src/bluesky/commandline/zmq_proxy.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
logger = logging.getLogger("bluesky")
88

99

10-
def start_dispatcher(host, port, logfile=None):
10+
def start_dispatcher(out_address, logfile=None):
1111
"""The dispatcher function
1212
Parameters
1313
----------
1414
logfile : string
1515
string come from user command. ex --logfile=temp.log
1616
logfile will be "temp.log". logfile could be empty.
1717
"""
18-
dispatcher = RemoteDispatcher((host, port))
18+
dispatcher = RemoteDispatcher(out_address)
1919
if logfile is not None:
2020
raise ValueError(
2121
"Parameter 'logfile' is deprecated and will be removed in future releases. "
@@ -41,8 +41,8 @@ def log_writer(name, doc):
4141
def main():
4242
DESC = "Start a 0MQ proxy for publishing bluesky documents over a network."
4343
parser = argparse.ArgumentParser(description=DESC)
44-
parser.add_argument("in_port", type=int, nargs=1, help="port that RunEngines should broadcast to")
45-
parser.add_argument("out_port", type=int, nargs=1, help="port that subscribers should subscribe to")
44+
parser.add_argument("--in-address", help="port that RunEngines should broadcast to")
45+
parser.add_argument("--out-address", help="port that subscribers should subscribe to")
4646
parser.add_argument(
4747
"--verbose",
4848
"-v",
@@ -51,8 +51,6 @@ def main():
5151
)
5252
parser.add_argument("--logfile", type=str, help="Redirect logging output to a file on disk.")
5353
args = parser.parse_args()
54-
in_port = args.in_port[0]
55-
out_port = args.out_port[0]
5654

5755
if args.verbose:
5856
from bluesky.log import config_bluesky_logging
@@ -64,11 +62,11 @@ def main():
6462
else:
6563
config_bluesky_logging(level=level)
6664
# Set daemon to kill all threads upon IPython exit
67-
threading.Thread(target=start_dispatcher, args=("localhost", out_port), daemon=True).start()
65+
threading.Thread(target=start_dispatcher, args=(args.out_address), daemon=True).start()
6866

6967
print("Connecting...")
70-
proxy = Proxy(in_port, out_port)
71-
print("Receiving on port %d; publishing to port %d." % (in_port, out_port))
68+
proxy = Proxy(args.in_address, args.out_address)
69+
print("Receiving on address %s; publishing to address %s." % (args.in_address, args.out_address))
7270
print("Use Ctrl+C to exit.")
7371
try:
7472
proxy.start()

src/bluesky/tests/test_zmq.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from event_model import sanitize_doc
1212

1313
from bluesky import Msg
14-
from bluesky.callbacks.zmq import Proxy, Publisher, RemoteDispatcher
14+
from bluesky.callbacks.zmq import Proxy, Publisher, RemoteDispatcher, _normalize_address
1515
from bluesky.plans import count
1616
from bluesky.tests import uses_os_kill_sigint
1717

@@ -126,21 +126,6 @@ def delayed_sigint(delay):
126126
gc.collect()
127127

128128

129-
@pytest.mark.parametrize("host", ["localhost:5555", ("localhost", 5555)])
130-
def test_zmq_RD_ports_spec(host):
131-
# test that two ways of specifying address are equivalent
132-
d = RemoteDispatcher(host)
133-
assert d.address == ("localhost", 5555)
134-
assert d._socket is None
135-
assert d._context is None
136-
assert not d.closed
137-
d.stop()
138-
assert d._socket is None
139-
assert d._context is None
140-
assert d.closed
141-
del d
142-
143-
144129
def test_zmq_no_RE(RE):
145130
# COMPONENT 1
146131
# Run a 0MQ proxy on a separate process.
@@ -345,3 +330,46 @@ def local_cb(name, doc):
345330
ra = sanitize_doc(remote_accumulator)
346331
la = sanitize_doc(local_accumulator)
347332
assert ra == la
333+
334+
335+
@pytest.mark.parametrize(
336+
"host",
337+
["localhost:5555", ("localhost", 5555)],
338+
)
339+
def test_zmq_RD_ports_spec(host):
340+
# test that two ways of specifying address are equivalent
341+
d = RemoteDispatcher(host)
342+
assert d.address == "tcp://localhost:5555"
343+
assert d._socket is None
344+
assert d._context is None
345+
assert not d.closed
346+
d.stop()
347+
assert d._socket is None
348+
assert d._context is None
349+
assert d.closed
350+
del d
351+
352+
353+
@pytest.mark.parametrize(
354+
"address",
355+
[
356+
("localhost", "tcp://localhost"),
357+
("localhost:9", "tcp://localhost:9"),
358+
("remote.host", "tcp://remote.host"),
359+
("remote.host:9", "tcp://remote.host:9"),
360+
("tcp://remote.host", "tcp://remote.host"),
361+
("tcp://localhost", "tcp://localhost"),
362+
("tcp://localhost:9", "tcp://localhost:9"),
363+
("tcp://remote.host:9", "tcp://remote.host:9"),
364+
("ipc:///tmp/path", "ipc:///tmp/path"),
365+
(("localhost",), "tcp://localhost"),
366+
(("localhost", 9), "tcp://localhost:9"),
367+
(("ipc", "/tmp/path"), "ipc:///tmp/path"),
368+
(("tcp", "localhost"), "tcp://localhost"),
369+
(("tcp", "localhost", 9), "tcp://localhost:9"),
370+
(("tcp", "localhost", "9"), "tcp://localhost:9"),
371+
],
372+
)
373+
def test_address_normaliaztion(address):
374+
inp, outp = address
375+
assert _normalize_address(inp) == outp

0 commit comments

Comments
 (0)