diff --git a/deploy_vllm_and_run_eval.sh b/deploy_vllm_and_run_eval.sh new file mode 100755 index 0000000..f633743 --- /dev/null +++ b/deploy_vllm_and_run_eval.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +export PYTHONPATH="$(pwd):$PYTHONPATH" +model_name="microsoft/phi-4" +exp_config="IFEval_PIPELINE" +current_datetime=$(date +"%Y-%m-%d-%H:%M:%S") +log_dir="logs/deploy_vllm_and_run_eval/$current_datetime" +mkdir -p $log_dir + +# vLLM args +num_servers=4 +tensor_parallel_size=1 +pipeline_parallel_size=1 +base_port=8000 +gpus_per_port=$((tensor_parallel_size * pipeline_parallel_size)) + +# Add any additional args accepted by vLLM serve here +VLLM_ARGS="\ + --tensor-parallel-size=${tensor_parallel_size} \ + --pipeline-parallel-size=${pipeline_parallel_size} \ + --gpu-memory-utilization=0.9 \ +" + +# Start servers +echo "Spinning up servers..." +for (( i = 0; i < $num_servers; i++ )) do + port=$((base_port + i)) + first_gpu=$((i * gpus_per_port)) + last_gpu=$((first_gpu + gpus_per_port - 1)) + devices=$(seq -s, $first_gpu $last_gpu) + CUDA_VISIBLE_DEVICES=${devices} vllm serve ${model_name} "$@" --port ${port} ${VLLM_ARGS} >> $log_dir/${port}.log 2>&1 & +done + +# Wait for servers to come online +while true; do + + servers_online=0 + for (( i = 0; i < $num_servers; i++ )) do + port=$((base_port + i)) + url="http://0.0.0.0:${port}/health" + response=$(curl -s -o /dev/null -w "%{http_code}" "$url") + + if [ "$response" -eq 200 ]; then + servers_online=$((servers_online + 1)) + fi + done + + if [ $servers_online -eq $num_servers ]; then + echo "All servers are online." + break + else + echo "Waiting for $((num_servers - servers_online)) more servers to come online..." + fi + + sleep 10 +done + +# Call Eureka to initiate evals +ports=$(seq -s ' ' $base_port $((base_port + num_servers - 1))) +EUREKA_ARGS="\ + --model_config=${model_name} \ + --exp_config=${exp_config} \ + --local_vllm \ + --ports ${ports} \ +" +echo "Starting evals..." +python main.py ${EUREKA_ARGS} >> $log_dir/out.log 2>&1 + +# Shut down servers +echo "Shutting down vLLM servers..." +for (( i = 0; i < $num_servers; i++ )) do + port=$((base_port + i)) + logfile="$log_dir/${port}.log" + pid=$(grep "Started server process" $logfile | grep -o '[0-9]\+') + echo "Shutting down server on port ${port} (PID ${pid})" + kill -INT $pid +done \ No newline at end of file diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 532ab94..1956feb 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -12,6 +12,7 @@ LlamaServerlessAzureRestEndpointModel, LLaVAHuggingFaceModel, LLaVAModel, + LocalVLLMModel, Phi4HFModel, MistralServerlessAzureRestEndpointModel, DeepseekR1ServerlessAzureRestEndpointModel, @@ -297,6 +298,27 @@ }, ) +# Local VLLM Models +# Adapt to your local deployments, or give enough info for vllm deployment. +PHI4_LOCAL_CONFIG = ModelConfig( + LocalVLLMModel, + { + # this name must match the vllm deployment name/path + "model_name": "microsoft/phi-4", + # specify ports in case the model is already deployed + "ports": ["8002", "8003"], + }, +) +QWQ32B_LOCAL_CONFIG = ModelConfig( + LocalVLLMModel, + { + # this name must match the vllm deployment name/path + "model_name": "Qwen/QwQ-32B", + # certain args will get passed to the vllm serve command + "tensor_parallel_size": 2, + }, +) + # DeepSeek R1 Endpoints on Azure DEEPSEEK_R1_CONFIG = ModelConfig( DeepseekR1ServerlessAzureRestEndpointModel, @@ -311,4 +333,4 @@ # the timeout parameter is passed to urllib.request.urlopen(request, timeout=self.timeout) in ServerlessAzureRestEndpointModel "timeout": 600, }, -) \ No newline at end of file +) diff --git a/eureka_ml_insights/models/__init__.py b/eureka_ml_insights/models/__init__.py index 2aa914c..931bc3d 100644 --- a/eureka_ml_insights/models/__init__.py +++ b/eureka_ml_insights/models/__init__.py @@ -11,6 +11,7 @@ LlamaServerlessAzureRestEndpointModel, LLaVAHuggingFaceModel, LLaVAModel, + LocalVLLMModel, MistralServerlessAzureRestEndpointModel, DeepseekR1ServerlessAzureRestEndpointModel, Phi3HFModel, @@ -38,6 +39,7 @@ LlamaServerlessAzureRestEndpointModel, DeepseekR1ServerlessAzureRestEndpointModel, LLaVAModel, + LocalVLLMModel, RestEndpointModel, TestModel, VLLMModel, diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index 97d0321..43d5a6e 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -2,6 +2,9 @@ import json import logging +import random +import requests +import threading import time import urllib.request from abc import ABC, abstractmethod @@ -130,6 +133,7 @@ def generate(self, query_text, *args, **kwargs): model_output = None is_valid = False response_time = None + n_output_tokens = None while attempts < self.num_retries: try: @@ -138,6 +142,7 @@ def generate(self, query_text, *args, **kwargs): response_dict.update(model_response) model_output = model_response["model_output"] response_time = model_response["response_time"] + n_output_tokens = model_response.get("n_output_tokens", None) if self.chat_mode: previous_messages = self.update_chat_history(query_text, model_output, *args, **kwargs) @@ -157,7 +162,7 @@ def generate(self, query_text, *args, **kwargs): "is_valid": is_valid, "model_output": model_output, "response_time": response_time, - "n_output_tokens": self.count_tokens(model_output, is_valid), + "n_output_tokens": n_output_tokens or self.count_tokens(model_output, is_valid), } ) if self.chat_mode: @@ -474,7 +479,10 @@ def get_response(self, request): "response_time": response_time, } if "usage" in openai_response: - response_dict.update({"usage": openai_response["usage"]}) + usage = openai_response["usage"] + response_dict.update({"usage": usage}) + if isinstance(usage, dict) and "completion_tokens" in usage: + response_dict.update({"n_output_tokens": usage["completion_tokens"]}) return response_dict @@ -1266,6 +1274,232 @@ def create_request(self, text_prompt, system_message=None): ] else: return [{"role": "user", "content": text_prompt}] + + +class _LocalVLLMDeploymentHandler: + """This class is used to handle the deployment of vLLM servers.""" + # Chose against dataclass here so we have the option to accept kwargs + # and pass them to the vLLM deployment script. + + # Used to store references to logs of the servers, since those contain PIDs for shutdown. + logs = [] + + def __init__( + self, + model_name: str = None, + num_servers: int = 1, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + dtype: str = "auto", + quantization: str = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + cpu_offload_gb: float = 0, + ports: list = None, + ): + if not model_name: + raise ValueError("LocalVLLM model_name must be specified.") + self.model_name = model_name + self.num_servers = num_servers + self.trust_remote_code = trust_remote_code + self.tensor_parallel_size = tensor_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.dtype = dtype + self.quantization = quantization + self.seed = seed + self.gpu_memory_utilization = gpu_memory_utilization + self.cpu_offload_gb = cpu_offload_gb + + self.ports = ports + self.session = requests.Session() + self.clients = self._get_clients() + + def _get_clients(self): + '''Get clients to access vllm servers, by checking for running servers and deploying if necessary.''' + from openai import OpenAI as OpenAIClient + + # If the user passes ports, check if the servers are running and populate clients accordingly. + if self.ports: + healthy_server_urls = ['http://0.0.0.0:' + port + '/v1' for port in self.get_healthy_ports()] + if len(healthy_server_urls) > 0: + logging.info(f"Found {len(healthy_server_urls)} healthy servers.") + return [OpenAIClient(base_url=url, api_key = 'none') for url in healthy_server_urls] + + # Even if the user doesn't pass ports, we can check if there happen to be deployed servers. + # There is no guarantee that the servers are hosting the correct model. + self.ports = [str(8000 + i) for i in range(self.num_servers)] + healthy_server_urls = ['http://0.0.0.0:' + port + '/v1' for port in self.get_healthy_ports()] + if len(healthy_server_urls) == self.num_servers: + logging.info(f"Found {len(healthy_server_urls)} healthy servers.") + return [OpenAIClient(base_url=url, api_key = 'none') for url in healthy_server_urls] + + # If that didn't work, let's deploy and wait for servers to come online. + self.deploy_servers() + server_start_time = time.time() + while time.time() - server_start_time < 600: + time.sleep(10) + healthy_ports = self.get_healthy_ports() + if len(healthy_ports) == self.num_servers: + logging.info(f"All {self.num_servers} servers are online.") + healthy_server_urls = ['http://0.0.0.0:' + port + '/v1' for port in healthy_ports] + return [OpenAIClient(base_url=url, api_key = 'none') for url in healthy_server_urls] + else: + logging.info(f"Waiting for {self.num_servers - len(healthy_ports)} more servers to come online.") + + if len(self.clients) != self.num_servers: + raise RuntimeError(f"Failed to start all servers.") + + def get_healthy_ports(self) -> list[str]: + """Check if servers are running.""" + + healthy_ports = [] + for port in self.ports: + try: + self.session.get('http://0.0.0.0:' + port +'/health') + healthy_ports.append(port) + except: + pass + return healthy_ports + + def deploy_servers(self): + """Deploy vLLM servers in background threads using the specified parameters.""" + + logging.info(f"No vLLM servers are running. Starting {self.num_servers} new servers at {self.ports}.") + import os, datetime + + gpus_per_port = self.pipeline_parallel_size * self.tensor_parallel_size + date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.%f") + log_dir = os.path.join("logs", "local_vllm_deployment_logs", f"{date}") + os.makedirs(log_dir) + + for index in range(self.num_servers): + port = 8000 + index + log_file = os.path.join(log_dir, f"{port}.log") + self.logs.append(log_file) + background_thread = threading.Thread( + target = lambda: self.deploy_server(index, gpus_per_port, log_file) + ) + background_thread.daemon = True + background_thread.start() + + def deploy_server(self, index: int, gpus_per_port: int, log_file: str): + """Deploy a single vLLM server using gpus_per_port many gpus starting at index*gpus_per_port.""" + + import subprocess + port = 8000 + index + first_gpu = index * gpus_per_port + last_gpu = first_gpu + gpus_per_port - 1 + devices = ",".join(str(gpu_num) for gpu_num in range(first_gpu, last_gpu + 1)) + + command = [ + "CUDA_VISIBLE_DEVICES=" + devices, + "vllm serve", + self.model_name, + "--port", str(port), + "--tensor_parallel_size", str(self.tensor_parallel_size), + "--pipeline_parallel_size", str(self.pipeline_parallel_size), + "--dtype", self.dtype, + "--seed", str(self.seed), + "--gpu_memory_utilization", str(self.gpu_memory_utilization), + "--cpu_offload_gb", str(self.cpu_offload_gb) + ] + if self.quantization: + command.append("--quantization") + command.append(self.quantization) + if self.trust_remote_code: + command.append("--trust_remote_code") + command = " ".join(command) + logging.info(f"Running command: {command}") + with open(log_file, 'w') as log_writer: + subprocess.run(command, shell=True, stdout=log_writer, stderr=log_writer) + + @classmethod + def shutdown_servers(cls): + """Shutdown all vLLM servers deployed during this run.""" + + import re, os, signal + for log_file in cls.logs: + with open(log_file, "r") as f: + for line in f: + if "Started server process" in line: + match = re.search(r"\d+", line) + if match: + pid = int(match.group()) + logging.info(f"Shutting down server with PID {pid}.") + os.kill(pid, signal.SIGINT) + break + + +local_vllm_model_lock = threading.Lock() +local_vllm_deployment_handlers : dict[str, _LocalVLLMDeploymentHandler] = {} + + +@dataclass +class LocalVLLMModel(OpenAICommonRequestResponseMixIn, EndpointModel): + """This class is used for vLLM servers running locally. + + In case the servers are already deployed, specify the + model_name and the ports at which the servers are hosted. + Otherwise instantiating will initiate a deployment with + any deployment parameters specified.""" + + model_name: str = None + + # Deployment parameters + num_servers: int = 1 + trust_remote_code: bool = False + tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + dtype: str = "auto" + quantization: str = None + seed: int = 0 + gpu_memory_utilization: float = 0.9 + cpu_offload_gb: float = 0 + + # Deployment handler + ports: list = None + handler: _LocalVLLMDeploymentHandler = None + + # Inference parameters + temperature: float = 0.01 + top_p: float = .95 + top_k: int = -1 + max_tokens: int = 2000 + frequency_penalty: float = 0 + presence_penalty: float = 0 + + def __post_init__(self): + if not self.model_name: + raise ValueError("LocalVLLM model_name must be specified.") + self.handler = self._get_local_vllm_deployment_handler() + + @property + def client(self): + return random.choice(self.handler.clients) + + def _get_local_vllm_deployment_handler(self): + if self.model_name not in local_vllm_deployment_handlers: + with local_vllm_model_lock: + if self.model_name not in local_vllm_deployment_handlers: + local_vllm_deployment_handlers[self.model_name] = _LocalVLLMDeploymentHandler( + model_name=self.model_name, + num_servers=self.num_servers, + trust_remote_code=self.trust_remote_code, + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + dtype=self.dtype, + quantization=self.quantization, + seed=self.seed, + gpu_memory_utilization=self.gpu_memory_utilization, + cpu_offload_gb=self.cpu_offload_gb, + ports=self.ports, + ) + + return local_vllm_deployment_handlers[self.model_name] + + def handle_request_error(self, e): + return False @dataclass diff --git a/main.py b/main.py index 433dda3..fae95a1 100755 --- a/main.py +++ b/main.py @@ -21,6 +21,9 @@ parser.add_argument( "--resume_from", type=str, help="The path to the inference_result.jsonl to resume from.", default=None ) + parser.add_argument("--local_vllm", action="store_true", help="Deploy/use local vllm for inference.") + parser.add_argument("--ports", type=str, nargs="*", help="Ports where vllm model is already hosted.", default=None) + parser.add_argument("--num_servers", type=int, help="Number of servers to deploy.", default=None) init_args = {} # catch any unknown arguments @@ -38,7 +41,30 @@ logging.info(f"Unknown arguments: {unknown_args} will be sent as is to the experiment config class.") experiment_config_class = args.exp_config - if args.model_config: + + if args.local_vllm and args.model_config: + from eureka_ml_insights.configs.config import ModelConfig + from eureka_ml_insights.models import LocalVLLMModel + try: + model_config = getattr(model_configs, args.model_config) + if isinstance(model_config, ModelConfig): + model_config.init_args["ports"] = args.ports + model_config.init_args["num_servers"] = args.num_servers if args.num_servers else 1 + init_args["model_config"] = model_config + # Logic above is that certain deployment parameters like ports and num_servers + # can be variable and so we allow them to be overridden by command line args. + except AttributeError: + # If there's no config, create one. + init_args["model_config"] = ModelConfig( + LocalVLLMModel, + { + "model_name": args.model_config, + "ports": args.ports, + "num_servers": args.num_servers if args.num_servers else 1 + } + ) + + elif args.model_config: try: init_args["model_config"] = getattr(model_configs, args.model_config) except AttributeError: @@ -55,3 +81,7 @@ logging.info(f"Saving experiment logs in {pipeline_config.log_dir}.") pipeline = Pipeline(pipeline_config.component_configs, pipeline_config.log_dir) pipeline.run() + + if args.local_vllm: + from eureka_ml_insights.models.models import _LocalVLLMDeploymentHandler + _LocalVLLMDeploymentHandler.shutdown_servers()