1
+ """
2
+ Customized LLM for support-ai
3
+ """
4
+
1
5
import argparse
2
6
import logging
3
7
import os
8
+ from typing import Dict , List , Optional
4
9
import torch
5
10
import yaml
6
11
7
12
from kserve import Model , ModelServer
8
13
from sentence_transformers import SentenceTransformer
9
14
from transformers import LlamaForCausalLM , LlamaTokenizer
10
- from typing import Dict , List , Optional
11
15
12
16
13
17
CONFIG_INFERENCE_MODEL_PATH = 'inference_model_path'
14
18
15
19
class RemoteLlamaModel (Model ):
20
+ """
21
+ A KServe model wrapper for the Llama causal language model and a sentence
22
+ transformer embeddings model.
23
+ """
24
+
16
25
def __init__ (self , config ):
26
+ """
27
+ Initializes RemoteLlamaModel by loading the specified configuration.
28
+
29
+ Args:
30
+ config: Configuration dictionary with model paths and
31
+ settings.
32
+
33
+ Raises:
34
+ ValueError: If CONFIG_INFERENCE_MODEL_PATH is missing in config.
35
+ Exception: If model loading fails.
36
+ """
17
37
super ().__init__ ('llama-model' )
18
38
self .load (config )
19
39
20
40
def load (self , config ):
41
+ """
42
+ Loads the Llama tokenizer, inference model, and sentence embeddings
43
+ model.
44
+
45
+ Args:
46
+ config: Configuration dictionary with model paths and
47
+ settings.
48
+
49
+ Raises:
50
+ ValueError: If CONFIG_INFERENCE_MODEL_PATH is missing in config.
51
+ Exception: If model loading fails.
52
+ """
21
53
if CONFIG_INFERENCE_MODEL_PATH not in config :
22
54
raise ValueError (f'The config doesn\' t contain { CONFIG_INFERENCE_MODEL_PATH } ' )
23
55
self .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
@@ -28,17 +60,31 @@ def load(self, config):
28
60
config [CONFIG_INFERENCE_MODEL_PATH ], token = self .token
29
61
)
30
62
self .inference_model = LlamaForCausalLM .from_pretrained (
31
- config [CONFIG_INFERENCE_MODEL_PATH ], token = self .token , device_map = 'auto' , load_in_4bit = True
63
+ config [CONFIG_INFERENCE_MODEL_PATH ], token = self .token ,
64
+ device_map = 'auto' , load_in_4bit = True
32
65
)
33
66
34
67
logging .info ("Loading Sentence Transformer embeddings model..." )
35
68
self .embeddings_model = SentenceTransformer ('multi-qa-MiniLM-L6-cos-v1' )
36
69
self .ready = True
37
70
except Exception as e :
38
- logging .error (f "Failed to load models: { str (e )} " )
71
+ logging .error ("Failed to load models: %s" , str (e ))
39
72
raise
40
73
41
- async def predict (self , payload : Dict [str , List [str ]], headers : Optional [Dict [str , str ]] = None ) -> Dict :
74
+ async def predict (self , payload : Dict [str , List [str ]],
75
+ _headers : Optional [Dict [str , str ]] = None ) -> Dict :
76
+ """
77
+ Handles prediction requests, performing inference or embeddings
78
+ generation based on request type.
79
+
80
+ Args:
81
+ payload: Contains 'texts' (List[str]) for inference or
82
+ embedding and 'type' to specify the operation.
83
+ headers: Optional headers for the request.
84
+
85
+ Returns:
86
+ dict: Contains the result 'outputs' or an error message.
87
+ """
42
88
texts = payload .get ('texts' , [])
43
89
response_type = payload .get ('type' , 'unknown' )
44
90
outputs = []
@@ -53,13 +99,22 @@ async def predict(self, payload: Dict[str, List[str]], headers: Optional[Dict[st
53
99
outputs = self .__generate_embeddings (texts )
54
100
else :
55
101
return {'error' : f'Unknown request type: { response_type } ' }
56
- except Exception as e :
57
- logging .error ( f"Prediction failed: { str ( e ) } " )
102
+ except Exception as e : # pylint: disable=broad-except
103
+ logging .exception ( e )
58
104
return {'error' : 'Prediction failed due to internal error.' }
59
105
60
106
return {'outputs' : outputs }
61
107
62
108
def _perform_inference (self , texts : List [str ]) -> List [str ]:
109
+ """
110
+ Performs text generation using the Llama model.
111
+
112
+ Args:
113
+ texts: List of input texts to generate responses for.
114
+
115
+ Returns:
116
+ List[str]: Generated responses for each input text.
117
+ """
63
118
results = []
64
119
for text in texts :
65
120
input_ids = self .tokenizer .encode (text , return_tensors = "pt" ).to (self .device )
@@ -73,20 +128,49 @@ def _perform_inference(self, texts: List[str]) -> List[str]:
73
128
return results
74
129
75
130
def __generate_embeddings (self , texts : List [str ]) -> List [List [float ]]:
131
+ """
132
+ Generates embeddings for input texts using the SentenceTransformer
133
+ model.
134
+
135
+ Args:
136
+ texts: List of input texts to generate embeddings for.
137
+
138
+ Returns:
139
+ List[List[float]]: Generated embeddings for each input text.
140
+ """
76
141
return [self .embeddings_model .encode (text ).tolist () for text in texts ]
77
142
78
143
def get_model_config (path ):
144
+ """
145
+ Reads and returns the model configuration from a YAML file.
146
+
147
+ Args:
148
+ path: Path to the YAML configuration file.
149
+
150
+ Returns:
151
+ dict: Loaded configuration data.
152
+ """
79
153
config = None
80
- with open (path ) as stream :
154
+ with open (path , encoding = "utf-8" ) as stream :
81
155
config = yaml .safe_load (stream )
82
156
return config
83
157
84
158
def parse_args ():
159
+ """
160
+ Parses command-line arguments.
161
+
162
+ Returns:
163
+ argparse.Namespace: Parsed arguments with model config path.
164
+ """
85
165
parser = argparse .ArgumentParser (description = 'remote-llm' )
86
166
parser .add_argument ('--model_config' , type = str , default = 'config.yaml' , help = 'Config path' )
87
167
return parser .parse_args ()
88
168
89
169
def main ():
170
+ """
171
+ Initializes logging, loads model configuration, and starts the model
172
+ server.
173
+ """
90
174
logging .basicConfig (level = logging .INFO )
91
175
args = parse_args ()
92
176
config = get_model_config (args .model_config )
0 commit comments