Skip to content

Commit 07dfe74

Browse files
committed
fixed multi-context calls
1 parent 87f8797 commit 07dfe74

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

libs/infinity_emb/infinity_emb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from infinity_emb.log_handler import logger # noqa: E402
1313
from infinity_emb.sync_engine import SyncEngineArray # noqa: E402
1414

15-
#__version__: str = importlib.metadata.version("infinity_emb")
15+
__version__: str = importlib.metadata.version("infinity_emb")
1616

1717
__all__ = [
1818
"__version__",

libs/infinity_emb/infinity_emb/engine.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(
5252
self._engine_args = EngineArgs(**kwargs)
5353

5454
self.running = False
55-
self._running_mutex: Lock = Lock()
55+
self._running_lock: Optional[Lock] = None
56+
self._running_counter: int = 0
5657
self._model_replicas, self._min_inference_t, self._max_inference_t = select_model(
5758
self._engine_args
5859
)
@@ -81,27 +82,35 @@ def __str__(self) -> str:
8182

8283
async def astart(self):
8384
"""startup engine"""
84-
await self._running_mutex.acquire()
85-
if not self.running:
86-
self.running = True
87-
self._batch_handler = BatchHandler(
88-
max_batch_size=self._engine_args.batch_size,
89-
model_replicas=self._model_replicas,
90-
# batch_delay=self._min_inference_t / 2,
91-
vector_disk_cache_path=self._engine_args.vector_disk_cache_path,
92-
verbose=logger.level <= 10,
93-
lengths_via_tokenize=self._engine_args.lengths_via_tokenize,
94-
)
95-
await self._batch_handler.spawn()
96-
97-
async def astop(self):
85+
if self._running_lock is None:
86+
self._running_lock = Lock()
87+
async with self._running_lock:
88+
# Counting the number of launches (when using multiple context managers asynchronously)
89+
self._running_counter += 1
90+
if not self.running:
91+
self.running = True
92+
self._batch_handler = BatchHandler(
93+
max_batch_size=self._engine_args.batch_size,
94+
model_replicas=self._model_replicas,
95+
# batch_delay=self._min_inference_t / 2,
96+
vector_disk_cache_path=self._engine_args.vector_disk_cache_path,
97+
verbose=logger.level <= 10,
98+
lengths_via_tokenize=self._engine_args.lengths_via_tokenize,
99+
)
100+
await self._batch_handler.spawn()
101+
102+
async def astop(self, *, force: bool = False):
98103
"""stop engine"""
99-
if not self._running_mutex.locked():
104+
if self._running_lock is None:
100105
return
101-
if self.running:
102-
self.running = False
103-
await self._batch_handler.shutdown()
104-
self._running_mutex.release()
106+
async with self._running_lock:
107+
if force:
108+
self._running_counter = 0
109+
if self._running_counter > 0:
110+
self._running_counter -= 1
111+
if self.running and self._running_counter == 0:
112+
self.running = False
113+
await self._batch_handler.shutdown()
105114

106115
async def __aenter__(self):
107116
await self.astart()

0 commit comments

Comments
 (0)