Skip to content

Commit 2b6fc04

Browse files
mrakitinTST Operator
authored andcommitted
ZMQ classed updates to support IPC protocol
1 parent c2b578a commit 2b6fc04

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

src/bluesky/callbacks/zmq.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import asyncio
1616
import copy
1717
import pickle
18+
import uuid
1819
import warnings
20+
from enum import Enum
1921

2022
from ..run_engine import Dispatcher, DocumentNames
2123

@@ -26,6 +28,11 @@ class Bluesky0MQDecodeError(Exception):
2628
...
2729

2830

31+
class Protocols(Enum):
32+
TCP = "tcp"
33+
IPC = "ipc"
34+
35+
2936
class Publisher:
3037
"""
3138
A callback that publishes documents to a 0MQ proxy.
@@ -58,7 +65,7 @@ class Publisher:
5865
>>> RE.subscribe(publisher)
5966
"""
6067

61-
def __init__(self, address, *, prefix=b"", RE=None, zmq=None, serializer=pickle.dumps):
68+
def __init__(self, address, *, prefix=b"", RE=None, zmq=None, serializer=pickle.dumps, protocol=Protocols.TCP):
6269
if RE is not None:
6370
warnings.warn( # noqa: B028
6471
"The RE argument to Publisher is deprecated and "
@@ -73,11 +80,15 @@ def __init__(self, address, *, prefix=b"", RE=None, zmq=None, serializer=pickle.
7380
raise ValueError(f"prefix {prefix!r} may not contain b' '")
7481
if zmq is None:
7582
import zmq
76-
if isinstance(address, str):
77-
address = address.split(":", maxsplit=1)
78-
self.address = (address[0], int(address[1]))
83+
84+
if isinstance(address, tuple):
85+
url = f"{protocol.value}://{address[0]}:{address[1]}"
86+
else:
87+
url = f"{protocol.value}://address"
88+
89+
self.address = url
7990
self.RE = RE
80-
url = "tcp://%s:%d" % self.address
91+
8192
self._prefix = bytes(prefix)
8293
self._context = zmq.Context()
8394
self._socket = self._context.socket(zmq.PUB)
@@ -116,9 +127,9 @@ class Proxy:
116127
117128
Attributes
118129
----------
119-
in_port : int
130+
in_port : int or str
120131
Port that RunEngines should broadcast to.
121-
out_port : int
132+
out_port : int or str
122133
Port that subscribers should subscribe to.
123134
closed : boolean
124135
True if the Proxy has already been started and subsequently
@@ -146,28 +157,41 @@ class Proxy:
146157
>>> proxy.start() # runs until interrupted
147158
"""
148159

149-
def __init__(self, in_port=None, out_port=None, *, zmq=None):
160+
def __init__(self, in_port=None, out_port=None, *, zmq=None, protocol=Protocols.TCP):
150161
if zmq is None:
151162
import zmq
152163
self.zmq = zmq
153164
self.closed = False
165+
self.input_path = None
166+
self.output_path = None
154167
try:
155168
context = zmq.Context(1)
156169
# Socket facing clients
157170
frontend = context.socket(zmq.SUB)
158-
if in_port is None:
159-
in_port = frontend.bind_to_random_port("tcp://*")
160-
else:
161-
frontend.bind("tcp://*:%d" % in_port)
171+
if protocol is Protocols.TCP:
172+
if in_port is None:
173+
in_port = frontend.bind_to_random_port(f"{protocol.value}://*")
174+
else:
175+
frontend.bind(f"{protocol.value}://*:{in_port}")
176+
elif protocol is Protocols.IPC:
177+
if in_port is None:
178+
in_port = f"/tmp/{uuid.uuid4()}"
179+
frontend.bind(f"{protocol.value}://{in_port}")
162180

163181
frontend.setsockopt_string(zmq.SUBSCRIBE, "")
164182

165183
# Socket facing services
166184
backend = context.socket(zmq.PUB)
167-
if out_port is None:
168-
out_port = backend.bind_to_random_port("tcp://*")
169-
else:
170-
backend.bind("tcp://*:%d" % out_port)
185+
if protocol is Protocols.TCP:
186+
if out_port is None:
187+
out_port = backend.bind_to_random_port(f"{protocol.value}://*")
188+
else:
189+
backend.bind(f"{protocol.value}://*:{out_port}")
190+
elif protocol is Protocols.IPC:
191+
if out_port is None:
192+
out_port = f"/tmp/{uuid.uuid4()}"
193+
frontend.bind(f"{protocol.value}://{out_port}")
194+
171195
except BaseException:
172196
# Clean up whichever components we have defined so far.
173197
try:
@@ -247,6 +271,7 @@ def __init__(
247271
zmq_asyncio=None,
248272
deserializer=pickle.loads,
249273
strict=False,
274+
protocol=Protocols.TCP,
250275
):
251276
if isinstance(prefix, str):
252277
raise ValueError("prefix must be bytes, not string")
@@ -257,10 +282,12 @@ def __init__(
257282
import zmq
258283
if zmq_asyncio is None:
259284
import zmq.asyncio as zmq_asyncio
260-
if isinstance(address, str):
261-
address = address.split(":", maxsplit=1)
285+
if isinstance(address, tuple):
286+
url = f"{protocol.value}://{address[0]}:{address[1]}"
287+
else:
288+
url = f"{protocol.value}://address"
262289
self._deserializer = deserializer
263-
self.address = (address[0], int(address[1]))
290+
self.address = url
264291

265292
if loop is None:
266293
loop = asyncio.new_event_loop()
@@ -274,8 +301,7 @@ def __finish_setup():
274301
self._context = zmq_asyncio.Context()
275302
self._socket = self._context.socket(zmq.SUB)
276303

277-
url = "tcp://%s:%d" % self.address
278-
self._socket.connect(url)
304+
self._socket.connect(self.address)
279305
self._socket.setsockopt_string(zmq.SUBSCRIBE, "")
280306

281307
self.__factory = __finish_setup

0 commit comments

Comments
 (0)