|
1 | 1 | # app.py |
2 | 2 | from flask import Flask, request, jsonify |
3 | | -from inference import generate, load_model_from_hub |
4 | | - |
5 | | - |
| 3 | +from inference import generate, load_model_and_tokenizer ,format_prompt |
6 | 4 |
|
7 | 5 | app = Flask(__name__) |
8 | 6 |
|
9 | | -# Specify the folder we want to save the downloaded model |
10 | | -destination_folder = "models" |
11 | | - |
12 | | -# Load the mpt-7b-chat model from the Hugging Face Model Hub |
| 7 | +# Load the mpt-7b-chat model and tokenizer from the Hugging Face Model Hub |
13 | 8 | model_name = "mosaicml/mpt-7b-chat" |
14 | | -llm = load_model_from_hub(model_name) |
| 9 | +llm, tokenizer = load_model_and_tokenizer(model_name, trust_remote_code=True) # Update model loading |
15 | 10 |
|
16 | | - |
17 | | -# Use the model path in the generate function |
| 11 | +# Use the model and tokenizer in the generate function |
18 | 12 | @app.route('/predict', methods=['POST']) |
19 | 13 | def predict(): |
20 | 14 | data = request.json |
21 | 15 | user_prompt = data.get('user_prompt') |
22 | 16 |
|
23 | | -# Update the call to use the loaded model directly |
24 | | - assistant_response = generate(llm, user_prompt) |
| 17 | + # Update the call to use the loaded model and tokenizer directly |
| 18 | + generation_config = { |
| 19 | + "temperature": 0.2, |
| 20 | + "top_k": 0, |
| 21 | + "top_p": 0.9, |
| 22 | + "repetition_penalty": 1.0, |
| 23 | + "max_new_tokens": 512, |
| 24 | + } |
| 25 | + |
| 26 | + # Format the prompt using the system prompt |
| 27 | + system_prompt = "A conversation between a user and an LLM-based AI assistant named Local Assistant. Local Assistant gives helpful and honest answers." |
| 28 | + prompt = format_prompt(system_prompt, user_prompt) |
| 29 | + |
| 30 | + # Generate the assistant's response |
| 31 | + assistant_response = generate(llm, tokenizer, generation_config, prompt) |
25 | 32 |
|
26 | 33 | return jsonify({'assistant_response': assistant_response}) |
27 | 34 |
|
28 | 35 | if __name__ == '__main__': |
29 | 36 | app.run(host='0.0.0.0', port=5000) |
30 | | - |
31 | | - |
|
0 commit comments