1111
1212class LocalLLM (BaseLLM ):
1313 model_path : str = Field ("" )
14- device : str = Field ("auto" )
15- model_name : str = Field ("" )
14+ device_map : str = Field ("auto" )
15+ dtype : str = Field ("bfloat16 " )
1616
1717 async def init (self ):
1818 try :
@@ -27,20 +27,28 @@ async def init(self):
2727 await super ().init ()
2828 # Load model directly
2929 self ._model = AutoModelForCausalLM .from_pretrained (
30- self .model_path , device_map = self .device , torch_dtype = torch . bfloat16
30+ self .model_path , device_map = self .device_map , dtype = self . dtype
3131 )
3232 self ._tokenizer = AutoTokenizer .from_pretrained (self .model_path )
3333
3434 async def _execute (self , oxy_request : OxyRequest ) -> OxyResponse :
35- payload = {"model" : self .model_name , "stream" : False }
36- payload .update (Config .get_llm_config ())
35+ payload = Config .get_llm_config ()
3736 for k , v in self .llm_params .items ():
3837 payload [k ] = v
3938 for k , v in oxy_request .arguments .items ():
4039 if k == "messages" :
4140 continue
4241 payload [k ] = v
43- payload = {"max_new_tokens" : 512 }
42+
43+ replace_dict = {
44+ "max_tokens" : "max_new_tokens" ,
45+ "stream" : "" ,
46+ }
47+ for k , v in replace_dict .items ():
48+ if k in payload :
49+ if v :
50+ payload [v ] = payload [k ]
51+ del payload [k ]
4452
4553 messages = oxy_request .arguments ["messages" ]
4654
0 commit comments