|
16 | 16 | from pprint import pformat
|
17 | 17 | import time
|
18 | 18 | from dask.distributed import wait, default_client
|
| 19 | +import logging |
| 20 | +from distributed.diagnostics.plugin import WorkerPlugin, SchedulerPlugin |
| 21 | +from distributed.scheduler import Scheduler |
19 | 22 | from dask import persist
|
20 | 23 | from dask.distributed import Client
|
21 | 24 | from dask.base import is_dask_collection
|
|
27 | 30 | import numpy as np
|
28 | 31 |
|
29 | 32 |
|
| 33 | +class GracefullyRetireWorkers(WorkerPlugin): |
| 34 | + def __init__(self, logger): |
| 35 | + self.logger = logger |
| 36 | + self.count = 0 |
| 37 | + self.key = None |
| 38 | + self.state = 1 |
| 39 | + |
| 40 | + async def remove_worker(self, scheduler, worker: str, *, stimulus_id, **kwargs) : |
| 41 | + print("a worker is leaving the cluster and state = ", self.state, " count = ", self.count, flush=True) |
| 42 | + #wait(scheduler.retire_workers()) |
| 43 | + if self.state == -1: |
| 44 | + self.logger.critical(" Worker %s left the cluster", worker) |
| 45 | + if self.count == 0: |
| 46 | + self.logger.critical(" An error occured: retiring all workers") |
| 47 | + self.count += 1 |
| 48 | + await scheduler.retire_workers() |
| 49 | + |
| 50 | + def setup(self, worker): |
| 51 | + self.worker = worker |
| 52 | + |
| 53 | + def transition(self, key, start, finish, *args, **kwargs): |
| 54 | + if finish in ['error', 'erred']: |
| 55 | + print("transition = ", finish) |
| 56 | + self.state = -1 |
| 57 | + |
| 58 | + |
30 | 59 | def start_dask_client(
|
31 | 60 | protocol=None,
|
32 | 61 | rmm_async=False,
|
@@ -157,6 +186,9 @@ def start_dask_client(
|
157 | 186 | num_workers = len(dask_worker_devices.split(","))
|
158 | 187 |
|
159 | 188 | client.wait_for_workers(num_workers)
|
| 189 | + |
| 190 | + s_plugin = GracefullyRetireWorkers(logging) |
| 191 | + client.register_plugin(s_plugin) |
160 | 192 | # Add a reference to tempdir_object to the client to prevent it from
|
161 | 193 | # being deleted when this function returns. This will be deleted in
|
162 | 194 | # stop_dask_client()
|
|
0 commit comments