Skip to content

Commit 550325d

Browse files
sgurunatpre-commit-ci[bot]lvliang-intel
authored
vLLM support for DocSum (opea-project#885)
* Add model parameter for DocSumGateway in gateway.py file Signed-off-by: sgurunat <[email protected]> * Add langchain vllm support for DocSum along with authentication support for vllm endpoints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated docker_compose_llm.yaml and README file with vLLM information Signed-off-by: sgurunat <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated docsum-vllm Dockerfile into llm-compose-cd.yaml under github workflows Signed-off-by: sgurunat <[email protected]> * Updated llm-compose.yaml file to include vllm sumarization docker build Signed-off-by: sgurunat <[email protected]> --------- Signed-off-by: sgurunat <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: lvliang-intel <[email protected]>
1 parent f5c60f1 commit 550325d

File tree

10 files changed

+334
-0
lines changed

10 files changed

+334
-0
lines changed

.github/workflows/docker/compose/llms-compose.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ services:
5858
build:
5959
dockerfile: comps/llms/text-generation/predictionguard/Dockerfile
6060
image: ${REGISTRY:-opea}/llm-textgen-predictionguard:${TAG:-latest}
61+
llm-docsum-vllm:
62+
build:
63+
dockerfile: comps/llms/summarization/vllm/langchain/Dockerfile
64+
image: ${REGISTRY:-opea}/llm-docsum-vllm:${TAG:-latest}
6165
llm-faqgen-vllm:
6266
build:
6367
dockerfile: comps/llms/faq-generation/vllm/langchain/Dockerfile

comps/cores/mega/gateway.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@ async def handle_request(self, request: Request):
433433
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
434434
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
435435
streaming=stream_opt,
436+
language=chat_request.language if chat_request.language else "auto",
437+
model=chat_request.model if chat_request.model else None,
436438
)
437439
result_dict, runtime_graph = await self.megaservice.schedule(
438440
initial_inputs={data["type"]: prompt}, llm_parameters=parameters
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
FROM python:3.11-slim
5+
6+
ARG ARCH="cpu"
7+
8+
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
9+
libgl1-mesa-glx \
10+
libjemalloc-dev
11+
12+
RUN useradd -m -s /bin/bash user && \
13+
mkdir -p /home/user && \
14+
chown -R user /home/user/
15+
16+
USER user
17+
18+
COPY comps /home/user/comps
19+
20+
RUN pip install --no-cache-dir --upgrade pip setuptools && \
21+
if [ ${ARCH} = "cpu" ]; then pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu; fi && \
22+
pip install --no-cache-dir -r /home/user/comps/llms/summarization/vllm/langchain/requirements.txt
23+
24+
ENV PYTHONPATH=$PYTHONPATH:/home/user
25+
26+
WORKDIR /home/user/comps/llms/summarization/vllm/langchain
27+
28+
ENTRYPOINT ["bash", "entrypoint.sh"]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Document Summary vLLM Microservice
2+
3+
This microservice leverages LangChain to implement summarization strategies and facilitate LLM inference using vLLM.
4+
[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use library for LLM inference and serving, it delivers state-of-the-art serving throughput with a set of advanced features such as PagedAttention, Continuous batching and etc.. Besides GPUs, vLLM already supported [Intel CPUs](https://www.intel.com/content/www/us/en/products/overview.html) and [Gaudi accelerators](https://habana.ai/products).
5+
6+
## 🚀1. Start Microservice with Python 🐍 (Option 1)
7+
8+
To start the LLM microservice, you need to install python packages first.
9+
10+
### 1.1 Install Requirements
11+
12+
```bash
13+
pip install -r requirements.txt
14+
```
15+
16+
### 1.2 Start LLM Service
17+
18+
```bash
19+
export HF_TOKEN=${your_hf_api_token}
20+
export LLM_MODEL_ID=${your_hf_llm_model}
21+
docker run -p 8008:80 -v ./data:/data --name llm-docsum-vllm --shm-size 1g opea/vllm:hpu --model-id ${LLM_MODEL_ID}
22+
```
23+
24+
### 1.3 Verify the vLLM Service
25+
26+
```bash
27+
curl http://${your_ip}:8008/v1/chat/completions \
28+
-X POST \
29+
-H "Content-Type: application/json" \
30+
-d '{"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "user", "content": "What is Deep Learning? "}]}'
31+
```
32+
33+
### 1.4 Start LLM Service with Python Script
34+
35+
```bash
36+
export vLLM_ENDPOINT="http://${your_ip}:8008"
37+
python llm.py
38+
```
39+
40+
## 🚀2. Start Microservice with Docker 🐳 (Option 2)
41+
42+
If you start an LLM microservice with docker, the `docker_compose_llm.yaml` file will automatically start a vLLM/vLLM service with docker.
43+
44+
To setup or build the vLLM image follow the instructions provided in [vLLM Gaudi](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/text-generation/vllm/langchain#22-vllm-on-gaudi)
45+
46+
### 2.1 Setup Environment Variables
47+
48+
In order to start vLLM and LLM services, you need to setup the following environment variables first.
49+
50+
```bash
51+
export HF_TOKEN=${your_hf_api_token}
52+
export vLLM_ENDPOINT="http://${your_ip}:8008"
53+
export LLM_MODEL_ID=${your_hf_llm_model}
54+
```
55+
56+
### 2.2 Build Docker Image
57+
58+
```bash
59+
cd ../../../../../
60+
docker build -t opea/llm-docsum-vllm:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/llms/summarization/vllm/langchain/Dockerfile .
61+
```
62+
63+
To start a docker container, you have two options:
64+
65+
- A. Run Docker with CLI
66+
- B. Run Docker with Docker Compose
67+
68+
You can choose one as needed.
69+
70+
### 2.3 Run Docker with CLI (Option A)
71+
72+
```bash
73+
docker run -d --name="llm-docsum-vllm-server" -p 9000:9000 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e vLLM_ENDPOINT=$vLLM_ENDPOINT -e HF_TOKEN=$HF_TOKEN opea/llm-docsum-vllm:latest
74+
```
75+
76+
### 2.4 Run Docker with Docker Compose (Option B)
77+
78+
```bash
79+
docker compose -f docker_compose_llm.yaml up -d
80+
```
81+
82+
## 🚀3. Consume LLM Service
83+
84+
### 3.1 Check Service Status
85+
86+
```bash
87+
curl http://${your_ip}:9000/v1/health_check\
88+
-X GET \
89+
-H 'Content-Type: application/json'
90+
```
91+
92+
### 3.2 Consume LLM Service
93+
94+
```bash
95+
# Enable streaming to receive a streaming response. By default, this is set to True.
96+
curl http://${your_ip}:9000/v1/chat/docsum \
97+
-X POST \
98+
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en"}' \
99+
-H 'Content-Type: application/json'
100+
101+
# Disable streaming to receive a non-streaming response.
102+
curl http://${your_ip}:9000/v1/chat/docsum \
103+
-X POST \
104+
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "streaming":false}' \
105+
-H 'Content-Type: application/json'
106+
107+
# Use Chinese mode. By default, language is set to "en"
108+
curl http://${your_ip}:9000/v1/chat/docsum \
109+
-X POST \
110+
-d '{"query":"2024年9月26日,北京——今日,英特尔正式发布英特尔® 至强® 6性能核处理器(代号Granite Rapids),为AI、数据分析、科学计算等计算密集型业务提供卓越性能。", "max_tokens":32, "language":"zh", "streaming":false}' \
111+
-H 'Content-Type: application/json'
112+
```
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
version: "3.8"
5+
6+
services:
7+
vllm-service:
8+
image: opea/vllm:hpu
9+
container_name: vllm-gaudi-server
10+
ports:
11+
- "8008:80"
12+
volumes:
13+
- "./data:/data"
14+
environment:
15+
no_proxy: ${no_proxy}
16+
http_proxy: ${http_proxy}
17+
https_proxy: ${https_proxy}
18+
HF_TOKEN: ${HF_TOKEN}
19+
HABANA_VISIBLE_DEVICES: all
20+
OMPI_MCA_btl_vader_single_copy_mechanism: none
21+
LLM_MODEL_ID: ${LLM_MODEL_ID}
22+
runtime: habana
23+
cap_add:
24+
- SYS_NICE
25+
ipc: host
26+
command: --enforce-eager --model $LLM_MODEL_ID --tensor-parallel-size 1 --host 0.0.0.0 --port 80
27+
llm:
28+
image: opea/llm-docsum-vllm:latest
29+
container_name: llm-docsum-vllm-server
30+
ports:
31+
- "9000:9000"
32+
ipc: host
33+
environment:
34+
no_proxy: ${no_proxy}
35+
http_proxy: ${http_proxy}
36+
https_proxy: ${https_proxy}
37+
vLLM_ENDPOINT: ${vLLM_ENDPOINT}
38+
HUGGINGFACEHUB_API_TOKEN: ${HF_TOKEN}
39+
LLM_MODEL_ID: ${LLM_MODEL_ID}
40+
restart: unless-stopped
41+
42+
networks:
43+
default:
44+
driver: bridge
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/env bash
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
pip --no-cache-dir install -r requirements-runtime.txt
7+
8+
python llm.py
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
6+
from fastapi.responses import StreamingResponse
7+
from langchain.chains.summarize import load_summarize_chain
8+
from langchain.docstore.document import Document
9+
from langchain.prompts import PromptTemplate
10+
from langchain.text_splitter import CharacterTextSplitter
11+
from langchain_community.llms import VLLMOpenAI
12+
13+
from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
14+
from comps.cores.mega.utils import get_access_token
15+
16+
logger = CustomLogger("llm_docsum")
17+
logflag = os.getenv("LOGFLAG", False)
18+
19+
# Environment variables
20+
TOKEN_URL = os.getenv("TOKEN_URL")
21+
CLIENTID = os.getenv("CLIENTID")
22+
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
23+
MODEL_ID = os.getenv("LLM_MODEL_ID", None)
24+
25+
templ_en = """Write a concise summary of the following:
26+
"{text}"
27+
CONCISE SUMMARY:"""
28+
29+
templ_zh = """请简要概括以下内容:
30+
"{text}"
31+
概况:"""
32+
33+
34+
def post_process_text(text: str):
35+
if text == " ":
36+
return "data: @#$\n\n"
37+
if text == "\n":
38+
return "data: <br/>\n\n"
39+
if text.isspace():
40+
return None
41+
new_text = text.replace(" ", "@#$")
42+
return f"data: {new_text}\n\n"
43+
44+
45+
@register_microservice(
46+
name="opea_service@llm_docsum",
47+
service_type=ServiceType.LLM,
48+
endpoint="/v1/chat/docsum",
49+
host="0.0.0.0",
50+
port=9000,
51+
)
52+
async def llm_generate(input: LLMParamsDoc):
53+
if logflag:
54+
logger.info(input)
55+
if input.language in ["en", "auto"]:
56+
templ = templ_en
57+
elif input.language in ["zh"]:
58+
templ = templ_zh
59+
else:
60+
raise NotImplementedError('Please specify the input language in "en", "zh", "auto"')
61+
62+
PROMPT = PromptTemplate.from_template(templ)
63+
64+
if logflag:
65+
logger.info("After prompting:")
66+
logger.info(PROMPT)
67+
68+
access_token = (
69+
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
70+
)
71+
headers = {}
72+
if access_token:
73+
headers = {"Authorization": f"Bearer {access_token}"}
74+
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8080")
75+
model = input.model if input.model else os.getenv("LLM_MODEL_ID")
76+
llm = VLLMOpenAI(
77+
openai_api_key="EMPTY",
78+
openai_api_base=llm_endpoint + "/v1",
79+
model_name=model,
80+
default_headers=headers,
81+
max_tokens=input.max_tokens,
82+
top_p=input.top_p,
83+
streaming=input.streaming,
84+
temperature=input.temperature,
85+
presence_penalty=input.repetition_penalty,
86+
)
87+
llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT)
88+
texts = text_splitter.split_text(input.query)
89+
90+
# Create multiple documents
91+
docs = [Document(page_content=t) for t in texts]
92+
93+
if input.streaming:
94+
95+
async def stream_generator():
96+
from langserve.serialization import WellKnownLCSerializer
97+
98+
_serializer = WellKnownLCSerializer()
99+
async for chunk in llm_chain.astream_log(docs):
100+
data = _serializer.dumps({"ops": chunk.ops}).decode("utf-8")
101+
if logflag:
102+
logger.info(data)
103+
yield f"data: {data}\n\n"
104+
yield "data: [DONE]\n\n"
105+
106+
return StreamingResponse(stream_generator(), media_type="text/event-stream")
107+
else:
108+
response = await llm_chain.ainvoke(docs)
109+
response = response["output_text"]
110+
if logflag:
111+
logger.info(response)
112+
return GeneratedDoc(text=response, prompt=input.query)
113+
114+
115+
if __name__ == "__main__":
116+
# Split text
117+
text_splitter = CharacterTextSplitter()
118+
opea_microservices["opea_service@llm_docsum"].start()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
langserve
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
docarray[full]
2+
fastapi
3+
huggingface_hub
4+
langchain #==0.1.12
5+
langchain-huggingface
6+
langchain-openai
7+
langchain_community
8+
langchainhub
9+
opentelemetry-api
10+
opentelemetry-exporter-otlp
11+
opentelemetry-sdk
12+
prometheus-fastapi-instrumentator
13+
shortuuid
14+
transformers
15+
uvicorn

0 commit comments

Comments
 (0)