|
| 1 | +import logging |
| 2 | + |
| 3 | +from pydantic import Field |
| 4 | + |
| 5 | +from ...config import Config |
| 6 | +from ...schemas import OxyRequest, OxyResponse, OxyState |
| 7 | +from .base_llm import BaseLLM |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +class LocalLLM(BaseLLM): |
| 13 | + model_path: str = Field("") |
| 14 | + device: str = Field("auto") |
| 15 | + model_name: str = Field("") |
| 16 | + |
| 17 | + async def init(self): |
| 18 | + try: |
| 19 | + import torch |
| 20 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 21 | + except ImportError as e: |
| 22 | + raise ImportError( |
| 23 | + "LocalLLM requires 'torch' and 'transformers' packages." |
| 24 | + "Please install them using 'pip install torch transformers einops transformers_stream_generator accelerate'" |
| 25 | + ) from e |
| 26 | + |
| 27 | + await super().init() |
| 28 | + # Load model directly |
| 29 | + self._model = AutoModelForCausalLM.from_pretrained( |
| 30 | + self.model_path, device_map=self.device, torch_dtype=torch.bfloat16 |
| 31 | + ) |
| 32 | + self._tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
| 33 | + |
| 34 | + async def _execute(self, oxy_request: OxyRequest) -> OxyResponse: |
| 35 | + payload = {"model": self.model_name, "stream": False} |
| 36 | + payload.update(Config.get_llm_config()) |
| 37 | + for k, v in self.llm_params.items(): |
| 38 | + payload[k] = v |
| 39 | + for k, v in oxy_request.arguments.items(): |
| 40 | + if k == "messages": |
| 41 | + continue |
| 42 | + payload[k] = v |
| 43 | + payload = {"max_new_tokens": 512} |
| 44 | + |
| 45 | + messages = oxy_request.arguments["messages"] |
| 46 | + |
| 47 | + input_text = self._tokenizer.apply_chat_template( |
| 48 | + messages, tokenize=False, add_generation_prompt=True |
| 49 | + ) |
| 50 | + input_ids = self._tokenizer.encode(input_text, return_tensors="pt") |
| 51 | + input_ids = input_ids.to(self._model.device) |
| 52 | + outputs = self._model.generate(input_ids=input_ids, **payload)[0] |
| 53 | + outputs = outputs[len(input_ids[0]) :] |
| 54 | + |
| 55 | + return OxyResponse( |
| 56 | + state=OxyState.COMPLETED, |
| 57 | + output=self._tokenizer.decode(outputs, skip_special_tokens=True), |
| 58 | + ) |
0 commit comments