@@ -52,7 +52,8 @@ def __init__(
5252 self ._engine_args = EngineArgs (** kwargs )
5353
5454 self .running = False
55- self ._running_mutex : Lock = Lock ()
55+ self ._running_lock : Optional [Lock ] = None
56+ self ._running_counter : int = 0
5657 self ._model_replicas , self ._min_inference_t , self ._max_inference_t = select_model (
5758 self ._engine_args
5859 )
@@ -81,27 +82,35 @@ def __str__(self) -> str:
8182
8283 async def astart (self ):
8384 """startup engine"""
84- await self ._running_mutex .acquire ()
85- if not self .running :
86- self .running = True
87- self ._batch_handler = BatchHandler (
88- max_batch_size = self ._engine_args .batch_size ,
89- model_replicas = self ._model_replicas ,
90- # batch_delay=self._min_inference_t / 2,
91- vector_disk_cache_path = self ._engine_args .vector_disk_cache_path ,
92- verbose = logger .level <= 10 ,
93- lengths_via_tokenize = self ._engine_args .lengths_via_tokenize ,
94- )
95- await self ._batch_handler .spawn ()
96-
97- async def astop (self ):
85+ if self ._running_lock is None :
86+ self ._running_lock = Lock ()
87+ async with self ._running_lock :
88+ # Counting the number of launches (when using multiple context managers asynchronously)
89+ self ._running_counter += 1
90+ if not self .running :
91+ self .running = True
92+ self ._batch_handler = BatchHandler (
93+ max_batch_size = self ._engine_args .batch_size ,
94+ model_replicas = self ._model_replicas ,
95+ # batch_delay=self._min_inference_t / 2,
96+ vector_disk_cache_path = self ._engine_args .vector_disk_cache_path ,
97+ verbose = logger .level <= 10 ,
98+ lengths_via_tokenize = self ._engine_args .lengths_via_tokenize ,
99+ )
100+ await self ._batch_handler .spawn ()
101+
102+ async def astop (self , * , force : bool = False ):
98103 """stop engine"""
99- if not self ._running_mutex . locked () :
104+ if self ._running_lock is None :
100105 return
101- if self .running :
102- self .running = False
103- await self ._batch_handler .shutdown ()
104- self ._running_mutex .release ()
106+ async with self ._running_lock :
107+ if force :
108+ self ._running_counter = 0
109+ if self ._running_counter > 0 :
110+ self ._running_counter -= 1
111+ if self .running and self ._running_counter == 0 :
112+ self .running = False
113+ await self ._batch_handler .shutdown ()
105114
106115 async def __aenter__ (self ):
107116 await self .astart ()
0 commit comments