1717import threading
1818from concurrent .futures import ThreadPoolExecutor
1919from datetime import timedelta
20- from typing import Dict , Optional
20+ from typing import Dict , Optional , Set
2121
2222from aws_advanced_python_wrapper .allowed_and_blocked_hosts import \
2323 AllowedAndBlockedHosts
@@ -37,6 +37,10 @@ def __init__(self) -> None:
3737 self ._storage_service : Optional [StorageService ] = None
3838 self ._monitor_service : Optional [MonitorService ] = None
3939 self ._thread_pools : Dict [str , ThreadPoolExecutor ] = {}
40+ # Some service pools must be drained BEFORE the monitor service is released and connections are closed.
41+ # This prevents worker threads like the topology util threads from continuing to using connections
42+ # after they are closed, causing segfaults.
43+ self ._drain_first_pools : Set [str ] = set ()
4044 self ._lock = threading .Lock ()
4145
4246 def _ensure_initialized (self ) -> None :
@@ -63,19 +67,24 @@ def monitor_service(self) -> MonitorService:
6367 self ._ensure_initialized ()
6468 return self ._monitor_service # type: ignore
6569
66- def get_thread_pool (self , name : str , max_workers : Optional [int ] = None ) -> ThreadPoolExecutor :
70+ def get_thread_pool (self , name : str , max_workers : Optional [int ] = None , drain_first : bool = False ) -> ThreadPoolExecutor :
6771 pool = self ._thread_pools .get (name )
6872 if pool is not None :
73+ if drain_first :
74+ self ._drain_first_pools .add (name )
6975 return pool
7076 with self ._lock :
7177 if name not in self ._thread_pools :
7278 self ._thread_pools [name ] = ThreadPoolExecutor (
7379 max_workers = max_workers , thread_name_prefix = name )
80+ if drain_first :
81+ self ._drain_first_pools .add (name )
7482 return self ._thread_pools [name ]
7583
7684 def release_thread_pool (self , name : str , wait : bool = True ) -> bool :
7785 with self ._lock :
7886 pool = self ._thread_pools .pop (name , None )
87+ self ._drain_first_pools .discard (name )
7988 if pool is not None :
8089 try :
8190 pool .shutdown (wait = wait )
@@ -85,6 +94,18 @@ def release_thread_pool(self, name: str, wait: bool = True) -> bool:
8594 return False
8695
8796 def release_resources (self ) -> None :
97+ # Some thread pools need to be drained first before shutting down the monitor services.
98+ # This prevents segfaults when monitor services shut down and close all the active monitoring connections.
99+ with self ._lock :
100+ drain_names = list (self ._drain_first_pools )
101+ for name in drain_names :
102+ pool = self ._thread_pools .get (name )
103+ if pool is not None :
104+ try :
105+ pool .shutdown (wait = True )
106+ except Exception as e :
107+ logger .debug ("CoreServices.ErrorShuttingDownPool" , name , e )
108+
88109 if self ._monitor_service is not None :
89110 try :
90111 self ._monitor_service .release_resources ()
@@ -114,6 +135,7 @@ def release_resources(self) -> None:
114135 except Exception as e :
115136 logger .debug ("CoreServices.ErrorShuttingDownPool" , name , e )
116137 self ._thread_pools .clear ()
138+ self ._drain_first_pools .clear ()
117139
118140
119141_instance = _ServicesContainer ()
@@ -132,8 +154,8 @@ def get_monitor_service() -> MonitorService:
132154 return _instance .monitor_service
133155
134156
135- def get_thread_pool (name : str , max_workers : Optional [int ] = None ) -> ThreadPoolExecutor :
136- return _instance .get_thread_pool (name , max_workers )
157+ def get_thread_pool (name : str , max_workers : Optional [int ] = None , drain_first : bool = False ) -> ThreadPoolExecutor :
158+ return _instance .get_thread_pool (name , max_workers , drain_first )
137159
138160
139161def release_thread_pool (name : str , wait : bool = True ) -> bool :
0 commit comments