-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_rest_api.py
More file actions
211 lines (177 loc) · 6.48 KB
/
predict_rest_api.py
File metadata and controls
211 lines (177 loc) · 6.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import asyncio
import base64
import logging
import os
from contextlib import asynccontextmanager
import tensorflow as tf
import uvicorn
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from src.exception import CustomException
from src.models.image_payload import ImagePayload
from src.pipeline.predict_pipeline import PredictPipeline
from src.services.ImagePreprocessingService import ImagePreprocessingService
logger = logging.getLogger(__name__)
def get_predict_pipeline(app: FastAPI):
"""
Retrieve the loaded model from FastAPI's application state.
"""
if app.state.predict_pipeline:
return app.state.predict_pipeline
else:
return None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown logic.
"""
# Startup logic
logger.info("Starting application...")
app.state.predict_pipeline = PredictPipeline("densenet", "binary")
logger.info("Application started successfully.")
# Yield control to the application
yield
# Shutdown logic (if needed)
logger.info("Shutting down resources...")
class Host:
def __init__(self, args: None):
"""
Initialize the Host class for Frostfire Stock Analysis AI Hub.
"""
self.logger = logging.getLogger(__name__)
# Configuration
self.host = os.getenv("HOST", "0.0.0.0")
self.port = int(os.getenv("PORT", 8000))
# Initialize FastAPI
self.app = FastAPI(
title="Frostfire Chart Sifter API", version="1.0", lifespan=lifespan
)
# Enable CORS
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
allow_headers=["*"], # Allow all headers
)
self.setup_routes()
def setup_routes(self):
"""
Set up FastAPI routes.
"""
@self.app.options("/sift_images/")
async def options_handler():
"""
Handle preflight OPTIONS request for CORS.
"""
return JSONResponse(content={}, status_code=200)
@self.app.post("/sift_images/")
async def detect_charts(
request: Request,
predict_pipeline: tf.keras.Model = Depends(
lambda: get_predict_pipeline(self.app)
),
):
"""
Detect if Base64-encoded images are charts using DenseNet.
Parameters:
- request (Request): Raw request data.
Returns:
- dict: Structured response with code, code_text, message, and data.
"""
try:
# Parse and validate the payload using ImagePayload
body = await request.json()
payload = ImagePayload(body)
base64_images = payload.base64_images
except ValueError as e:
self.logger.error("Invalid request payload: %s", e)
return {
"code": 400,
"code_text": "error",
"message": str(e),
"data": None,
}
results = []
preprocessing = ImagePreprocessingService()
features = preprocessing.parse_and_preprocess_images(base64_images)
# Perform prediction using the saved files
predictions = predict_pipeline.predict(features)
# Parse predictions into results format
results = [
{"index": idx, "is_chart": pred["label"]}
for idx, pred in enumerate(predictions)
]
# Structured response
return {
"code": 0,
"code_text": "ok",
"message": "Processed successfully.",
"data": results,
}
@self.app.get("/health")
async def health_check():
"""
Health check endpoint to verify that the model and LLM are loaded properly.
"""
try:
# Check if the model is loaded
if (
not hasattr(self.app.state, "predict_pipeline")
or not self.app.state.predict_pipeline
):
raise ValueError("Chart detection model is not initialized.")
# Structured response
return {
"code": 0,
"code_text": "ok",
"message": "All services are running.",
"data": {"sift_images": "loaded"},
}
except Exception as e:
self.logger.error("Health check failed: %s", e)
return {
"code": 500,
"code_text": "error",
"message": str(e),
"data": {"sift_images": "not loaded"},
}
def run(self):
"""
Asynchronous method to start both MQTT and FastAPI server concurrently.
"""
self.logger.info("Starting host process.")
fastapi_task = None # Initialize fastapi_task to None
try:
# # Start the heartbeat task
# heartbeat_task = asyncio.create_task(self.mqtt_service.heartbeat())
# Start FastAPI server as a task
fastapi_task = asyncio.run(self.start_fastapi())
# # Keep the process running until interrupted
# while True:
# await asyncio.sleep(1)
except asyncio.CancelledError:
self.logger.info("Stopping host process.")
finally:
if fastapi_task: # Check if fastapi_task is initialized
fastapi_task.cancel()
fastapi_task
# await self.mqtt_service.shutdown()
async def start_fastapi(self):
"""
Run the FastAPI server asynchronously.
"""
config = uvicorn.Config(
self.app, host=self.host, port=self.port, log_level="info"
)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
try:
args = None
instance = Host(args)
# Run the async main function with the parsed arguments
instance.run()
except CustomException as e:
logging.error("Critical error: %s. Application cannot start.", e)