Skip to content

Commit 9bb7364

Browse files
committed
feat: add multimodal inference service for transformer + finbert
1 parent 7cd0fa2 commit 9bb7364

File tree

4 files changed

+81
-143
lines changed

4 files changed

+81
-143
lines changed

backend/api/predict.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
1-
from fastapi import APIRouter, HTTPException
2-
from pydantic import BaseModel
3-
from backend.services import inference
4-
import logging
1+
import torch
2+
from fastapi import APIRouter
3+
4+
from backend.schemas.predict import PredictRequest, PredictResponse
5+
from backend.services.inference import InferenceService
56

67
router = APIRouter()
78

8-
class PredictionRequest(BaseModel):
9-
ticker: str
10-
date: str
9+
# singleton inference service
10+
inference_service = InferenceService(
11+
model_path="models/fusion_model.pt",
12+
input_dim=10,
13+
)
14+
15+
16+
@router.post("/predict", response_model=PredictResponse)
17+
def predict(req: PredictRequest):
18+
market_x = torch.tensor(req.market_sequence).unsqueeze(0)
19+
input_ids = torch.tensor(req.input_ids).unsqueeze(0)
20+
attention_mask = torch.tensor(req.attention_mask).unsqueeze(0)
1121

12-
class PredictionResponse(BaseModel):
13-
predicted_price: float
22+
out = inference_service.predict(
23+
market_x,
24+
input_ids,
25+
attention_mask,
26+
)
1427

15-
@router.post("/predict", response_model=PredictionResponse)
16-
def predict(request: PredictionRequest):
17-
try:
18-
prediction = inference.make_prediction(ticker=request.ticker, date_str=request.date)
19-
return {"predicted_price": prediction}
20-
except FileNotFoundError as e:
21-
raise HTTPException(status_code=404, detail=str(e))
22-
except ValueError as e:
23-
raise HTTPException(status_code=400, detail=str(e))
24-
except Exception as e:
25-
logging.error(f"Error during prediction: {e}")
26-
raise HTTPException(status_code=500, detail="An internal error occurred during prediction.")
28+
return PredictResponse(**out)

backend/main.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,6 @@
11
from fastapi import FastAPI
2-
from contextlib import asynccontextmanager
3-
from backend.services import inference
4-
from backend.routers import predict, backtest, drift, screener
5-
import logging
2+
from backend.api import predict
63

7-
logging.basicConfig(level=logging.INFO)
4+
app = FastAPI(title="ArthaQuant API")
85

9-
# --- Lifespan Events ---
10-
@asynccontextmanager
11-
async def lifespan(app: FastAPI):
12-
# Load all ML resources on startup
13-
inference.load_resources()
14-
yield
15-
# Clean up resources on shutdown
16-
inference.cache.clear()
17-
18-
app = FastAPI(
19-
title="Stock Market AI API (Refactored)",
20-
lifespan=lifespan
21-
)
22-
23-
# --- Include Routers ---
24-
# This brings in all the endpoints from your other files
256
app.include_router(predict.router)
26-
app.include_router(backtest.router)
27-
app.include_router(drift.router)
28-
app.include_router(screener.router)
29-
30-
# --- Other General Endpoints ---
31-
from backend.services.inference import cache
32-
from fastapi import HTTPException
33-
34-
@app.get("/tickers")
35-
def get_tickers():
36-
data_df = cache.get("data")
37-
if data_df is None: raise HTTPException(status_code=500, detail="Data not loaded.")
38-
return data_df['Ticker'].unique().tolist()
39-
40-
@app.get("/history/{ticker}")
41-
def get_history(ticker: str):
42-
data_df = cache.get("data")
43-
if data_df is None: raise HTTPException(status_code=500, detail="Data not loaded.")
44-
ticker_data = data_df[data_df['Ticker'] == ticker.upper()]
45-
if ticker_data.empty: raise HTTPException(status_code=404, detail=f"Ticker '{ticker}' not found.")
46-
return ticker_data[['Date', 'Close']].to_dict('records')
47-
48-
@app.get("/")
49-
def read_root():
50-
return {"message": "Welcome to the Stock Market AI API!"}

backend/schemas/predict.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pydantic import BaseModel
2+
from typing import List
3+
4+
5+
class PredictRequest(BaseModel):
6+
symbol: str
7+
market_sequence: List[List[float]] # (T, F)
8+
input_ids: List[int]
9+
attention_mask: List[int]
10+
11+
12+
class PredictResponse(BaseModel):
13+
p_up: float
14+
expected_return: float
15+
uncertainty: float

backend/services/inference.py

Lines changed: 41 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,43 @@
11
import torch
2-
import numpy as np
3-
import pandas as pd
4-
from pathlib import Path
5-
import logging
6-
import joblib
7-
import sys
82

9-
# Add project root to path
10-
project_root = str(Path(__file__).resolve().parents[2])
11-
if project_root not in sys.path:
12-
sys.path.insert(0, project_root)
13-
14-
from pipelines.train_transformer_pipeline import TransformerModel
15-
16-
# This dictionary will act as a global cache for loaded resources
17-
cache = {}
18-
19-
def load_resources():
20-
"""Loads all necessary ML resources into the cache."""
21-
if "model" in cache:
22-
logging.info("Resources already loaded.")
23-
return
24-
25-
logging.info("Loading resources for Transformer model...")
26-
device = "cuda" if torch.cuda.is_available() else "cpu"
27-
cache["device"] = device
28-
29-
try:
30-
cache["x_scaler"] = joblib.load(Path("data/sequences_sentiment/x_scaler.joblib"))
31-
cache["y_scaler"] = joblib.load(Path("data/sequences_sentiment/y_scaler.joblib"))
32-
data_path = Path("data/processed/final_fused_data.csv")
33-
df = pd.read_csv(data_path, parse_dates=['Date'])
34-
cache["data"] = df
35-
36-
model_path = Path("models/transformer_v1.pt")
37-
input_size = cache["x_scaler"].n_features_in_
38-
39-
model = TransformerModel(input_size=input_size).to(device)
40-
model.load_state_dict(torch.load(model_path, map_location=device))
41-
model.eval()
42-
cache["model"] = model
43-
logging.info("Transformer model and all resources loaded successfully.")
44-
except Exception as e:
45-
logging.error(f"Failed to load resources on startup: {e}")
46-
47-
def make_prediction(ticker: str, date_str: str) -> float:
48-
"""Makes a single stock prediction for a given ticker and date."""
49-
if "model" not in cache:
50-
raise ValueError("Model and resources are not loaded.")
51-
52-
model = cache["model"]
53-
data_df = cache["data"]
54-
x_scaler = cache["x_scaler"]
55-
y_scaler = cache["y_scaler"]
56-
device = cache["device"]
57-
58-
sequence_length = 60
59-
end_date = pd.to_datetime(date_str)
60-
61-
ticker_data = data_df[data_df['Ticker'] == ticker.upper()]
62-
if ticker_data.empty:
63-
raise FileNotFoundError(f"Data for ticker '{ticker}' not found.")
64-
65-
data_up_to_date = ticker_data[ticker_data['Date'] <= end_date]
66-
if len(data_up_to_date) < sequence_length:
67-
raise ValueError(f"Not enough historical data for {ticker} before {date_str}.")
68-
69-
sequence_to_predict = data_up_to_date.tail(sequence_length)
70-
feature_cols = x_scaler.feature_names_in_
71-
sequence_scaled = x_scaler.transform(sequence_to_predict[feature_cols])
72-
73-
input_tensor = torch.from_numpy(sequence_scaled).float().unsqueeze(0).to(device)
74-
with torch.no_grad():
75-
prediction_scaled = model(input_tensor)
76-
77-
prediction_unscaled = y_scaler.inverse_transform(prediction_scaled.cpu().numpy())[0][0]
78-
return prediction_unscaled
3+
from ml.models.multimodal_model import MultimodalTradingModel
4+
5+
6+
class InferenceService:
7+
"""
8+
Stateless inference service.
9+
Loads model once, serves predictions.
10+
"""
11+
12+
def __init__(self, model_path: str, input_dim: int):
13+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14+
15+
self.model = MultimodalTradingModel(input_dim=input_dim)
16+
self.model.load_state_dict(
17+
torch.load(model_path, map_location=self.device)
18+
)
19+
self.model.to(self.device)
20+
self.model.eval()
21+
22+
@torch.no_grad()
23+
def predict(
24+
self,
25+
market_x: torch.Tensor,
26+
input_ids: torch.Tensor,
27+
attention_mask: torch.Tensor,
28+
) -> dict:
29+
market_x = market_x.to(self.device)
30+
input_ids = input_ids.to(self.device)
31+
attention_mask = attention_mask.to(self.device)
32+
33+
output = self.model(
34+
market_x,
35+
input_ids,
36+
attention_mask,
37+
)
38+
39+
return {
40+
"p_up": float(output["p_up"].mean().cpu()),
41+
"expected_return": float(output["expected_return"].mean().cpu()),
42+
"uncertainty": float(output["uncertainty"].mean().cpu()),
43+
}

0 commit comments

Comments
 (0)