diff --git a/src/scaler/protocol/capnp/message.capnp b/src/scaler/protocol/capnp/message.capnp index cd6c61443..13794cf43 100644 --- a/src/scaler/protocol/capnp/message.capnp +++ b/src/scaler/protocol/capnp/message.capnp @@ -105,6 +105,10 @@ struct DisconnectResponse { worker @0 :Data; } +struct WorkerDisconnectNotification { + worker @0 :Data; +} + struct ClientDisconnect { disconnectType @0 :DisconnectType; @@ -214,5 +218,6 @@ struct Message { informationRequest @23 :InformationRequest; informationResponse @24 :InformationResponse; + workerDisconnectNotification @25 :WorkerDisconnectNotification; } } diff --git a/src/scaler/protocol/python/message.py b/src/scaler/protocol/python/message.py index 8539ffde4..80e905906 100644 --- a/src/scaler/protocol/python/message.py +++ b/src/scaler/protocol/python/message.py @@ -424,6 +424,19 @@ def new_msg(worker: WorkerID) -> "DisconnectResponse": return DisconnectResponse(_message.DisconnectResponse(worker=bytes(worker))) +class WorkerDisconnectNotification(Message): + def __init__(self, msg): + super().__init__(msg) + + @property + def worker(self) -> WorkerID: + return WorkerID(self._msg.worker) + + @staticmethod + def new_msg(worker: WorkerID) -> "WorkerDisconnectNotification": + return WorkerDisconnectNotification(_message.WorkerDisconnectNotification(worker=bytes(worker))) + + class ClientDisconnect(Message): class DisconnectType(enum.Enum): Disconnect = _message.ClientDisconnect.DisconnectType.disconnect @@ -735,6 +748,7 @@ def workers(self) -> Dict[WorkerID, WorkerHeartbeat]: "workerHeartbeatEcho": WorkerHeartbeatEcho, "disconnectRequest": DisconnectRequest, "disconnectResponse": DisconnectResponse, + "workerDisconnectNotification": WorkerDisconnectNotification, "stateClient": StateClient, "stateObject": StateObject, "stateBalanceAdvice": StateBalanceAdvice, diff --git a/src/scaler/scheduler/controllers/mixins.py b/src/scaler/scheduler/controllers/mixins.py index b95ac1a29..f962fb86e 100644 --- a/src/scaler/scheduler/controllers/mixins.py +++ b/src/scaler/scheduler/controllers/mixins.py @@ -13,6 +13,7 @@ TaskCancel, TaskCancelConfirm, TaskResult, + WorkerDisconnectNotification, WorkerHeartbeat, ) from scaler.utility.identifiers import ClientID, ObjectID, TaskID, WorkerID @@ -175,6 +176,10 @@ async def on_client_shutdown(self, client_id: ClientID): async def on_disconnect(self, worker_id: WorkerID, request: DisconnectRequest): raise NotImplementedError() + @abc.abstractmethod + async def on_disconnect_notification(self, worker_id: WorkerID, notification: WorkerDisconnectNotification): + raise NotImplementedError() + @abc.abstractmethod def has_available_worker(self) -> bool: raise NotImplementedError() diff --git a/src/scaler/scheduler/controllers/worker_controller.py b/src/scaler/scheduler/controllers/worker_controller.py index 33d1d3cf7..046383598 100644 --- a/src/scaler/scheduler/controllers/worker_controller.py +++ b/src/scaler/scheduler/controllers/worker_controller.py @@ -11,6 +11,7 @@ StateWorker, Task, TaskCancel, + WorkerDisconnectNotification, WorkerHeartbeat, WorkerHeartbeatEcho, ) @@ -80,6 +81,9 @@ async def on_disconnect(self, worker_id: WorkerID, request: DisconnectRequest): await self.__disconnect_worker(request.worker) await self._binder.send(worker_id, DisconnectResponse.new_msg(request.worker)) + async def on_disconnect_notification(self, worker_id: WorkerID, notification: WorkerDisconnectNotification): + await self.__disconnect_worker(notification.worker) + async def routine(self): await self.__clean_workers() diff --git a/src/scaler/scheduler/scheduler.py b/src/scaler/scheduler/scheduler.py index defd6bf41..2daf91c14 100644 --- a/src/scaler/scheduler/scheduler.py +++ b/src/scaler/scheduler/scheduler.py @@ -24,6 +24,7 @@ TaskCancelConfirm, TaskLog, TaskResult, + WorkerDisconnectNotification, WorkerHeartbeat, ) from scaler.protocol.python.mixins import Message @@ -205,6 +206,10 @@ async def on_receive_message(self, source: bytes, message: Message): await self._worker_controller.on_disconnect(WorkerID(source), message) return + if isinstance(message, WorkerDisconnectNotification): + await self._worker_controller.on_disconnect_notification(WorkerID(source), message) + return + # ===================================================================================== # object manager if isinstance(message, ObjectInstruction): diff --git a/src/scaler/worker/worker.py b/src/scaler/worker/worker.py index c96e8f087..cb424e1f8 100644 --- a/src/scaler/worker/worker.py +++ b/src/scaler/worker/worker.py @@ -10,20 +10,14 @@ import zmq.asyncio from scaler.config.defaults import PROFILING_INTERVAL_SECONDS -from scaler.config.types.network_backend import NetworkBackend from scaler.config.types.object_storage_server import ObjectStorageAddressConfig from scaler.config.types.zmq import ZMQConfig, ZMQType from scaler.io.async_binder import ZMQAsyncBinder from scaler.io.mixins import AsyncBinder, AsyncConnector, AsyncObjectStorageConnector -from scaler.io.utility import ( - create_async_connector, - create_async_object_storage_connector, - get_scaler_network_backend_from_env, -) +from scaler.io.utility import create_async_connector, create_async_object_storage_connector from scaler.io.ymq import ymq from scaler.protocol.python.message import ( ClientDisconnect, - DisconnectRequest, DisconnectResponse, ObjectInstruction, ProcessorInitialized, @@ -31,6 +25,7 @@ TaskCancel, TaskLog, TaskResult, + WorkerDisconnectNotification, WorkerHeartbeatEcho, ) from scaler.protocol.python.mixins import Message @@ -269,8 +264,16 @@ async def __get_loops(self): except Exception as e: logging.exception(f"{self.identity!r}: failed with unhandled exception:\n{e}") - if get_scaler_network_backend_from_env() == NetworkBackend.tcp_zmq: - await self._connector_external.send(DisconnectRequest.new_msg(self.identity)) + try: + await self._connector_external.send(WorkerDisconnectNotification.new_msg(self.identity)) + except ymq.YMQException as e: + + # this means that the scheduler shut down before we could send our notification + # we don't consider this to be an error + if e.code == ymq.ErrorCode.ConnectorSocketClosedByRemoteEnd: + pass + else: + raise self._connector_external.destroy() self._processor_manager.destroy("quit") @@ -281,16 +284,8 @@ async def __get_loops(self): logging.info(f"{self.identity!r}: quit") def __register_signal(self): - backend = get_scaler_network_backend_from_env() - if backend == NetworkBackend.tcp_zmq: - self._loop.add_signal_handler(signal.SIGINT, self.__destroy) - self._loop.add_signal_handler(signal.SIGTERM, self.__destroy) - elif backend == NetworkBackend.ymq: - self._loop.add_signal_handler(signal.SIGINT, lambda: asyncio.ensure_future(self.__graceful_shutdown())) - self._loop.add_signal_handler(signal.SIGTERM, lambda: asyncio.ensure_future(self.__graceful_shutdown())) - - async def __graceful_shutdown(self): - await self._connector_external.send(DisconnectRequest.new_msg(self.identity)) + self._loop.add_signal_handler(signal.SIGINT, self.__destroy) + self._loop.add_signal_handler(signal.SIGTERM, self.__destroy) def __destroy(self): self._task.cancel()