diff --git a/src/scaler/scheduler/controllers/mixins.py b/src/scaler/scheduler/controllers/mixins.py index b95ac1a29..a7fbf465f 100644 --- a/src/scaler/scheduler/controllers/mixins.py +++ b/src/scaler/scheduler/controllers/mixins.py @@ -156,7 +156,7 @@ def acquire_worker(self, task: Task) -> Optional[WorkerID]: raise NotImplementedError() @abc.abstractmethod - async def on_task_cancel(self, task_cancel: TaskCancel) -> bytes: + async def on_task_cancel(self, task_cancel: TaskCancel) -> WorkerID: raise NotImplementedError() @abc.abstractmethod diff --git a/src/scaler/scheduler/controllers/task_controller.py b/src/scaler/scheduler/controllers/task_controller.py index 2245fbdb2..830c8913f 100644 --- a/src/scaler/scheduler/controllers/task_controller.py +++ b/src/scaler/scheduler/controllers/task_controller.py @@ -289,7 +289,8 @@ async def __state_failed_worker_died( async def __send_task_cancel_to_worker(self, task_cancel: TaskCancel): worker = await self._worker_controller.on_task_cancel(task_cancel) - if not worker: + assert isinstance(worker, WorkerID) + if not worker.is_valid(): logging.error(f"{task_cancel.task_id!r}: cannot find task in worker to cancel") await self.__routing( task_cancel.task_id, @@ -300,7 +301,7 @@ async def __send_task_cancel_to_worker(self, task_cancel: TaskCancel): ) return - await self._binder.send(worker, TaskCancel.new_msg(task_cancel.task_id)) + await self._binder.send(worker, task_cancel) await self.__send_monitor(task_cancel.task_id, b"") async def __send_task_result_to_client(self, task_result: TaskResult): diff --git a/src/scaler/scheduler/controllers/worker_controller.py b/src/scaler/scheduler/controllers/worker_controller.py index 33d1d3cf7..3fbb67f63 100644 --- a/src/scaler/scheduler/controllers/worker_controller.py +++ b/src/scaler/scheduler/controllers/worker_controller.py @@ -43,13 +43,12 @@ def register(self, binder: AsyncBinder, binder_monitor: AsyncConnector, task_con def acquire_worker(self, task: Task) -> WorkerID: return self._scaler_policy.assign_task(task) - async def on_task_cancel(self, task_cancel: TaskCancel): + async def on_task_cancel(self, task_cancel: TaskCancel) -> WorkerID: worker = self._scaler_policy.remove_task(task_cancel.task_id) if not worker.is_valid(): logging.error(f"cannot find task_id={task_cancel.task_id.hex()} in task workers") - return - await self._binder.send(worker, task_cancel) + return worker async def on_task_done(self, task_id: TaskID) -> WorkerID: worker = self._scaler_policy.remove_task(task_id) diff --git a/src/scaler/version.txt b/src/scaler/version.txt index 092afa15d..511a76e6f 100644 --- a/src/scaler/version.txt +++ b/src/scaler/version.txt @@ -1 +1 @@ -1.17.0 +1.17.1