diff --git a/src/cpp/scaler/object_storage/object_storage_server.cpp b/src/cpp/scaler/object_storage/object_storage_server.cpp index eac86c9e3..71fe4a38e 100644 --- a/src/cpp/scaler/object_storage/object_storage_server.cpp +++ b/src/cpp/scaler/object_storage/object_storage_server.cpp @@ -1,6 +1,7 @@ #include "scaler/object_storage/object_storage_server.h" #include +#include #include #include #include @@ -14,9 +15,32 @@ namespace scaler { namespace object_storage { +// Global atomic flag to indicate termination request +static std::atomic sigRequestStop {false}; + +// Signal handler for SIGTERM +extern "C" void handleSigTerm(int signum) +{ + sigRequestStop = true; +} + +// Function to install the signal handler +void setupSignalHandling() +{ + struct sigaction sa {}; + sa.sa_handler = handleSigTerm; + sigemptyset(&sa.sa_mask); + sa.sa_flags = 0; + + if (sigaction(SIGTERM, &sa, nullptr) == -1) { + perror("sigaction"); + } +} + ObjectStorageServer::ObjectStorageServer() { initServerReadyFds(); + setupSignalHandling(); } ObjectStorageServer::~ObjectStorageServer() @@ -138,7 +162,7 @@ void ObjectStorageServer::processRequests(std::function running) auto maybeMessageFuture = ymq::futureRecvMessage(_ioSocket); while (maybeMessageFuture.wait_for(100ms) == std::future_status::timeout) { - if (!running()) { + if (!running() || sigRequestStop) { _logger.log(scaler::ymq::Logger::LoggingLevel::info, "ObjectStorageServer: stopped by user"); pendingRequests.clear(); return; diff --git a/src/cpp/scaler/object_storage/pymod_object_storage_server.cpp b/src/cpp/scaler/object_storage/pymod_object_storage_server.cpp index aab42ed4a..308b33968 100644 --- a/src/cpp/scaler/object_storage/pymod_object_storage_server.cpp +++ b/src/cpp/scaler/object_storage/pymod_object_storage_server.cpp @@ -54,19 +54,23 @@ static PyObject* PyObjectStorageServerRun(PyObject* self, PyObject* args) logging_paths.push_back(PyUnicode_AsUTF8(path_obj)); } - auto running = []() -> bool { + int res {}; + auto running = [&] -> bool { AcquireGIL gil; (void)gil; - return PyErr_CheckSignals() == 0; + res = PyErr_CheckSignals(); + return res == 0; }; ((PyObjectStorageServer*)self) ->server.run( addr, std::to_string(port), identity, log_level, log_format, std::move(logging_paths), std::move(running)); - // TODO: Ideally, run should return a bool and we return failure with nullptr. - return nullptr; - // Py_RETURN_NONE; + if (!res) { + Py_RETURN_NONE; + } else { + return nullptr; + } } static PyObject* PyObjectStorageServerWaitUntilReady(PyObject* self, [[maybe_unused]] PyObject* args) diff --git a/src/cpp/scaler/ymq/io_context.cpp b/src/cpp/scaler/ymq/io_context.cpp index 355a25eda..42c547602 100644 --- a/src/cpp/scaler/ymq/io_context.cpp +++ b/src/cpp/scaler/ymq/io_context.cpp @@ -64,10 +64,23 @@ void IOContext::removeIOSocket(std::shared_ptr& socket) noexcept std::promise promise; auto future = promise.get_future(); + // TODO: This `count` and `maxCount` is needed as a safety net so that + // we don't wait forever on querying numOfConnections. + // If the remote end is using YMQ as internal communication tool, then + // we don't need this safety net. This is because YMQ closes a connection + // if the remote end shutdown write. This results to an event in local. + // If the remote end does not close connection upon local end shutdown + // write, the local end will never get any event for remote socket close, + // and therefore the connection will stay alive in the system. + // This needs to be revisited, we have opened an issue, the issue link is: + // https://github.com/finos/opengris-scaler/issues/445 + int count = 0; auto waitToRemoveIOSocket = [&](const auto& self) -> void { + constexpr static int maxCount = 8; rawSocket->_eventLoopThread->_eventLoop.executeNow([&] { rawSocket->_eventLoopThread->_eventLoop.executeLater([&] { - if (rawSocket->numOfConnections()) { + if (rawSocket->numOfConnections() && count < maxCount) { + ++count; self(self); return; } diff --git a/src/cpp/scaler/ymq/io_socket.cpp b/src/cpp/scaler/ymq/io_socket.cpp index 02b418db1..ac2d31135 100644 --- a/src/cpp/scaler/ymq/io_socket.cpp +++ b/src/cpp/scaler/ymq/io_socket.cpp @@ -144,7 +144,7 @@ void IOSocket::connectTo(SocketAddress addr, ConnectReturnCallback onConnectRetu _tcpClient->onCreated(); } else if (addr.nativeHandleType() == SocketAddress::Type::IPC) { - if (_domainClient) { + if (_ipcClient) { unrecoverableError({ Error::ErrorCode::MultipleConnectToNotSupported, "Originated from", @@ -152,9 +152,9 @@ void IOSocket::connectTo(SocketAddress addr, ConnectReturnCallback onConnectRetu }); } - _domainClient.emplace( + _ipcClient.emplace( _eventLoopThread.get(), this->identity(), std::move(addr), std::move(callback), maxRetryTimes); - _domainClient->onCreated(); + _ipcClient->onCreated(); } else { std::unreachable(); // current protocol supports only tcp and icp @@ -171,35 +171,34 @@ void IOSocket::connectTo( void IOSocket::bindTo(std::string netOrDomainAddr, BindReturnCallback onBindReturn) noexcept { - _eventLoopThread->_eventLoop.executeNow( - [this, netOrDomainAddr = std::move(netOrDomainAddr), callback = std::move(onBindReturn)] mutable { - assert(netOrDomainAddr.size()); - const auto socketAddress = stringToSocketAddress(netOrDomainAddr); - - if (socketAddress.nativeHandleType() == SocketAddress::Type::TCP) { - if (_tcpServer) { - callback(std::unexpected {Error::ErrorCode::MultipleBindToNotSupported}); - return; - } + _eventLoopThread->_eventLoop.executeNow([this, + netOrDomainAddr = std::move(netOrDomainAddr), + callback = std::move(onBindReturn)] mutable { + assert(netOrDomainAddr.size()); + const auto socketAddress = stringToSocketAddress(netOrDomainAddr); + + if (socketAddress.nativeHandleType() == SocketAddress::Type::TCP) { + if (_tcpServer) { + callback(std::unexpected {Error::ErrorCode::MultipleBindToNotSupported}); + return; + } - _tcpServer.emplace( - _eventLoopThread.get(), this->identity(), std::move(socketAddress), std::move(callback)); - _tcpServer->onCreated(); + _tcpServer.emplace(_eventLoopThread.get(), this->identity(), std::move(socketAddress), std::move(callback)); + _tcpServer->onCreated(); - } else if (socketAddress.nativeHandleType() == SocketAddress::Type::IPC) { - if (_domainServer) { - callback(std::unexpected {Error::ErrorCode::MultipleBindToNotSupported}); - return; - } + } else if (socketAddress.nativeHandleType() == SocketAddress::Type::IPC) { + if (_ipcServer) { + callback(std::unexpected {Error::ErrorCode::MultipleBindToNotSupported}); + return; + } - _domainServer.emplace( - _eventLoopThread.get(), this->identity(), std::move(socketAddress), std::move(callback)); - _domainServer->onCreated(); + _ipcServer.emplace(_eventLoopThread.get(), this->identity(), std::move(socketAddress), std::move(callback)); + _ipcServer->onCreated(); - } else { - std::unreachable(); // current protocol supports only tcp and icp - } - }); + } else { + std::unreachable(); // current protocol supports only tcp and icp + } + }); } void IOSocket::closeConnection(Identity remoteSocketIdentity) noexcept @@ -213,11 +212,24 @@ void IOSocket::closeConnection(Identity remoteSocketIdentity) noexcept }); } -// TODO: The function should be separated into onConnectionAborted, onConnectionDisconnected, -// and probably onConnectionAbortedBeforeEstablished(?) +void IOSocket::onConnectorMaxedOutRetry() noexcept +{ + assert(_unestablishedConnection.size()); + assert(IOSocketType::Connector == this->_socketType); + _connectorDisconnected = true; + fillPendingRecvMessagesWithErr(Error::ErrorCode::ConnectorSocketClosedByRemoteEnd); + auto& connPtr = _unestablishedConnection.back(); + _eventLoopThread->_eventLoop.executeLater([conn = std::move(connPtr)]() {}); + _unestablishedConnection.pop_back(); +} + void IOSocket::onConnectionDisconnected(MessageConnection* conn, bool keepInBook) noexcept { if (!conn->_remoteIOSocketIdentity) { + if (IOSocketType::Connector == this->_socketType) { + _connectorDisconnected = true; + fillPendingRecvMessagesWithErr(Error::ErrorCode::ConnectorSocketClosedByRemoteEnd); + } auto connIt = std::ranges::find_if(_unestablishedConnection, [&](const auto& x) { return x.get() == conn; }); assert(connIt != _unestablishedConnection.end()); _eventLoopThread->_eventLoop.executeLater([conn = std::move(*connIt)] {}); @@ -330,11 +342,14 @@ void IOSocket::onConnectionCreated( _unestablishedConnection.back()->onCreated(); } -void IOSocket::removeConnectedStreamClient() noexcept +void IOSocket::removeConnectedStreamClient(const StreamClient* client) noexcept { - if (this->_tcpClient && this->_tcpClient->_connected) { + if (this->_tcpClient && &(*this->_tcpClient) == client) { this->_tcpClient.reset(); } + if (this->_ipcClient && &(*this->_ipcClient) == client) { + this->_ipcClient.reset(); + } } void IOSocket::requestStop() noexcept @@ -352,11 +367,11 @@ void IOSocket::requestStop() noexcept _tcpClient->disconnect(); } - if (_domainClient) { - _domainClient->disconnect(); + if (_ipcClient) { + _ipcClient->disconnect(); } - if (_domainServer) { - _domainServer->disconnect(); + if (_ipcServer) { + _ipcServer->disconnect(); } } diff --git a/src/cpp/scaler/ymq/io_socket.h b/src/cpp/scaler/ymq/io_socket.h index 29647af9b..780f4e30c 100644 --- a/src/cpp/scaler/ymq/io_socket.h +++ b/src/cpp/scaler/ymq/io_socket.h @@ -48,13 +48,12 @@ class IOSocket { // NOTE: BELOW FOUR FUNCTIONS ARE USERSPACE API void sendMessage(Message message, SendMessageCallback onMessageSent) noexcept; void recvMessage(RecvMessageCallback onRecvMessage) noexcept; - void bindTo(std::string netOrDomainAddr, BindReturnCallback onBindReturn) noexcept; void connectTo( - std::string netOrDomainAddr, ConnectReturnCallback onConnectReturn, size_t maxRetryTimes = 8) noexcept; + std::string netOrDomainAddr, ConnectReturnCallback onConnectReturn, size_t maxRetryTimes = 4) noexcept; // NOTE: BELOW ONE ARE NOT OFFICIAL USERSPACE API. USE WITH CAUTION. - void connectTo(SocketAddress addr, ConnectReturnCallback onConnectReturn, size_t maxRetryTimes = 8) noexcept; + void connectTo(SocketAddress addr, ConnectReturnCallback onConnectReturn, size_t maxRetryTimes = 4) noexcept; void closeConnection(Identity remoteSocketIdentity) noexcept; @@ -70,6 +69,9 @@ class IOSocket { // From Connection Class only void onConnectionIdentityReceived(MessageConnection* conn) noexcept; + // From CONNECTOR only + void onConnectorMaxedOutRetry() noexcept; + // NOTE: These two functions are called respectively by sendMessage and server/client. // Notice that in the each case only the needed information are passed in; so it's less // likely the user passed in combinations that does not make sense. These two calls are @@ -78,8 +80,8 @@ class IOSocket { void onConnectionCreated( int fd, SocketAddress localAddr, SocketAddress remoteAddr, bool responsibleForRetry) noexcept; - // From TCPClient class only - void removeConnectedStreamClient() noexcept; + // From StreamClient class only + void removeConnectedStreamClient(const StreamClient* client) noexcept; void requestStop() noexcept; @@ -100,9 +102,9 @@ class IOSocket { // NOTE: Owning one TCPServer means the user cannot bindTo multiple addresses. std::optional _tcpServer; - // NOTE: User may choose to bind to one IP address + one UDS address - std::optional _domainServer; - std::optional _domainClient; + // NOTE: User may choose to bind to one IPv4 address + one IPC address + std::optional _ipcServer; + std::optional _ipcClient; // Remote identity to connection map std::map> _identityToConnection; diff --git a/src/cpp/scaler/ymq/pymod_ymq/exception.h b/src/cpp/scaler/ymq/pymod_ymq/exception.h index d942f90e4..4b2932b5a 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/exception.h +++ b/src/cpp/scaler/ymq/pymod_ymq/exception.h @@ -48,10 +48,9 @@ static int YMQException_init(YMQException* self, PyObject* args, PyObject* kwds) static void YMQException_dealloc(YMQException* self) { - self->ob_base.ob_type->tp_base->tp_dealloc((PyObject*)self); - // we still need to release the reference to the heap type auto* tp = Py_TYPE(self); + self->ob_base.ob_type->tp_base->tp_dealloc((PyObject*)self); Py_DECREF(tp); } diff --git a/src/cpp/scaler/ymq/stream_client.cpp b/src/cpp/scaler/ymq/stream_client.cpp index de9514e36..c91a2abdb 100644 --- a/src/cpp/scaler/ymq/stream_client.cpp +++ b/src/cpp/scaler/ymq/stream_client.cpp @@ -33,7 +33,7 @@ void StreamClient::onCreated() _rawClient.zeroNativeHandle(); _connected = true; - _eventLoopThread->_eventLoop.executeLater([sock] { sock->removeConnectedStreamClient(); }); + _eventLoopThread->_eventLoop.executeLater([sock, this] { sock->removeConnectedStreamClient(this); }); if (_retryTimes == 0) { _onConnectReturn({}); @@ -105,14 +105,19 @@ void StreamClient::onWrite() _rawClient.zeroNativeHandle(); _connected = true; - _eventLoopThread->_eventLoop.executeLater([sock] { sock->removeConnectedStreamClient(); }); + _eventLoopThread->_eventLoop.executeLater([sock, this] { sock->removeConnectedStreamClient(this); }); } void StreamClient::retry() { if (_retryTimes > _maxRetryTimes) { _logger.log(Logger::LoggingLevel::error, "Retried times has reached maximum: ", _maxRetryTimes); - // exit(1); + disconnect(); + + const std::string id = this->_localIOSocketIdentity; + auto sock = this->_eventLoopThread->_identityToIOSocket.at(id); + sock->onConnectorMaxedOutRetry(); + _eventLoopThread->_eventLoop.executeLater([sock, this] { sock->removeConnectedStreamClient(this); }); return; } diff --git a/src/scaler/client/agent/client_agent.py b/src/scaler/client/agent/client_agent.py index 8eae128b4..476f01640 100644 --- a/src/scaler/client/agent/client_agent.py +++ b/src/scaler/client/agent/client_agent.py @@ -16,6 +16,8 @@ from scaler.config.types.zmq import ZMQConfig from scaler.io.async_connector import ZMQAsyncConnector from scaler.io.mixins import AsyncConnector +from scaler.io.utility import create_async_connector +from scaler.io.ymq.ymq import YMQException from scaler.protocol.python.common import ObjectStorageAddress from scaler.protocol.python.message import ( ClientDisconnect, @@ -78,8 +80,9 @@ def __init__( callback=self.__on_receive_from_client, identity=None, ) - self._connector_external: AsyncConnector = ZMQAsyncConnector( - context=zmq.asyncio.Context.shadow(self._context), + + self._connector_external: AsyncConnector = create_async_connector( + zmq.asyncio.Context.shadow(self._context), name="client_agent_external", socket_type=zmq.DEALER, address=self._scheduler_address, @@ -194,7 +197,11 @@ async def __get_loops(self): finally: self._stop_event.set() # always set the stop event before setting futures' exceptions - await self._object_manager.clear_all_objects(clear_serializer=True) + if not isinstance(exception, YMQException): + try: + await self._object_manager.clear_all_objects(clear_serializer=True) + except YMQException: # Above call triggers YMQ, which may raise + pass self._connector_external.destroy() self._connector_internal.destroy() @@ -211,8 +218,8 @@ async def __get_loops(self): elif isinstance(exception, (ClientQuitException, ClientShutdownException)): logging.info("ClientAgent: client quitting") self._future_manager.set_all_futures_with_exception(exception) - elif isinstance(exception, TimeoutError): + elif isinstance(exception, (TimeoutError, YMQException)): logging.error(f"ClientAgent: client timeout when connecting to {self._scheduler_address.to_address()}") - self._future_manager.set_all_futures_with_exception(exception) + self._future_manager.set_all_futures_with_exception(TimeoutError()) else: raise exception diff --git a/src/scaler/io/async_object_storage_connector.py b/src/scaler/io/async_object_storage_connector.py index 9742989a7..659060803 100644 --- a/src/scaler/io/async_object_storage_connector.py +++ b/src/scaler/io/async_object_storage_connector.py @@ -65,15 +65,13 @@ async def wait_until_connected(self): def is_connected(self) -> bool: return self._connected_event.is_set() - async def destroy(self): + def destroy(self): if not self.is_connected(): return if not self._writer.is_closing: self._writer.close() - await self._writer.wait_closed() - @property def reader(self) -> Optional[asyncio.StreamReader]: return self._reader diff --git a/src/scaler/io/mixins.py b/src/scaler/io/mixins.py index 17b137278..eb9ba1aee 100644 --- a/src/scaler/io/mixins.py +++ b/src/scaler/io/mixins.py @@ -92,7 +92,7 @@ def is_connected(self) -> bool: raise NotImplementedError() @abc.abstractmethod - async def destroy(self): + def destroy(self): raise NotImplementedError() @property diff --git a/src/scaler/io/utility.py b/src/scaler/io/utility.py index 682909c41..4c18c9527 100644 --- a/src/scaler/io/utility.py +++ b/src/scaler/io/utility.py @@ -2,10 +2,12 @@ import os from typing import List, Optional -from scaler.config.defaults import CAPNP_DATA_SIZE_LIMIT, CAPNP_MESSAGE_SIZE_LIMIT +import zmq.asyncio + +from scaler.config.defaults import CAPNP_DATA_SIZE_LIMIT, CAPNP_MESSAGE_SIZE_LIMIT, SCALER_NETWORK_BACKEND from scaler.config.types.network_backend import NetworkBackend from scaler.io.async_object_storage_connector import PyAsyncObjectStorageConnector -from scaler.io.mixins import AsyncObjectStorageConnector, SyncObjectStorageConnector +from scaler.io.mixins import AsyncBinder, AsyncConnector, AsyncObjectStorageConnector, SyncObjectStorageConnector from scaler.io.sync_object_storage_connector import PySyncObjectStorageConnector from scaler.protocol.capnp._python import _message # noqa from scaler.protocol.python.message import PROTOCOL @@ -18,11 +20,44 @@ def get_scaler_network_backend_from_env(): - backend_str = os.environ.get("SCALER_NETWORK_BACKEND", "tcp_zmq") # Default to tcp_zmq - try: - return NetworkBackend[backend_str] - except KeyError: - return None + backend_str = os.environ.get("SCALER_NETWORK_BACKEND") # Default to tcp_zmq + if backend_str is None: + return SCALER_NETWORK_BACKEND + return NetworkBackend[backend_str] + + +def create_async_binder(ctx: zmq.asyncio.Context, *args, **kwargs) -> AsyncBinder: + connector_type = get_scaler_network_backend_from_env() + if connector_type == NetworkBackend.ymq: + from scaler.io.ymq_async_binder import YMQAsyncBinder + + return YMQAsyncBinder(*args, **kwargs) + elif connector_type == NetworkBackend.tcp_zmq: + from scaler.io.async_binder import ZMQAsyncBinder + + return ZMQAsyncBinder(context=ctx, *args, **kwargs) # type: ignore[misc] + else: + raise ValueError( + f"Invalid SCALER_NETWORK_BACKEND value." f"Expected one of: {[e.name for e in NetworkBackend]}" + ) + + +def create_async_connector(ctx: zmq.asyncio.Context, *args, **kwargs) -> AsyncConnector: + connector_type = get_scaler_network_backend_from_env() + if connector_type == NetworkBackend.ymq: + from scaler.io.ymq import ymq + from scaler.io.ymq_async_connector import YMQAsyncConnector + + kwargs["socket_type"] = ymq.IOSocketType.Connector + return YMQAsyncConnector(*args, **kwargs) + elif connector_type == NetworkBackend.tcp_zmq: + from scaler.io.async_connector import ZMQAsyncConnector + + return ZMQAsyncConnector(context=ctx, *args, **kwargs) # type: ignore[misc] + else: + raise ValueError( + f"Invalid SCALER_NETWORK_BACKEND value." f"Expected one of: {[e.name for e in NetworkBackend]}" + ) def create_async_object_storage_connector(*args, **kwargs) -> AsyncObjectStorageConnector: diff --git a/src/scaler/io/ymq_async_binder.py b/src/scaler/io/ymq_async_binder.py new file mode 100644 index 000000000..cf61e9c5f --- /dev/null +++ b/src/scaler/io/ymq_async_binder.py @@ -0,0 +1,65 @@ +import logging +import os +import uuid +from collections import defaultdict +from typing import Awaitable, Callable, Dict, Optional + +from scaler.config.types.zmq import ZMQConfig +from scaler.io.mixins import AsyncBinder +from scaler.io.utility import deserialize, serialize +from scaler.io.ymq import ymq +from scaler.protocol.python.mixins import Message +from scaler.protocol.python.status import BinderStatus + + +class YMQAsyncBinder(AsyncBinder): + def __init__(self, name: str, address: ZMQConfig, identity: Optional[bytes] = None): + self._address = address + + if identity is None: + identity = f"{os.getpid()}|{name}|{uuid.uuid4()}".encode() + self._identity = identity + + self._context = ymq.IOContext() + self._socket = self._context.createIOSocket_sync(self.identity.decode(), ymq.IOSocketType.Binder) + self._socket.bind_sync(self._address.to_address()) + + self._callback: Optional[Callable[[bytes, Message], Awaitable[None]]] = None + + self._received: Dict[str, int] = defaultdict(lambda: 0) + self._sent: Dict[str, int] = defaultdict(lambda: 0) + + @property + def identity(self): + return self._identity + + def destroy(self): + self._socket = None + self._context = None + + def register(self, callback: Callable[[bytes, Message], Awaitable[None]]): + self._callback = callback + + async def routine(self): + ymqmsg = await self._socket.recv() + + message: Optional[Message] = deserialize(ymqmsg.payload.data) + if message is None: + logging.error(f"received unknown message from {ymqmsg.address.data!r}: {ymqmsg.address.data!r}") + return + + self.__count_received(message.__class__.__name__) + await self._callback(ymqmsg.address.data, message) + + async def send(self, to: bytes, message: Message): + self.__count_sent(message.__class__.__name__) + await self._socket.send(ymq.Message(address=to, payload=serialize(message))) + + def get_status(self) -> BinderStatus: + return BinderStatus.new_msg(received=self._received, sent=self._sent) + + def __count_received(self, message_type: str): + self._received[message_type] += 1 + + def __count_sent(self, message_type: str): + self._sent[message_type] += 1 diff --git a/src/scaler/io/ymq_async_connector.py b/src/scaler/io/ymq_async_connector.py new file mode 100644 index 000000000..f34892edf --- /dev/null +++ b/src/scaler/io/ymq_async_connector.py @@ -0,0 +1,84 @@ +import logging +import os +import uuid +from typing import Awaitable, Callable, Literal, Optional + +from scaler.config.types.zmq import ZMQConfig +from scaler.io.mixins import AsyncConnector +from scaler.io.utility import deserialize, serialize +from scaler.io.ymq import ymq +from scaler.protocol.python.mixins import Message + + +class YMQAsyncConnector(AsyncConnector): + def __init__( + self, + name: str, + socket_type: ymq.IOSocketType, + address: ZMQConfig, + bind_or_connect: Literal["bind", "connect"], + callback: Optional[Callable[[Message], Awaitable[None]]], + identity: Optional[bytes], + ): + self._address = address + + self._context = ymq.IOContext() + + if identity is None: + identity = f"{os.getpid()}|{name}|{uuid.uuid4().bytes.hex()}".encode() + self._identity = identity + + self._socket = self._context.createIOSocket_sync(self.identity.decode(), socket_type) + + if bind_or_connect == "bind": + self._socket.bind_sync(self.address) + elif bind_or_connect == "connect": + self._socket.connect_sync(self.address) + else: + raise TypeError("bind_or_connect has to be 'bind' or 'connect'") + + self._callback: Optional[Callable[[Message], Awaitable[None]]] = callback + + def destroy(self): + self._socket = None + self._context = None + + @property + def identity(self) -> bytes: + return self._identity + + @property + def socket(self) -> ymq.IOSocket: + return self._socket + + @property + def address(self) -> str: + return self._address.to_address() + + async def routine(self): + if self._callback is None: + return + + message: Optional[Message] = await self.receive() + if message is None: + return + + await self._callback(message) + + async def receive(self) -> Optional[Message]: + if self._context is None: + return None + + if self._socket is None: + return None + + msg = await self._socket.recv() + result: Optional[Message] = deserialize(msg.payload.data) + if result is None: + logging.error(f"received unknown message: {msg.payload.data!r}") + return None + + return result + + async def send(self, message: Message): + await self._socket.send(ymq.Message(address=b"", payload=serialize(message))) diff --git a/src/scaler/io/ymq_async_object_storage_connector.py b/src/scaler/io/ymq_async_object_storage_connector.py index 5c2aead24..1854e905b 100644 --- a/src/scaler/io/ymq_async_object_storage_connector.py +++ b/src/scaler/io/ymq_async_object_storage_connector.py @@ -33,6 +33,7 @@ def __del__(self): if not self.is_connected(): return self._io_socket = None + self._io_context = None async def connect(self, host: str, port: int): self._host = host @@ -49,10 +50,11 @@ async def wait_until_connected(self): def is_connected(self) -> bool: return self._connected_event.is_set() - async def destroy(self): + def destroy(self): if not self.is_connected(): return self._io_socket = None + self._io_context = None @property def address(self) -> str: diff --git a/src/scaler/io/ymq_sync_object_storage_connector.py b/src/scaler/io/ymq_sync_object_storage_connector.py index 045106881..b5e31a7ef 100644 --- a/src/scaler/io/ymq_sync_object_storage_connector.py +++ b/src/scaler/io/ymq_sync_object_storage_connector.py @@ -37,6 +37,7 @@ def destroy(self): with self._socket_lock: if self._io_socket is not None: self._io_socket = None + self._io_context = None @property def address(self) -> str: diff --git a/src/scaler/scheduler/scheduler.py b/src/scaler/scheduler/scheduler.py index e49de922d..9e4c768c0 100644 --- a/src/scaler/scheduler/scheduler.py +++ b/src/scaler/scheduler/scheduler.py @@ -7,10 +7,10 @@ from scaler.config.defaults import CLEANUP_INTERVAL_SECONDS, STATUS_REPORT_INTERVAL_SECONDS from scaler.config.section.scheduler import SchedulerConfig from scaler.config.types.zmq import ZMQConfig, ZMQType -from scaler.io.async_binder import ZMQAsyncBinder from scaler.io.async_connector import ZMQAsyncConnector from scaler.io.mixins import AsyncBinder, AsyncConnector, AsyncObjectStorageConnector -from scaler.io.utility import create_async_object_storage_connector +from scaler.io.utility import create_async_binder, create_async_object_storage_connector +from scaler.io.ymq.ymq import YMQException from scaler.protocol.python.common import ObjectStorageAddress from scaler.protocol.python.message import ( ClientDisconnect, @@ -37,7 +37,7 @@ from scaler.scheduler.controllers.task_controller import VanillaTaskController from scaler.scheduler.controllers.worker_controller import VanillaWorkerController from scaler.utility.event_loop import create_async_loop_routine -from scaler.utility.exceptions import ClientShutdownException +from scaler.utility.exceptions import ClientShutdownException, ObjectStorageException from scaler.utility.identifiers import ClientID, WorkerID @@ -71,9 +71,10 @@ def __init__(self, config: SchedulerConfig): self._context = zmq.asyncio.Context(io_threads=config.worker_io_threads) - self._binder: AsyncBinder = ZMQAsyncBinder( - context=self._context, name="scheduler", address=config.scheduler_address + self._binder: AsyncBinder = create_async_binder( + self._context, name="scheduler", address=config.scheduler_address ) + logging.info(f"{self.__class__.__name__}: listen to scheduler address {config.scheduler_address}") self._connector_storage: AsyncObjectStorageConnector = create_async_object_storage_connector() @@ -240,9 +241,14 @@ async def get_loops(self): except ClientShutdownException as e: logging.info(f"{self.__class__.__name__}: {e}") pass + except YMQException: + pass + except ObjectStorageException: + pass self._binder.destroy() self._binder_monitor.destroy() + self._connector_storage.destroy() @functools.wraps(Scheduler) diff --git a/src/scaler/worker/agent/processor/processor.py b/src/scaler/worker/agent/processor/processor.py index ef327aae1..135be0d4f 100644 --- a/src/scaler/worker/agent/processor/processor.py +++ b/src/scaler/worker/agent/processor/processor.py @@ -19,6 +19,7 @@ from scaler.protocol.python.common import ObjectMetadata, TaskResultType from scaler.protocol.python.message import ObjectInstruction, ProcessorInitialized, Task, TaskLog, TaskResult from scaler.protocol.python.mixins import Message +from scaler.utility.exceptions import ObjectStorageException from scaler.utility.identifiers import ClientID, ObjectID, TaskID from scaler.utility.logging.utility import setup_logger from scaler.utility.metadata.task_flags import retrieve_task_flags_from_task @@ -124,6 +125,7 @@ def __register_signals(self): def __interrupt(self, *args): self._connector_agent.destroy() # interrupts any blocking socket. + self._connector_storage.destroy() def __suspend(self, *args): assert self._resume_event is not None @@ -149,6 +151,9 @@ def __run_forever(self): if e.errno != zmq.ENOTSOCK: # ignore if socket got closed raise + except ObjectStorageException: + pass + except (KeyboardInterrupt, InterruptedError): pass @@ -160,6 +165,7 @@ def __run_forever(self): self._connector_agent.destroy() self._object_cache.join() + self._connector_storage.destroy() def __on_connector_receive(self, message: Message): if isinstance(message, ObjectInstruction): diff --git a/src/scaler/worker/worker.py b/src/scaler/worker/worker.py index 52300567d..2c79bbc6c 100644 --- a/src/scaler/worker/worker.py +++ b/src/scaler/worker/worker.py @@ -10,16 +10,21 @@ import zmq.asyncio from scaler.config.defaults import PROFILING_INTERVAL_SECONDS +from scaler.config.types.network_backend import NetworkBackend from scaler.config.types.object_storage_server import ObjectStorageAddressConfig from scaler.config.types.zmq import ZMQConfig, ZMQType from scaler.io.async_binder import ZMQAsyncBinder -from scaler.io.async_connector import ZMQAsyncConnector from scaler.io.mixins import AsyncBinder, AsyncConnector, AsyncObjectStorageConnector -from scaler.io.utility import create_async_object_storage_connector +from scaler.io.utility import ( + create_async_connector, + create_async_object_storage_connector, + get_scaler_network_backend_from_env, +) from scaler.io.ymq import ymq from scaler.protocol.python.message import ( ClientDisconnect, DisconnectRequest, + DisconnectResponse, ObjectInstruction, ProcessorInitialized, Task, @@ -30,7 +35,7 @@ ) from scaler.protocol.python.mixins import Message from scaler.utility.event_loop import create_async_loop_routine, register_event_loop -from scaler.utility.exceptions import ClientShutdownException +from scaler.utility.exceptions import ClientShutdownException, ObjectStorageException from scaler.utility.identifiers import ProcessorID, WorkerID from scaler.utility.logging.utility import setup_logger from scaler.worker.agent.heartbeat_manager import VanillaHeartbeatManager @@ -109,8 +114,9 @@ def __initialize(self): register_event_loop(self._event_loop) self._context = zmq.asyncio.Context() - self._connector_external = ZMQAsyncConnector( - context=self._context, + + self._connector_external = create_async_connector( + self._context, name=self.name, socket_type=zmq.DEALER, address=self._address, @@ -193,6 +199,11 @@ async def __on_receive_external(self, message: Message): logging.error(f"Worker received invalid ClientDisconnect type, ignoring {message=}") return + if isinstance(message, DisconnectResponse): + logging.error("Worker initiated DisconnectRequest got replied") + self._task.cancel() + return + raise TypeError(f"Unknown {message=}") async def __on_receive_internal(self, processor_id_bytes: bytes, message: Message): @@ -235,6 +246,9 @@ async def __get_loops(self): except asyncio.CancelledError: pass + except ObjectStorageException: + pass + # TODO: Should the object storage connector catch this error? except ymq.YMQException as e: if e.code == ymq.ErrorCode.ConnectorSocketClosedByRemoteEnd: @@ -246,11 +260,13 @@ async def __get_loops(self): except Exception as e: logging.exception(f"{self.identity!r}: failed with unhandled exception:\n{e}") - await self._connector_external.send(DisconnectRequest.new_msg(self.identity)) + if get_scaler_network_backend_from_env() == NetworkBackend.tcp_zmq: + await self._connector_external.send(DisconnectRequest.new_msg(self.identity)) self._connector_external.destroy() self._processor_manager.destroy("quit") self._binder_internal.destroy() + self._connector_storage.destroy() os.remove(self._address_path_internal) logging.info(f"{self.identity!r}: quit") @@ -259,7 +275,14 @@ def __run_forever(self): self._loop.run_until_complete(self._task) def __register_signal(self): - self._loop.add_signal_handler(signal.SIGINT, self.__destroy) + backend = get_scaler_network_backend_from_env() + if backend == NetworkBackend.tcp_zmq: + self._loop.add_signal_handler(signal.SIGINT, self.__destroy) + elif backend == NetworkBackend.ymq: + self._loop.add_signal_handler(signal.SIGINT, lambda: asyncio.ensure_future(self.__graceful_shutdown())) + + async def __graceful_shutdown(self): + await self._connector_external.send(DisconnectRequest.new_msg(self.identity)) def __destroy(self): self._task.cancel()