|
| 1 | +--- |
| 2 | +title: 自定义LLM API |
| 3 | +date: |
| 4 | + created: 2023-09-11 |
| 5 | +authors: [SWHL] |
| 6 | +categories: |
| 7 | + - General |
| 8 | +comments: true |
| 9 | +--- |
| 10 | + |
| 11 | +### 引言 |
| 12 | + |
| 13 | +{{% alert context="info" %}}该项目的LLM部分是独立的,用户可在 **knowledge_qa_llm/llm** 自定义配置所需的LLM接口。{{% /alert %}} |
| 14 | + |
| 15 | +下面以自定义支持InterLM-7b大模型为例,说明如何支持的。前提是本地满足部署LLM的推理条件。 |
| 16 | + |
| 17 | +### 步骤如下 |
| 18 | + |
| 19 | +#### 1. 部署LLM模型到本地 |
| 20 | + |
| 21 | +具体如何下载,参见Hugging Face中[internlm-7b](https://huggingface.co/internlm/internlm-7b)。 |
| 22 | + |
| 23 | +#### 2. 编写模型的部署推理代码 |
| 24 | + |
| 25 | +这一点可以参考[ChatGLM](https://github.com/THUDM/ChatGLM-6B/blob/main/api.py)API的实现。只需要替换模型加载部分为InternLM的即可。具体如下: |
| 26 | + |
| 27 | +<details> |
| 28 | + |
| 29 | +```python {linenos=table} |
| 30 | +from fastapi import FastAPI, Request |
| 31 | +from transformers import AutoTokenizer, AutoModel |
| 32 | +import uvicorn, json, datetime |
| 33 | +import torch |
| 34 | + |
| 35 | +DEVICE = "cuda" |
| 36 | +DEVICE_ID = "0" |
| 37 | +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE |
| 38 | + |
| 39 | + |
| 40 | +def torch_gc(): |
| 41 | + if torch.cuda.is_available(): |
| 42 | + with torch.cuda.device(CUDA_DEVICE): |
| 43 | + torch.cuda.empty_cache() |
| 44 | + torch.cuda.ipc_collect() |
| 45 | + |
| 46 | + |
| 47 | +app = FastAPI() |
| 48 | + |
| 49 | + |
| 50 | +@app.post("/") |
| 51 | +async def create_item(request: Request): |
| 52 | + global model, tokenizer |
| 53 | + json_post_raw = await request.json() |
| 54 | + json_post = json.dumps(json_post_raw) |
| 55 | + json_post_list = json.loads(json_post) |
| 56 | + prompt = json_post_list.get('prompt') |
| 57 | + history = json_post_list.get('history') |
| 58 | + max_length = json_post_list.get('max_length') |
| 59 | + top_p = json_post_list.get('top_p') |
| 60 | + temperature = json_post_list.get('temperature') |
| 61 | + response, history = model.chat(tokenizer, |
| 62 | + prompt, |
| 63 | + history=history, |
| 64 | + max_new_tokens=max_length if max_length else 2048, |
| 65 | + top_p=top_p if top_p else 0.7, |
| 66 | + temperature=temperature if temperature else 0.95) |
| 67 | + now = datetime.datetime.now() |
| 68 | + time = now.strftime("%Y-%m-%d %H:%M:%S") |
| 69 | + answer = { |
| 70 | + "response": response, |
| 71 | + "history": history, |
| 72 | + "status": 200, |
| 73 | + "time": time |
| 74 | + } |
| 75 | + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' |
| 76 | + print(log) |
| 77 | + torch_gc() |
| 78 | + return answer |
| 79 | + |
| 80 | + |
| 81 | +if __name__ == '__main__': |
| 82 | + tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b-v1_1", trust_remote_code=True) |
| 83 | + model = AutoModel.from_pretrained("internlm/internlm-chat-7b-v1_1", trust_remote_code=True).half().cuda() |
| 84 | + model.eval() |
| 85 | + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
| 86 | +``` |
| 87 | + |
| 88 | +</details> |
| 89 | + |
| 90 | +#### 3. 编写调用接口部分代码 |
| 91 | + |
| 92 | +在以下项目`knowledge_qa_llm/llm/`目录下创建`internlm_7b.py`文件,具体代码如下: |
| 93 | + |
| 94 | +<details> |
| 95 | + |
| 96 | +```python {linenos=table} |
| 97 | +import json |
| 98 | +from typing import List, Optional |
| 99 | + |
| 100 | +import requests |
| 101 | + |
| 102 | + |
| 103 | +class InternLM_7B: |
| 104 | + def __init__(self, api_url: str = None): |
| 105 | + self.api_url = api_url |
| 106 | + |
| 107 | + def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): |
| 108 | + if not history: |
| 109 | + history = [] |
| 110 | + |
| 111 | + data = {"prompt": prompt, "history": history} |
| 112 | + if kwargs: |
| 113 | + temperature = kwargs.get("temperature", 0.1) |
| 114 | + top_p = kwargs.get("top_p", 0.7) |
| 115 | + max_length = kwargs.get("max_length", 4096) |
| 116 | + |
| 117 | + data.update( |
| 118 | + {"temperature": temperature, "top_p": top_p, "max_length": max_length} |
| 119 | + ) |
| 120 | + req = requests.post(self.api_url, data=json.dumps(data), timeout=60) |
| 121 | + try: |
| 122 | + rdata = req.json() |
| 123 | + if rdata["status"] == 200: |
| 124 | + return rdata["response"] |
| 125 | + return "Network error" |
| 126 | + except Exception as e: |
| 127 | + return f"Network error:{e}" |
| 128 | +``` |
| 129 | + |
| 130 | +</details> |
| 131 | + |
| 132 | +#### 4. 添加导入声明 |
| 133 | + |
| 134 | +在`knowledge_qa_llm/llm/__init__.py`中添加对应的`import`部分代码,示例如下: |
| 135 | + |
| 136 | +```python {linenos=table} |
| 137 | +from .baichuan_7b import BaiChuan7B |
| 138 | +from .chatglm2_6b import ChatGLM2_6B |
| 139 | +from .ernie_bot_turbo import ERNIEBotTurbo |
| 140 | +from .qwen7b_chat import Qwen7B_Chat |
| 141 | +from .internlm_7b import InternLM_7B |
| 142 | + |
| 143 | +__all__ = ["BaiChuan7B", "ChatGLM2_6B", "ERNIEBotTurbo", "Qwen7B_Chat", "InternLM_7B"] |
| 144 | +``` |
| 145 | + |
| 146 | +#### 5. 更改配置文件 |
| 147 | + |
| 148 | +更改`knowledge_qa_llm/config.yaml` |
| 149 | + |
| 150 | +```yaml {linenos=table} |
| 151 | +LLM_API: |
| 152 | + InternLM_7B: your_api |
| 153 | + Qwen7B_Chat: your_api |
| 154 | + ChatGLM2_6B: your_api |
| 155 | + BaiChuan7B: your_api |
| 156 | +``` |
| 157 | +
|
| 158 | +#### 6. 启动 |
| 159 | +
|
| 160 | +```bash {linenos=table} |
| 161 | +streamlit run web_ui.py |
| 162 | +``` |
0 commit comments