Skip to content

Commit 53b6d63

Browse files
committed
Fix pylint errors
1 parent b973fbf commit 53b6d63

30 files changed

+1273
-55
lines changed

remote_llm/remote_llm.py

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,55 @@
1+
"""
2+
Customized LLM for support-ai
3+
"""
4+
15
import argparse
26
import logging
37
import os
8+
from typing import Dict, List, Optional
49
import torch
510
import yaml
611

712
from kserve import Model, ModelServer
813
from sentence_transformers import SentenceTransformer
914
from transformers import LlamaForCausalLM, LlamaTokenizer
10-
from typing import Dict, List, Optional
1115

1216

1317
CONFIG_INFERENCE_MODEL_PATH = 'inference_model_path'
1418

1519
class RemoteLlamaModel(Model):
20+
"""
21+
A KServe model wrapper for the Llama causal language model and a sentence
22+
transformer embeddings model.
23+
"""
24+
1625
def __init__(self, config):
26+
"""
27+
Initializes RemoteLlamaModel by loading the specified configuration.
28+
29+
Args:
30+
config: Configuration dictionary with model paths and
31+
settings.
32+
33+
Raises:
34+
ValueError: If CONFIG_INFERENCE_MODEL_PATH is missing in config.
35+
Exception: If model loading fails.
36+
"""
1737
super().__init__('llama-model')
1838
self.load(config)
1939

2040
def load(self, config):
41+
"""
42+
Loads the Llama tokenizer, inference model, and sentence embeddings
43+
model.
44+
45+
Args:
46+
config: Configuration dictionary with model paths and
47+
settings.
48+
49+
Raises:
50+
ValueError: If CONFIG_INFERENCE_MODEL_PATH is missing in config.
51+
Exception: If model loading fails.
52+
"""
2153
if CONFIG_INFERENCE_MODEL_PATH not in config:
2254
raise ValueError(f'The config doesn\'t contain {CONFIG_INFERENCE_MODEL_PATH}')
2355
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -28,17 +60,31 @@ def load(self, config):
2860
config[CONFIG_INFERENCE_MODEL_PATH], token=self.token
2961
)
3062
self.inference_model = LlamaForCausalLM.from_pretrained(
31-
config[CONFIG_INFERENCE_MODEL_PATH], token=self.token, device_map='auto', load_in_4bit=True
63+
config[CONFIG_INFERENCE_MODEL_PATH], token=self.token,
64+
device_map='auto', load_in_4bit=True
3265
)
3366

3467
logging.info("Loading Sentence Transformer embeddings model...")
3568
self.embeddings_model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
3669
self.ready = True
3770
except Exception as e:
38-
logging.error(f"Failed to load models: {str(e)}")
71+
logging.error("Failed to load models: %s", str(e))
3972
raise
4073

41-
async def predict(self, payload: Dict[str, List[str]], headers: Optional[Dict[str, str]] = None) -> Dict:
74+
async def predict(self, payload: Dict[str, List[str]],
75+
_headers: Optional[Dict[str, str]] = None) -> Dict:
76+
"""
77+
Handles prediction requests, performing inference or embeddings
78+
generation based on request type.
79+
80+
Args:
81+
payload: Contains 'texts' (List[str]) for inference or
82+
embedding and 'type' to specify the operation.
83+
headers: Optional headers for the request.
84+
85+
Returns:
86+
dict: Contains the result 'outputs' or an error message.
87+
"""
4288
texts = payload.get('texts', [])
4389
response_type = payload.get('type', 'unknown')
4490
outputs = []
@@ -53,13 +99,22 @@ async def predict(self, payload: Dict[str, List[str]], headers: Optional[Dict[st
5399
outputs = self.__generate_embeddings(texts)
54100
else:
55101
return {'error': f'Unknown request type: {response_type}'}
56-
except Exception as e:
57-
logging.error(f"Prediction failed: {str(e)}")
102+
except Exception as e: # pylint: disable=broad-except
103+
logging.exception(e)
58104
return {'error': 'Prediction failed due to internal error.'}
59105

60106
return {'outputs': outputs}
61107

62108
def _perform_inference(self, texts: List[str]) -> List[str]:
109+
"""
110+
Performs text generation using the Llama model.
111+
112+
Args:
113+
texts: List of input texts to generate responses for.
114+
115+
Returns:
116+
List[str]: Generated responses for each input text.
117+
"""
63118
results = []
64119
for text in texts:
65120
input_ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
@@ -73,20 +128,49 @@ def _perform_inference(self, texts: List[str]) -> List[str]:
73128
return results
74129

75130
def __generate_embeddings(self, texts: List[str]) -> List[List[float]]:
131+
"""
132+
Generates embeddings for input texts using the SentenceTransformer
133+
model.
134+
135+
Args:
136+
texts: List of input texts to generate embeddings for.
137+
138+
Returns:
139+
List[List[float]]: Generated embeddings for each input text.
140+
"""
76141
return [self.embeddings_model.encode(text).tolist() for text in texts]
77142

78143
def get_model_config(path):
144+
"""
145+
Reads and returns the model configuration from a YAML file.
146+
147+
Args:
148+
path: Path to the YAML configuration file.
149+
150+
Returns:
151+
dict: Loaded configuration data.
152+
"""
79153
config = None
80-
with open(path) as stream:
154+
with open(path, encoding="utf-8") as stream:
81155
config = yaml.safe_load(stream)
82156
return config
83157

84158
def parse_args():
159+
"""
160+
Parses command-line arguments.
161+
162+
Returns:
163+
argparse.Namespace: Parsed arguments with model config path.
164+
"""
85165
parser = argparse.ArgumentParser(description='remote-llm')
86166
parser.add_argument('--model_config', type=str, default='config.yaml', help='Config path')
87167
return parser.parse_args()
88168

89169
def main():
170+
"""
171+
Initializes logging, loads model configuration, and starts the model
172+
server.
173+
"""
90174
logging.basicConfig(level=logging.INFO)
91175
args = parse_args()
92176
config = get_model_config(args.model_config)

src/support_ai/ai_bot.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Support-AI Command Line Tool
3+
"""
4+
15
import argparse
26
import uuid
37

@@ -6,6 +10,12 @@
610

711

812
def parse_args():
13+
"""
14+
Parses command-line arguments for the support-ai command line tool.
15+
16+
Returns:
17+
argparse.Namespace: Parsed arguments, including the config file path.
18+
"""
919
parser = argparse.ArgumentParser(
1020
description='Command line tool for support-ai')
1121
parser.add_argument('--config', type=str, default=None,
@@ -14,6 +24,13 @@ def parse_args():
1424

1525

1626
def main():
27+
"""
28+
Main function to execute the support-ai tool. Initializes configuration,
29+
session ID, and input loop for querying the support-ai chain.
30+
31+
Prompts the user for input queries and processes each input through the
32+
chain. Exits if 'exit', 'quit', 'q', 'e', or 'x' is entered.
33+
"""
1734
args = parse_args()
1835
config = get_config(args.config)
1936
chain = Chain(config)

src/support_ai/api_server.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Support-AI API Server
3+
"""
4+
15
import argparse
26

37
from flask import Blueprint, Flask, jsonify, request, Response
@@ -9,13 +13,25 @@
913

1014

1115
app = Flask(__name__)
12-
chain = None
16+
chain = None # pylint: disable=invalid-name
1317
api_blueprint = Blueprint('api', __name__)
1418
api = Api(api_blueprint)
1519

1620

17-
class AI(Resource):
18-
def get(self):
21+
class AI(Resource): # pylint: disable=too-few-public-methods
22+
"""
23+
Resource class for the AI endpoint.
24+
"""
25+
26+
def get(self): # pylint: disable=no-self-use
27+
"""
28+
Handles GET requests to the /api/ai endpoint. Queries the AI model
29+
with the provided query and optional datasource and session arguments.
30+
31+
Returns:
32+
Response: A text/plain response with the model's response or
33+
a JSON error message if the query parameter is missing.
34+
"""
1935
query = request.args.get('query')
2036
datasource = request.args.get('datasource')
2137
session = request.args.get('session')
@@ -30,8 +46,24 @@ def get(self):
3046
return {'message': 'Service unavailable'}, 400
3147

3248

33-
class Salesforce(Resource):
34-
def get(self, case_number):
49+
class Salesforce(Resource): # pylint: disable=too-few-public-methods
50+
"""
51+
Resource class for the Salesforce endpoint.
52+
"""
53+
54+
def get(self, case_number): # pylint: disable=no-self-use
55+
"""
56+
Handles GET requests to the /api/salesforce/<case_number>/summary
57+
endpoint.
58+
59+
Args:
60+
case_number: The case number for which to retrieve a summary.
61+
62+
Returns:
63+
Response: A text/plain response with the case summary or a JSON
64+
error message if the case number is missing or service is
65+
unavailable.
66+
"""
3567
if case_number is None:
3668
return {'message': 'Case number not specified'}, 400
3769

@@ -44,8 +76,20 @@ def get(self, case_number):
4476
return {'message': 'Service unavailable'}, 400
4577

4678

47-
class History(Resource):
48-
def delete(self):
79+
class History(Resource): # pylint: disable=too-few-public-methods
80+
"""
81+
Resource class for the History endpoint.
82+
"""
83+
84+
def delete(self): # pylint: disable=no-self-use
85+
"""
86+
Handles DELETE requests to the /api/history endpoint. Clears the
87+
history for the specified session.
88+
89+
Returns:
90+
Response: A JSON response with success status, or an error
91+
message if the session parameter is missing.
92+
"""
4993
session = request.args.get('session')
5094

5195
if session is None:
@@ -55,14 +99,26 @@ def delete(self):
5599

56100

57101
def parse_args():
102+
"""
103+
Parses command-line arguments for the support-ai API server.
104+
105+
Returns:
106+
argparse.Namespace: Parsed arguments, including the config file path.
107+
"""
58108
parser = argparse.ArgumentParser(
59109
description='Command line tool for support-ai')
60110
parser.add_argument('--config', type=str, default=None, help='Config path')
61111
return parser.parse_args()
62112

63113

64114
def main():
65-
global chain
115+
"""
116+
Main function to initialize the support-ai server. Loads configuration,
117+
initializes the Chain, and sets up API routes.
118+
119+
The server listens on all interfaces at port 8080.
120+
"""
121+
global chain # pylint: disable=global-statement
66122
args = parse_args()
67123
config = get_config(args.config)
68124
chain = Chain(config)

src/support_ai/ds_updater.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Support-AI Data Source Updater
3+
"""
4+
15
import argparse
26
import signal
37
import sys
@@ -8,13 +12,26 @@
812

913

1014
def parse_args():
15+
"""
16+
Parses command-line arguments for the support-ai data source updater.
17+
18+
Returns:
19+
argparse.Namespace: Parsed arguments, including the config file path.
20+
"""
1121
parser = argparse.ArgumentParser(
1222
description='Command line tool for support-ai')
1323
parser.add_argument('--config', type=str, default=None, help='Config path')
1424
return parser.parse_args()
1525

1626

1727
def main():
28+
"""
29+
Main function to initialize and start the data source updater.
30+
31+
Loads the configuration, initializes the DSUpdater, and sets up signal
32+
handling for graceful termination. Runs the update thread continuously
33+
until a termination signal is received.
34+
"""
1835
args = parse_args()
1936
config = get_config(args.config)
2037
ds_updater = DSUpdater(config)

0 commit comments

Comments
 (0)