Skip to content

Commit 5cc90a9

Browse files
committed
added local llm class.
1 parent 7ea2df8 commit 5cc90a9

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

examples/llms/demo_local_llm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import asyncio
2+
3+
from oxygent import MAS, oxy
4+
5+
oxy_space = [
6+
oxy.LocalLLM(
7+
name="default_llm",
8+
model_path="/path/to/your_model",
9+
),
10+
oxy.ChatAgent(
11+
name="master_agent",
12+
llm_model="default_llm",
13+
),
14+
]
15+
16+
17+
async def main():
18+
async with MAS(oxy_space=oxy_space) as mas:
19+
await mas.start_web_service(first_query="hello")
20+
21+
22+
if __name__ == "__main__":
23+
asyncio.run(main())

oxygent/oxy/llms/local_llm.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
class LocalLLM(BaseLLM):
1313
model_path: str = Field("")
14-
device: str = Field("auto")
15-
model_name: str = Field("")
14+
device_map: str = Field("auto")
15+
dtype: str = Field("bfloat16")
1616

1717
async def init(self):
1818
try:
@@ -27,20 +27,28 @@ async def init(self):
2727
await super().init()
2828
# Load model directly
2929
self._model = AutoModelForCausalLM.from_pretrained(
30-
self.model_path, device_map=self.device, torch_dtype=torch.bfloat16
30+
self.model_path, device_map=self.device_map, dtype=self.dtype
3131
)
3232
self._tokenizer = AutoTokenizer.from_pretrained(self.model_path)
3333

3434
async def _execute(self, oxy_request: OxyRequest) -> OxyResponse:
35-
payload = {"model": self.model_name, "stream": False}
36-
payload.update(Config.get_llm_config())
35+
payload = Config.get_llm_config()
3736
for k, v in self.llm_params.items():
3837
payload[k] = v
3938
for k, v in oxy_request.arguments.items():
4039
if k == "messages":
4140
continue
4241
payload[k] = v
43-
payload = {"max_new_tokens": 512}
42+
43+
replace_dict = {
44+
"max_tokens": "max_new_tokens",
45+
"stream": "",
46+
}
47+
for k, v in replace_dict.items():
48+
if k in payload:
49+
if v:
50+
payload[v] = payload[k]
51+
del payload[k]
4452

4553
messages = oxy_request.arguments["messages"]
4654

0 commit comments

Comments
 (0)