Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/scaler/scheduler/controllers/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def acquire_worker(self, task: Task) -> Optional[WorkerID]:
raise NotImplementedError()

@abc.abstractmethod
async def on_task_cancel(self, task_cancel: TaskCancel) -> bytes:
async def on_task_cancel(self, task_cancel: TaskCancel) -> WorkerID:
raise NotImplementedError()

@abc.abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions src/scaler/scheduler/controllers/task_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ async def __state_failed_worker_died(

async def __send_task_cancel_to_worker(self, task_cancel: TaskCancel):
worker = await self._worker_controller.on_task_cancel(task_cancel)
if not worker:
assert isinstance(worker, WorkerID)
if not worker.is_valid():
logging.error(f"{task_cancel.task_id!r}: cannot find task in worker to cancel")
await self.__routing(
task_cancel.task_id,
Expand All @@ -300,7 +301,7 @@ async def __send_task_cancel_to_worker(self, task_cancel: TaskCancel):
)
return

await self._binder.send(worker, TaskCancel.new_msg(task_cancel.task_id))
await self._binder.send(worker, task_cancel)
await self.__send_monitor(task_cancel.task_id, b"")

async def __send_task_result_to_client(self, task_result: TaskResult):
Expand Down
5 changes: 2 additions & 3 deletions src/scaler/scheduler/controllers/worker_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ def register(self, binder: AsyncBinder, binder_monitor: AsyncConnector, task_con
def acquire_worker(self, task: Task) -> WorkerID:
return self._scaler_policy.assign_task(task)

async def on_task_cancel(self, task_cancel: TaskCancel):
async def on_task_cancel(self, task_cancel: TaskCancel) -> WorkerID:
worker = self._scaler_policy.remove_task(task_cancel.task_id)
if not worker.is_valid():
logging.error(f"cannot find task_id={task_cancel.task_id.hex()} in task workers")
return

await self._binder.send(worker, task_cancel)
return worker

async def on_task_done(self, task_id: TaskID) -> WorkerID:
worker = self._scaler_policy.remove_task(task_id)
Expand Down
2 changes: 1 addition & 1 deletion src/scaler/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.17.0
1.17.1
Loading