Skip to content

Commit 839580c

Browse files
author
Michael Harrison
committed
added lock/dict to prevent multiple deployments of same model_name
1 parent 574b72d commit 839580c

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

eureka_ml_insights/models/models.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import random
66
import requests
7+
import threading
78
import time
89
import urllib.request
910
from abc import ABC, abstractmethod
@@ -1343,10 +1344,19 @@ def deploy_server(self, index: int, gpus_per_port: int, log_dir: str):
13431344
with open(log_file, 'w') as log_writer:
13441345
subprocess.run(command, shell=True, stdout=log_writer, stderr=log_writer)
13451346

1347+
1348+
local_vllm_model_lock = threading.Lock()
1349+
local_vllm_deployment_handlers : dict[str, _LocalVLLMDeploymentHandler] = {}
1350+
13461351

13471352
@dataclass
13481353
class LocalVLLMModel(Model, OpenAICommonRequestResponseMixIn):
1349-
"""This class is used when you have multiple vLLM servers running locally."""
1354+
"""This class is used for vLLM servers running locally.
1355+
1356+
In case the servers are already deployed, specify the
1357+
model_name and the ports at which the servers are hosted.
1358+
Otherwise instantiating will initiate a deployment with
1359+
any deployment parameters specified."""
13501360

13511361
model_name: str = None
13521362

@@ -1374,19 +1384,27 @@ class LocalVLLMModel(Model, OpenAICommonRequestResponseMixIn):
13741384
def __post_init__(self):
13751385
if not self.model_name:
13761386
raise ValueError("LocalVLLM model_name must be specified.")
1377-
self.handler = _LocalVLLMDeploymentHandler(
1378-
model_name=self.model_name,
1379-
num_servers=self.num_servers,
1380-
trust_remote_code=self.trust_remote_code,
1381-
pipeline_parallel_size=self.pipeline_parallel_size,
1382-
tensor_parallel_size=self.tensor_parallel_size,
1383-
dtype=self.dtype,
1384-
quantization=self.quantization,
1385-
seed=self.seed,
1386-
gpu_memory_utilization=self.gpu_memory_utilization,
1387-
cpu_offload_gb=self.cpu_offload_gb,
1388-
ports=self.ports,
1389-
)
1387+
self.handler = self._get_local_vllm_deployment_handler()
1388+
1389+
def _get_local_vllm_deployment_handler(self):
1390+
if self.model_name not in local_vllm_deployment_handlers:
1391+
with local_vllm_model_lock:
1392+
if self.model_name not in local_vllm_deployment_handlers:
1393+
local_vllm_deployment_handlers['self.model_name'] = _LocalVLLMDeploymentHandler(
1394+
model_name=self.model_name,
1395+
num_servers=self.num_servers,
1396+
trust_remote_code=self.trust_remote_code,
1397+
pipeline_parallel_size=self.pipeline_parallel_size,
1398+
tensor_parallel_size=self.tensor_parallel_size,
1399+
dtype=self.dtype,
1400+
quantization=self.quantization,
1401+
seed=self.seed,
1402+
gpu_memory_utilization=self.gpu_memory_utilization,
1403+
cpu_offload_gb=self.cpu_offload_gb,
1404+
ports=self.ports,
1405+
)
1406+
1407+
return local_vllm_deployment_handlers['self.model_name']
13901408

13911409
def _generate(self, request):
13921410

0 commit comments

Comments
 (0)