Skip to content

Commit f2100bd

Browse files
committed
adding graceful termination to executor
1 parent bf51577 commit f2100bd

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

executor/app/main.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import signal
23
import tempfile
34
import time
45
import datetime
@@ -14,6 +15,7 @@
1415
from geokube.core.datacube import DataCube
1516
from geokube.core.dataset import Dataset
1617
from geokube.core.field import Field
18+
from pika.exceptions import ChannelClosed
1719

1820
from datastore.datastore import Datastore
1921
from workflow import Workflow
@@ -277,6 +279,8 @@ def __init__(self, broker, store_path, dask_cluster_opts):
277279
self._channel = broker_conn.channel()
278280
self._db = DBManager()
279281
self.dask_cluster_opts = dask_cluster_opts
282+
self.to_exit = False
283+
self.processing = False
280284

281285
def create_dask_cluster(self, dask_cluster_opts: dict = None):
282286
if dask_cluster_opts is None:
@@ -307,21 +311,21 @@ def create_dask_cluster(self, dask_cluster_opts: dict = None):
307311

308312
def maybe_restart_cluster(self, status: RequestStatus):
309313
if status is RequestStatus.TIMEOUT:
310-
self._LOG.info("recreating the cluster due to timeout")
314+
self._LOG.info("recreating the cluster due to timeout", extra={"track_id": "N/A"})
311315
self._dask_client.cluster.close()
312316
self.create_dask_cluster()
313317
if self._dask_client.cluster.status is Status.failed:
314-
self._LOG.info("attempt to restart the cluster...")
318+
self._LOG.info("attempt to restart the cluster...", extra={"track_id": "N/A"})
315319
try:
316320
asyncio.run(self._nanny.restart())
317321
except Exception as err:
318322
self._LOG.error(
319-
"couldn't restart the cluster due to an error: %s", err
323+
"couldn't restart the cluster due to an error: %s", err, extra={"track_id": "N/A"}
320324
)
321-
self._LOG.info("closing the cluster")
325+
self._LOG.info("closing the cluster", extra={"track_id": "N/A"})
322326
self._dask_client.cluster.close()
323-
if self._dask_client.cluster.status is Status.closed:
324-
self._LOG.info("recreating the cluster")
327+
if self._dask_client.cluster.status is Status.closed and not self.to_exit:
328+
self._LOG.info("recreating the cluster", extra={"track_id": "N/A"})
325329
self.create_dask_cluster()
326330

327331
def ack_message(self, channel, delivery_tag):
@@ -330,9 +334,13 @@ def ack_message(self, channel, delivery_tag):
330334
"""
331335
if channel.is_open:
332336
channel.basic_ack(delivery_tag)
337+
self.processing = False
338+
if self.to_exit:
339+
channel.stop_consuming()
333340
else:
334341
self._LOG.info(
335-
"cannot acknowledge the message. channel is closed!"
342+
"cannot acknowledge the message. channel is closed!",
343+
extra={"track_id": "N/A"},
336344
)
337345
pass
338346

@@ -391,6 +399,7 @@ def retry_until_timeout(
391399
return location_path, status, fail_reason
392400

393401
def handle_message(self, connection, channel, delivery_tag, body):
402+
self.processing = True
394403
message: Message = Message(body)
395404
self._LOG.debug(
396405
"executing query: `%s`",
@@ -468,8 +477,21 @@ def subscribe(self, etype):
468477
)
469478

470479
def listen(self):
471-
while True:
472-
self._channel.start_consuming()
480+
while not self.to_exit:
481+
try:
482+
self._channel.start_consuming()
483+
except ChannelClosed as cc:
484+
self._LOG.error("Channel closed exiting...", extra={"track_id": "N/A"})
485+
self._LOG.info(f'Shutting down Dask...', extra={"track_id": "N/A"})
486+
self._dask_client.shutdown()
487+
self._LOG.info(f'Exiting...', extra={"track_id": "N/A"})
488+
exit(0)
489+
490+
def stop_listening(self, signo, frame):
491+
self._LOG.info(f'received signal {signo}:', extra={"track_id": "N/A"})
492+
self.to_exit = True
493+
if not self.processing:
494+
self._channel.stop_consuming()
473495

474496
def get_size(self, location_path):
475497
if location_path and os.path.exists(location_path):
@@ -503,5 +525,9 @@ def get_size(self, location_path):
503525

504526
executor.subscribe(etype)
505527

528+
print('registering signal handlers')
529+
signal.signal(signal.SIGTERM, executor.stop_listening)
530+
signal.signal(signal.SIGINT, executor.stop_listening)
531+
506532
print("waiting for requests ...")
507533
executor.listen()

0 commit comments

Comments
 (0)