Skip to content

Commit

Permalink
Fix pylint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
yukariatlas committed Oct 29, 2024
1 parent b973fbf commit 53b6d63
Show file tree
Hide file tree
Showing 30 changed files with 1,273 additions and 55 deletions.
98 changes: 91 additions & 7 deletions remote_llm/remote_llm.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,55 @@
"""
Customized LLM for support-ai
"""

import argparse
import logging
import os
from typing import Dict, List, Optional
import torch
import yaml

from kserve import Model, ModelServer
from sentence_transformers import SentenceTransformer
from transformers import LlamaForCausalLM, LlamaTokenizer
from typing import Dict, List, Optional


CONFIG_INFERENCE_MODEL_PATH = 'inference_model_path'

class RemoteLlamaModel(Model):
"""
A KServe model wrapper for the Llama causal language model and a sentence
transformer embeddings model.
"""

def __init__(self, config):
"""
Initializes RemoteLlamaModel by loading the specified configuration.
Args:
config: Configuration dictionary with model paths and
settings.
Raises:
ValueError: If CONFIG_INFERENCE_MODEL_PATH is missing in config.
Exception: If model loading fails.
"""
super().__init__('llama-model')
self.load(config)

def load(self, config):
"""
Loads the Llama tokenizer, inference model, and sentence embeddings
model.
Args:
config: Configuration dictionary with model paths and
settings.
Raises:
ValueError: If CONFIG_INFERENCE_MODEL_PATH is missing in config.
Exception: If model loading fails.
"""
if CONFIG_INFERENCE_MODEL_PATH not in config:
raise ValueError(f'The config doesn\'t contain {CONFIG_INFERENCE_MODEL_PATH}')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -28,17 +60,31 @@ def load(self, config):
config[CONFIG_INFERENCE_MODEL_PATH], token=self.token
)
self.inference_model = LlamaForCausalLM.from_pretrained(
config[CONFIG_INFERENCE_MODEL_PATH], token=self.token, device_map='auto', load_in_4bit=True
config[CONFIG_INFERENCE_MODEL_PATH], token=self.token,
device_map='auto', load_in_4bit=True
)

logging.info("Loading Sentence Transformer embeddings model...")
self.embeddings_model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
self.ready = True
except Exception as e:
logging.error(f"Failed to load models: {str(e)}")
logging.error("Failed to load models: %s", str(e))
raise

async def predict(self, payload: Dict[str, List[str]], headers: Optional[Dict[str, str]] = None) -> Dict:
async def predict(self, payload: Dict[str, List[str]],
_headers: Optional[Dict[str, str]] = None) -> Dict:
"""
Handles prediction requests, performing inference or embeddings
generation based on request type.
Args:
payload: Contains 'texts' (List[str]) for inference or
embedding and 'type' to specify the operation.
headers: Optional headers for the request.
Returns:
dict: Contains the result 'outputs' or an error message.
"""
texts = payload.get('texts', [])
response_type = payload.get('type', 'unknown')
outputs = []
Expand All @@ -53,13 +99,22 @@ async def predict(self, payload: Dict[str, List[str]], headers: Optional[Dict[st
outputs = self.__generate_embeddings(texts)
else:
return {'error': f'Unknown request type: {response_type}'}
except Exception as e:
logging.error(f"Prediction failed: {str(e)}")
except Exception as e: # pylint: disable=broad-except
logging.exception(e)
return {'error': 'Prediction failed due to internal error.'}

return {'outputs': outputs}

def _perform_inference(self, texts: List[str]) -> List[str]:
"""
Performs text generation using the Llama model.
Args:
texts: List of input texts to generate responses for.
Returns:
List[str]: Generated responses for each input text.
"""
results = []
for text in texts:
input_ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
Expand All @@ -73,20 +128,49 @@ def _perform_inference(self, texts: List[str]) -> List[str]:
return results

def __generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""
Generates embeddings for input texts using the SentenceTransformer
model.
Args:
texts: List of input texts to generate embeddings for.
Returns:
List[List[float]]: Generated embeddings for each input text.
"""
return [self.embeddings_model.encode(text).tolist() for text in texts]

def get_model_config(path):
"""
Reads and returns the model configuration from a YAML file.
Args:
path: Path to the YAML configuration file.
Returns:
dict: Loaded configuration data.
"""
config = None
with open(path) as stream:
with open(path, encoding="utf-8") as stream:
config = yaml.safe_load(stream)
return config

def parse_args():
"""
Parses command-line arguments.
Returns:
argparse.Namespace: Parsed arguments with model config path.
"""
parser = argparse.ArgumentParser(description='remote-llm')
parser.add_argument('--model_config', type=str, default='config.yaml', help='Config path')
return parser.parse_args()

def main():
"""
Initializes logging, loads model configuration, and starts the model
server.
"""
logging.basicConfig(level=logging.INFO)
args = parse_args()
config = get_model_config(args.model_config)
Expand Down
17 changes: 17 additions & 0 deletions src/support_ai/ai_bot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Support-AI Command Line Tool
"""

import argparse
import uuid

Expand All @@ -6,6 +10,12 @@


def parse_args():
"""
Parses command-line arguments for the support-ai command line tool.
Returns:
argparse.Namespace: Parsed arguments, including the config file path.
"""
parser = argparse.ArgumentParser(
description='Command line tool for support-ai')
parser.add_argument('--config', type=str, default=None,
Expand All @@ -14,6 +24,13 @@ def parse_args():


def main():
"""
Main function to execute the support-ai tool. Initializes configuration,
session ID, and input loop for querying the support-ai chain.
Prompts the user for input queries and processes each input through the
chain. Exits if 'exit', 'quit', 'q', 'e', or 'x' is entered.
"""
args = parse_args()
config = get_config(args.config)
chain = Chain(config)
Expand Down
72 changes: 64 additions & 8 deletions src/support_ai/api_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Support-AI API Server
"""

import argparse

from flask import Blueprint, Flask, jsonify, request, Response
Expand All @@ -9,13 +13,25 @@


app = Flask(__name__)
chain = None
chain = None # pylint: disable=invalid-name
api_blueprint = Blueprint('api', __name__)
api = Api(api_blueprint)


class AI(Resource):
def get(self):
class AI(Resource): # pylint: disable=too-few-public-methods
"""
Resource class for the AI endpoint.
"""

def get(self): # pylint: disable=no-self-use
"""
Handles GET requests to the /api/ai endpoint. Queries the AI model
with the provided query and optional datasource and session arguments.
Returns:
Response: A text/plain response with the model's response or
a JSON error message if the query parameter is missing.
"""
query = request.args.get('query')
datasource = request.args.get('datasource')
session = request.args.get('session')
Expand All @@ -30,8 +46,24 @@ def get(self):
return {'message': 'Service unavailable'}, 400


class Salesforce(Resource):
def get(self, case_number):
class Salesforce(Resource): # pylint: disable=too-few-public-methods
"""
Resource class for the Salesforce endpoint.
"""

def get(self, case_number): # pylint: disable=no-self-use
"""
Handles GET requests to the /api/salesforce/<case_number>/summary
endpoint.
Args:
case_number: The case number for which to retrieve a summary.
Returns:
Response: A text/plain response with the case summary or a JSON
error message if the case number is missing or service is
unavailable.
"""
if case_number is None:
return {'message': 'Case number not specified'}, 400

Expand All @@ -44,8 +76,20 @@ def get(self, case_number):
return {'message': 'Service unavailable'}, 400


class History(Resource):
def delete(self):
class History(Resource): # pylint: disable=too-few-public-methods
"""
Resource class for the History endpoint.
"""

def delete(self): # pylint: disable=no-self-use
"""
Handles DELETE requests to the /api/history endpoint. Clears the
history for the specified session.
Returns:
Response: A JSON response with success status, or an error
message if the session parameter is missing.
"""
session = request.args.get('session')

if session is None:
Expand All @@ -55,14 +99,26 @@ def delete(self):


def parse_args():
"""
Parses command-line arguments for the support-ai API server.
Returns:
argparse.Namespace: Parsed arguments, including the config file path.
"""
parser = argparse.ArgumentParser(
description='Command line tool for support-ai')
parser.add_argument('--config', type=str, default=None, help='Config path')
return parser.parse_args()


def main():
global chain
"""
Main function to initialize the support-ai server. Loads configuration,
initializes the Chain, and sets up API routes.
The server listens on all interfaces at port 8080.
"""
global chain # pylint: disable=global-statement
args = parse_args()
config = get_config(args.config)
chain = Chain(config)
Expand Down
17 changes: 17 additions & 0 deletions src/support_ai/ds_updater.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Support-AI Data Source Updater
"""

import argparse
import signal
import sys
Expand All @@ -8,13 +12,26 @@


def parse_args():
"""
Parses command-line arguments for the support-ai data source updater.
Returns:
argparse.Namespace: Parsed arguments, including the config file path.
"""
parser = argparse.ArgumentParser(
description='Command line tool for support-ai')
parser.add_argument('--config', type=str, default=None, help='Config path')
return parser.parse_args()


def main():
"""
Main function to initialize and start the data source updater.
Loads the configuration, initializes the DSUpdater, and sets up signal
handling for graceful termination. Runs the update thread continuously
until a termination signal is received.
"""
args = parse_args()
config = get_config(args.config)
ds_updater = DSUpdater(config)
Expand Down
Loading

0 comments on commit 53b6d63

Please sign in to comment.