Skip to content

Commit 7ea2df8

Browse files
committed
added local llm class.
1 parent f2a8af8 commit 7ea2df8

File tree

8 files changed

+66
-32
lines changed

8 files changed

+66
-32
lines changed

CHANGELOG_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
### Added
1515
- 新增Prompt管理界面,支持在线修改Prompt
16+
- 支持本地大模型 LocalLLM类
1617

1718
### Changed
1819
- 修改es包,由elasticsearch[async] 改为 elasticsearch

config.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
"is_detailed_tool_call": true,
1818
"is_detailed_observation": true
1919
},
20-
"llm": {
21-
"cls": "oxygent.llms.OllamaLLM",
22-
"base_url": "http://localhost:11434",
20+
"llm": {
2321
"temperature": 0.1,
2422
"max_tokens": 4096,
2523
"top_p": 1

oxygent/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ class Config:
5353
"is_detailed_observation": True,
5454
},
5555
"llm": {
56-
"cls": "oxygent.llms.OllamaLLM",
57-
"base_url": "http://localhost:11434",
5856
"temperature": 0.1,
5957
"max_tokens": 4096,
6058
"top_p": 1,

oxygent/oxy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from .function_tools.function_hub import FunctionHub
1818
from .function_tools.function_tool import FunctionTool
19-
from .llms import HttpLLM, MockLLM, OpenAILLM
19+
from .llms import HttpLLM, LocalLLM, MockLLM, OpenAILLM
2020
from .mcp_tools import MCPTool, SSEMCPClient, StdioMCPClient, StreamableMCPClient
2121

2222
__all__ = [
@@ -31,6 +31,7 @@
3131
"HttpLLM",
3232
"OpenAILLM",
3333
"MockLLM",
34+
"LocalLLM",
3435
"MCPTool",
3536
"StdioMCPClient",
3637
"StreamableMCPClient",

oxygent/oxy/llms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .http_llm import HttpLLM
2+
from .local_llm import LocalLLM
23
from .mock_llm import MockLLM
34
from .openai_llm import OpenAILLM
45

56
__all__ = [
67
"HttpLLM",
78
"OpenAILLM",
89
"MockLLM",
10+
"LocalLLM",
911
]

oxygent/oxy/llms/http_llm.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,24 +77,12 @@ async def _execute(self, oxy_request: OxyRequest) -> OxyResponse:
7777
if k != "messages":
7878
payload[k] = v
7979
else:
80-
llm_config = {
81-
k: v
82-
for k, v in Config.get_llm_config().items()
83-
if k
84-
not in {
85-
"cls",
86-
"base_url",
87-
"api_key",
88-
"name",
89-
"model_name",
90-
}
91-
}
9280
payload = {
9381
"messages": await self._get_messages(oxy_request),
9482
"model": self.model_name,
9583
"stream": True,
9684
}
97-
payload.update(llm_config)
85+
payload.update(Config.get_llm_config())
9886
for k, v in self.llm_params.items():
9987
payload[k] = v
10088
for k, v in oxy_request.arguments.items():

oxygent/oxy/llms/local_llm.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import logging
2+
3+
from pydantic import Field
4+
5+
from ...config import Config
6+
from ...schemas import OxyRequest, OxyResponse, OxyState
7+
from .base_llm import BaseLLM
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class LocalLLM(BaseLLM):
13+
model_path: str = Field("")
14+
device: str = Field("auto")
15+
model_name: str = Field("")
16+
17+
async def init(self):
18+
try:
19+
import torch
20+
from transformers import AutoModelForCausalLM, AutoTokenizer
21+
except ImportError as e:
22+
raise ImportError(
23+
"LocalLLM requires 'torch' and 'transformers' packages."
24+
"Please install them using 'pip install torch transformers einops transformers_stream_generator accelerate'"
25+
) from e
26+
27+
await super().init()
28+
# Load model directly
29+
self._model = AutoModelForCausalLM.from_pretrained(
30+
self.model_path, device_map=self.device, torch_dtype=torch.bfloat16
31+
)
32+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_path)
33+
34+
async def _execute(self, oxy_request: OxyRequest) -> OxyResponse:
35+
payload = {"model": self.model_name, "stream": False}
36+
payload.update(Config.get_llm_config())
37+
for k, v in self.llm_params.items():
38+
payload[k] = v
39+
for k, v in oxy_request.arguments.items():
40+
if k == "messages":
41+
continue
42+
payload[k] = v
43+
payload = {"max_new_tokens": 512}
44+
45+
messages = oxy_request.arguments["messages"]
46+
47+
input_text = self._tokenizer.apply_chat_template(
48+
messages, tokenize=False, add_generation_prompt=True
49+
)
50+
input_ids = self._tokenizer.encode(input_text, return_tensors="pt")
51+
input_ids = input_ids.to(self._model.device)
52+
outputs = self._model.generate(input_ids=input_ids, **payload)[0]
53+
outputs = outputs[len(input_ids[0]) :]
54+
55+
return OxyResponse(
56+
state=OxyState.COMPLETED,
57+
output=self._tokenizer.decode(outputs, skip_special_tokens=True),
58+
)

oxygent/oxy/llms/openai_llm.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,12 @@ async def _execute(self, oxy_request: OxyRequest) -> OxyResponse:
3838
OxyResponse: The response containing the model's output with COMPLETED state.
3939
"""
4040
# Construct payload for OpenAI API request
41-
llm_config = {
42-
k: v
43-
for k, v in Config.get_llm_config().items()
44-
if k
45-
not in {
46-
"cls",
47-
"base_url",
48-
"api_key",
49-
"name",
50-
"model_name",
51-
}
52-
}
5341
payload = {
5442
"messages": await self._get_messages(oxy_request),
5543
"model": self.model_name,
5644
"stream": True,
5745
}
58-
payload.update(llm_config)
46+
payload.update(Config.get_llm_config())
5947
for k, v in self.llm_params.items():
6048
payload[k] = v
6149
for k, v in oxy_request.arguments.items():

0 commit comments

Comments
 (0)