Skip to content

Commit b0120dd

Browse files
committed
update app.py file
Signed-off-by: lyndanajjar <lyndanajjar15@gmail.com>
1 parent 842f462 commit b0120dd

1 file changed

Lines changed: 19 additions & 14 deletions

File tree

src/mpt-7B-inference/app.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
11
# app.py
22
from flask import Flask, request, jsonify
3-
from inference import generate, load_model_from_hub
4-
5-
3+
from inference import generate, load_model_and_tokenizer ,format_prompt
64

75
app = Flask(__name__)
86

9-
# Specify the folder we want to save the downloaded model
10-
destination_folder = "models"
11-
12-
# Load the mpt-7b-chat model from the Hugging Face Model Hub
7+
# Load the mpt-7b-chat model and tokenizer from the Hugging Face Model Hub
138
model_name = "mosaicml/mpt-7b-chat"
14-
llm = load_model_from_hub(model_name)
9+
llm, tokenizer = load_model_and_tokenizer(model_name, trust_remote_code=True) # Update model loading
1510

16-
17-
# Use the model path in the generate function
11+
# Use the model and tokenizer in the generate function
1812
@app.route('/predict', methods=['POST'])
1913
def predict():
2014
data = request.json
2115
user_prompt = data.get('user_prompt')
2216

23-
# Update the call to use the loaded model directly
24-
assistant_response = generate(llm, user_prompt)
17+
# Update the call to use the loaded model and tokenizer directly
18+
generation_config = {
19+
"temperature": 0.2,
20+
"top_k": 0,
21+
"top_p": 0.9,
22+
"repetition_penalty": 1.0,
23+
"max_new_tokens": 512,
24+
}
25+
26+
# Format the prompt using the system prompt
27+
system_prompt = "A conversation between a user and an LLM-based AI assistant named Local Assistant. Local Assistant gives helpful and honest answers."
28+
prompt = format_prompt(system_prompt, user_prompt)
29+
30+
# Generate the assistant's response
31+
assistant_response = generate(llm, tokenizer, generation_config, prompt)
2532

2633
return jsonify({'assistant_response': assistant_response})
2734

2835
if __name__ == '__main__':
2936
app.run(host='0.0.0.0', port=5000)
30-
31-

0 commit comments

Comments
 (0)