11import os
2+ import signal
23import tempfile
34import time
45import datetime
1415from geokube .core .datacube import DataCube
1516from geokube .core .dataset import Dataset
1617from geokube .core .field import Field
18+ from pika .exceptions import ChannelClosed
1719
1820from datastore .datastore import Datastore
1921from workflow import Workflow
@@ -277,6 +279,8 @@ def __init__(self, broker, store_path, dask_cluster_opts):
277279 self ._channel = broker_conn .channel ()
278280 self ._db = DBManager ()
279281 self .dask_cluster_opts = dask_cluster_opts
282+ self .to_exit = False
283+ self .processing = False
280284
281285 def create_dask_cluster (self , dask_cluster_opts : dict = None ):
282286 if dask_cluster_opts is None :
@@ -307,21 +311,21 @@ def create_dask_cluster(self, dask_cluster_opts: dict = None):
307311
308312 def maybe_restart_cluster (self , status : RequestStatus ):
309313 if status is RequestStatus .TIMEOUT :
310- self ._LOG .info ("recreating the cluster due to timeout" )
314+ self ._LOG .info ("recreating the cluster due to timeout" , extra = { "track_id" : "N/A" } )
311315 self ._dask_client .cluster .close ()
312316 self .create_dask_cluster ()
313317 if self ._dask_client .cluster .status is Status .failed :
314- self ._LOG .info ("attempt to restart the cluster..." )
318+ self ._LOG .info ("attempt to restart the cluster..." , extra = { "track_id" : "N/A" } )
315319 try :
316320 asyncio .run (self ._nanny .restart ())
317321 except Exception as err :
318322 self ._LOG .error (
319- "couldn't restart the cluster due to an error: %s" , err
323+ "couldn't restart the cluster due to an error: %s" , err , extra = { "track_id" : "N/A" }
320324 )
321- self ._LOG .info ("closing the cluster" )
325+ self ._LOG .info ("closing the cluster" , extra = { "track_id" : "N/A" } )
322326 self ._dask_client .cluster .close ()
323- if self ._dask_client .cluster .status is Status .closed :
324- self ._LOG .info ("recreating the cluster" )
327+ if self ._dask_client .cluster .status is Status .closed and not self . to_exit :
328+ self ._LOG .info ("recreating the cluster" , extra = { "track_id" : "N/A" } )
325329 self .create_dask_cluster ()
326330
327331 def ack_message (self , channel , delivery_tag ):
@@ -330,9 +334,13 @@ def ack_message(self, channel, delivery_tag):
330334 """
331335 if channel .is_open :
332336 channel .basic_ack (delivery_tag )
337+ self .processing = False
338+ if self .to_exit :
339+ channel .stop_consuming ()
333340 else :
334341 self ._LOG .info (
335- "cannot acknowledge the message. channel is closed!"
342+ "cannot acknowledge the message. channel is closed!" ,
343+ extra = {"track_id" : "N/A" },
336344 )
337345 pass
338346
@@ -391,6 +399,7 @@ def retry_until_timeout(
391399 return location_path , status , fail_reason
392400
393401 def handle_message (self , connection , channel , delivery_tag , body ):
402+ self .processing = True
394403 message : Message = Message (body )
395404 self ._LOG .debug (
396405 "executing query: `%s`" ,
@@ -468,8 +477,21 @@ def subscribe(self, etype):
468477 )
469478
470479 def listen (self ):
471- while True :
472- self ._channel .start_consuming ()
480+ while not self .to_exit :
481+ try :
482+ self ._channel .start_consuming ()
483+ except ChannelClosed as cc :
484+ self ._LOG .error ("Channel closed exiting..." , extra = {"track_id" : "N/A" })
485+ self ._LOG .info (f'Shutting down Dask...' , extra = {"track_id" : "N/A" })
486+ self ._dask_client .shutdown ()
487+ self ._LOG .info (f'Exiting...' , extra = {"track_id" : "N/A" })
488+ exit (0 )
489+
490+ def stop_listening (self , signo , frame ):
491+ self ._LOG .info (f'received signal { signo } :' , extra = {"track_id" : "N/A" })
492+ self .to_exit = True
493+ if not self .processing :
494+ self ._channel .stop_consuming ()
473495
474496 def get_size (self , location_path ):
475497 if location_path and os .path .exists (location_path ):
@@ -503,5 +525,9 @@ def get_size(self, location_path):
503525
504526 executor .subscribe (etype )
505527
528+ print ('registering signal handlers' )
529+ signal .signal (signal .SIGTERM , executor .stop_listening )
530+ signal .signal (signal .SIGINT , executor .stop_listening )
531+
506532 print ("waiting for requests ..." )
507533 executor .listen ()
0 commit comments