Skip to content

Commit 1518282

Browse files
authored
feat(src): initialize a framework (#6)
* update * update
1 parent 00d8036 commit 1518282

File tree

15 files changed

+1698
-20
lines changed

15 files changed

+1698
-20
lines changed

data/yfinance_data/price_data/AAPL_data_formatted.json

Lines changed: 716 additions & 0 deletions
Large diffs are not rendered by default.

examples/demo.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import argparse
2+
3+
from trading_bench.bench import SimBench
4+
from trading_bench.model_wrapper import RuleBasedModel
5+
from trading_bench.utils import setup_logging
6+
7+
8+
def main():
9+
setup_logging()
10+
parser = argparse.ArgumentParser(
11+
description='Run the simulated trading bench with Finnhub data crawler'
12+
)
13+
parser.add_argument('--ticker', required=True, help='Stock ticker symbol')
14+
parser.add_argument('--start_date', required=True, help='Start date (YYYY-MM-DD)')
15+
parser.add_argument('--end_date', required=True, help='End date (YYYY-MM-DD)')
16+
parser.add_argument(
17+
'--data_dir', required=True, help='Root directory where data is stored'
18+
)
19+
parser.add_argument(
20+
'--eval_delay', type=int, default=5, help='Evaluation delay in data points'
21+
)
22+
parser.add_argument(
23+
'--resolution',
24+
choices=['1', '5', '15', '30', '60', 'D', 'W', 'M'],
25+
default='D',
26+
help="Finnhub data resolution (e.g., '1','5','15','30','60','D','W','M')",
27+
)
28+
29+
args = parser.parse_args()
30+
31+
model = RuleBasedModel()
32+
bench = SimBench(
33+
ticker=args.ticker,
34+
start_date=args.start_date,
35+
end_date=args.end_date,
36+
data_dir=args.data_dir,
37+
model=model,
38+
eval_delay=args.eval_delay,
39+
resolution=args.resolution,
40+
)
41+
summary = bench.run()
42+
43+
print('Performance Summary:')
44+
for key, value in summary.items():
45+
if isinstance(value, float):
46+
print(f'{key}: {value:.4f}')
47+
else:
48+
print(f'{key}: {value}')
49+
50+
51+
if __name__ == '__main__':
52+
main()

poetry.lock

Lines changed: 615 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
[tool.poetry]
2-
name = "python-project-template"
2+
name = "trading-bench"
33
version = "0.0.1"
44
description = "A template for python-based research project"
55
authors = ["Haofei Yu <yuhaofei44@gmail.com>"]
66
license = "Apache 2.0 License"
77
readme = "README.md"
8-
packages = [
9-
{ include = "src" }
10-
]
118

129
[tool.poetry.dependencies]
1310
python = ">=3.9, <3.12"
1411
mypy = "^1.8.0"
1512
beartype = "^0.17.1"
1613
pydantic = "^2.8.2"
14+
finnhub-python = "*"
15+
yfinance = "*"
1716

1817
[tool.poetry.group.dev.dependencies]
1918
pre-commit = "^3.6.0"

scripts/demo.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/usr/bin/env bash
2+
3+
# === Configuration ===
4+
DATA_DIR="../data" # e.g. ~/projects/data
5+
TICKER="AAPL"
6+
START_DATE="2025-01-01"
7+
END_DATE="2025-06-01"
8+
EVAL_DELAY=5
9+
RESOLUTION="D" # Finnhub resolution: 1,5,15,30,60,D,W,M
10+
11+
# === Run the bench ===
12+
python ../examples/demo.py \
13+
--ticker "${TICKER}" \
14+
--start_date "${START_DATE}" \
15+
--end_date "${END_DATE}" \
16+
--data_dir "${DATA_DIR}" \
17+
--eval_delay "${EVAL_DELAY}" \
18+
--resolution "${RESOLUTION}"
File renamed without changes.

trading_bench/bench.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import json
2+
import os
3+
from collections import defaultdict, deque
4+
from datetime import datetime
5+
6+
from .data_fetcher import fetch_price_data
7+
from .evaluator import ReturnEvaluator
8+
from .metrics import MetricsLogger
9+
from .model_wrapper import BaseModel
10+
from .signal import Signal
11+
12+
13+
class SimBench:
14+
"""
15+
Simulated backtest bench that asks a model for buy signals and evaluates returns.
16+
Uses the updated yfinance-based fetch_price_data to retrieve OHLCV data.
17+
"""
18+
19+
def __init__(
20+
self,
21+
ticker: str,
22+
start_date: str,
23+
end_date: str,
24+
data_dir: str,
25+
model: BaseModel,
26+
eval_delay: int = 5,
27+
resolution: str = 'D',
28+
):
29+
self.ticker = ticker
30+
self.start_date = start_date
31+
self.end_date = end_date
32+
self.data_dir = data_dir
33+
self.model = model
34+
self.eval_delay = eval_delay
35+
self.resolution = resolution
36+
37+
# Fetch and save price data (yfinance) into yfinance_data/price_data
38+
fetch_price_data(
39+
ticker=self.ticker,
40+
start_date=self.start_date,
41+
end_date=self.end_date,
42+
data_dir=self.data_dir,
43+
resolution=self.resolution,
44+
)
45+
46+
# Load fetched JSON data from yfinance_data
47+
data_path = os.path.join(
48+
self.data_dir,
49+
'yfinance_data',
50+
'price_data',
51+
f'{self.ticker}_data_formatted.json',
52+
)
53+
if not os.path.isfile(data_path):
54+
raise FileNotFoundError(f'Expected data file not found at {data_path}')
55+
56+
with open(data_path, encoding='utf-8') as f:
57+
raw_data = json.load(f)
58+
59+
# Parse into list of (datetime, close_price)
60+
parsed: list[tuple[datetime, float]] = []
61+
for date_str, v in raw_data.items():
62+
date = datetime.fromisoformat(date_str)
63+
# v is a dict with keys open, high, low, close, volume
64+
price = float(v.get('close', v))
65+
parsed.append((date, price))
66+
67+
# Sort chronologically and initialize history deque
68+
self.data: list[tuple[datetime, float]] = sorted(parsed, key=lambda x: x[0])
69+
self.history: deque = deque(self.data)
70+
71+
self.evaluator = ReturnEvaluator()
72+
self.logger = MetricsLogger()
73+
74+
def run(self) -> dict[str, float]:
75+
prices = [price for _, price in self.data]
76+
n = len(prices)
77+
78+
# 1. start with an empty "past history" and a place to stash pending signals
79+
self.history = deque()
80+
pending: dict[int, list[Signal]] = defaultdict(list)
81+
82+
for idx, (date, price) in enumerate(self.data):
83+
# 2. append this step into your history
84+
self.history.append((date, price))
85+
86+
# 3. if model says BUY, schedule evaluation at idx + eval_delay
87+
if self.model.should_buy([p for _, p in self.history]):
88+
eval_idx = min(idx + self.eval_delay, n - 1)
89+
eval_time = self.data[eval_idx][0]
90+
signal = Signal(date, price, eval_time)
91+
pending[eval_idx].append(signal)
92+
93+
# 4. now check if any scheduled signals are due at this idx
94+
if idx in pending:
95+
for signal in pending.pop(idx):
96+
# you might want to pass only the slice of history from buy→eval
97+
# but ReturnEvaluator could also just use signal.price + actual price at eval_time
98+
ret = self.evaluator.evaluate(
99+
signal, list(self.history)[-self.eval_delay - 1 :]
100+
)
101+
self.logger.record(ret)
102+
103+
return self.logger.summary()

trading_bench/config.py

Whitespace-only changes.

trading_bench/data_fetcher.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import json
2+
import os
3+
import time
4+
5+
import yfinance as yf
6+
7+
8+
def fetch_price_data(
9+
ticker: str, start_date: str, end_date: str, data_dir: str, resolution: str = 'D'
10+
) -> None:
11+
"""
12+
Fetches historical OHLCV price data for a ticker via yfinance and saves it as formatted JSON.
13+
Args:
14+
ticker: Stock ticker symbol.
15+
start_date: YYYY-MM-DD
16+
end_date: YYYY-MM-DD
17+
data_dir: Root data directory where 'yfinance_data/price_data' will live.
18+
resolution: '1', '5', '15', '30', '60', 'D', 'W', 'M'
19+
"""
20+
# map your resolution codes to yfinance intervals
21+
interval_map = {
22+
'1': '1m',
23+
'5': '5m',
24+
'15': '15m',
25+
'30': '30m',
26+
'60': '60m',
27+
'D': '1d',
28+
'W': '1wk',
29+
'M': '1mo',
30+
}
31+
interval = interval_map.get(resolution.upper(), '1d')
32+
33+
# download data
34+
df = yf.download(
35+
tickers=ticker,
36+
start=start_date,
37+
end=end_date,
38+
interval=interval,
39+
progress=False,
40+
)
41+
42+
if df.empty:
43+
raise RuntimeError(
44+
f'No data returned for {ticker} {start_date}{end_date} @ {interval}'
45+
)
46+
47+
# Build date-indexed dict
48+
data = {}
49+
for idx, row in df.iterrows():
50+
# idx is a pandas.Timestamp
51+
date_str = idx.strftime('%Y-%m-%d')
52+
data[date_str] = {
53+
'open': float(row['Open']),
54+
'high': float(row['High']),
55+
'low': float(row['Low']),
56+
'close': float(row['Close']),
57+
'volume': int(row['Volume']),
58+
}
59+
60+
# Ensure target directory exists
61+
out_dir = os.path.join(data_dir, 'yfinance_data', 'price_data')
62+
os.makedirs(out_dir, exist_ok=True)
63+
64+
out_path = os.path.join(out_dir, f'{ticker}_data_formatted.json')
65+
with open(out_path, 'w', encoding='utf-8') as f:
66+
json.dump(data, f, indent=2)
67+
68+
# be polite with any rate limits
69+
time.sleep(1)

0 commit comments

Comments
 (0)