Skip to content

Commit 58953b1

Browse files
committed
feat: add paper trading api endpoints
1 parent b1a3b80 commit 58953b1

File tree

4 files changed

+70
-45
lines changed

4 files changed

+70
-45
lines changed

backend/api/paper_trade.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,44 @@
11
from fastapi import APIRouter, HTTPException
2-
from pydantic import BaseModel
3-
from backend.services.inference import cache # Import cache to get resources
4-
import pandas as pd
5-
import torch
6-
import logging
7-
8-
router = APIRouter()
9-
10-
class BacktestRequest(BaseModel):
11-
ticker: str
12-
start_date: str
13-
end_date: str
14-
15-
@router.post("/backtest")
16-
def run_backtest(request: BacktestRequest):
17-
# This logic can also be moved to a 'backtester.py' service later
18-
# For now, we'll keep it here for simplicity.
19-
model = cache.get("model"); data_df = cache.get("data"); x_scaler = cache.get("x_scaler"); y_scaler = cache.get("y_scaler"); device = cache.get("device")
20-
if not all([model, data_df is not None, x_scaler, y_scaler]):
21-
raise HTTPException(status_code=500, detail="Server resources not loaded.")
22-
# ... (Paste the full backtesting for loop and logic from your old main.py here) ...
23-
try:
24-
ticker_data = data_df[data_df['Ticker'] == request.ticker.upper()]
25-
backtest_period = ticker_data[(ticker_data['Date'] >= pd.to_datetime(request.start_date)) & (ticker_data['Date'] <= pd.to_datetime(request.end_date))]
26-
if len(backtest_period) < 61: raise HTTPException(status_code=400, detail="Date range too short. Need at least 61 days for backtesting.")
27-
portfolio_value = 10000.0; equity_curve = []; feature_cols = x_scaler.feature_names_in_
28-
for i in range(60, len(backtest_period)):
29-
sequence_df = backtest_period.iloc[i-60:i]
30-
current_price = sequence_df.iloc[-1]['Close']
31-
sequence_scaled = x_scaler.transform(sequence_df[feature_cols])
32-
input_tensor = torch.from_numpy(sequence_scaled).float().unsqueeze(0).to(device)
33-
with torch.no_grad():
34-
prediction_scaled = model(input_tensor)
35-
predicted_price = y_scaler.inverse_transform(prediction_scaled.cpu().numpy())[0][0]
36-
actual_next_price = backtest_period.iloc[i]['Close']
37-
if predicted_price > current_price:
38-
daily_return = (actual_next_price - current_price) / current_price
39-
portfolio_value *= (1 + daily_return)
40-
equity_curve.append({'Date': backtest_period.iloc[i]['Date'].strftime('%Y-%m-%d'), 'Portfolio Value': portfolio_value})
41-
total_return = (portfolio_value / 10000.0 - 1) * 100
42-
return {"total_return_pct": total_return, "final_portfolio_value": portfolio_value, "equity_curve": equity_curve}
43-
except Exception as e:
44-
logging.error(f"Error during backtest: {e}")
45-
raise HTTPException(status_code=500, detail=str(e))
2+
3+
from backend.schemas.paper_trade import PaperTradeRequest
4+
from backend.schemas.trading import TradeSignal
5+
from backend.services.execution import PaperExecutionEngine
6+
from backend.services.portfolio import Portfolio
7+
8+
router = APIRouter(prefix="/paper", tags=["paper-trading"])
9+
10+
# In-memory portfolio (MVP)
11+
portfolio = Portfolio(initial_cash=1_000_000)
12+
13+
execution_engine = PaperExecutionEngine()
14+
15+
16+
@router.post("/trade")
17+
def paper_trade(req: PaperTradeRequest):
18+
signal = TradeSignal(
19+
symbol=req.symbol,
20+
p_up=req.p_up,
21+
expected_return=req.expected_return,
22+
uncertainty=req.uncertainty,
23+
timestamp=req.timestamp,
24+
)
25+
26+
trade = execution_engine.generate_trade(
27+
signal=signal,
28+
current_price=req.price,
29+
capital=portfolio.cash,
30+
)
31+
32+
if trade:
33+
portfolio.apply_trade(trade)
34+
return {
35+
"status": "executed",
36+
"trade": trade.dict(),
37+
"cash": portfolio.cash,
38+
"positions": portfolio.positions,
39+
}
40+
41+
return {
42+
"status": "no_trade",
43+
"reason": "signal did not meet execution criteria",
44+
}

backend/api/portfolio.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from fastapi import APIRouter
2+
from backend.api.paper_trade import portfolio
3+
4+
router = APIRouter(prefix="/portfolio", tags=["portfolio"])
5+
6+
7+
@router.get("")
8+
def get_portfolio():
9+
return {
10+
"cash": portfolio.cash,
11+
"positions": portfolio.positions,
12+
"trades": [t.dict() for t in portfolio.trade_log],
13+
}

backend/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from fastapi import FastAPI
2-
from backend.api import predict
2+
from backend.api import predict, paper_trade, portfolio
33

44
app = FastAPI(title="ArthaQuant API")
55

66
app.include_router(predict.router)
7+
app.include_router(paper_trade.router)
8+
app.include_router(portfolio.router)

backend/schemas/paper_trade.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from pydantic import BaseModel
2+
from datetime import datetime
3+
4+
5+
class PaperTradeRequest(BaseModel):
6+
symbol: str
7+
p_up: float
8+
expected_return: float
9+
uncertainty: float
10+
price: float
11+
timestamp: datetime

0 commit comments

Comments
 (0)