Skip to content

Commit 2575215

Browse files
committed
chore: add candlestick chart generation and sample data functions
1 parent 9e57a7c commit 2575215

File tree

1 file changed

+215
-0
lines changed

1 file changed

+215
-0
lines changed

Diff for: example/candle_stick/example_candle_stick_mpl.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from typing import Dict, List, Optional, Tuple
2+
3+
import matplotlib.dates as mdates
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pandas as pd
7+
from mplfinance.original_flavor import candlestick_ohlc
8+
9+
10+
def generate_candlestick_chart(
11+
data_dict: Dict[str, List],
12+
output_file: str = "candle_stick.svg",
13+
width: float = 0.6,
14+
colorup: str = "g",
15+
colordown: str = "r",
16+
title: str = "Stock Price Candlestick Chart",
17+
show_volume: bool = False,
18+
) -> Tuple[plt.Figure, plt.Axes]: # type: ignore
19+
"""
20+
Generate and save a candlestick chart from OHLC (Open, High, Low, Close) data.
21+
22+
Parameters
23+
----------
24+
data_dict : Dict[str, List]
25+
Dictionary with keys 'Date', 'Open', 'High', 'Low', 'Close'
26+
containing the trading data. Optional key 'Volume' for volume data.
27+
output_file : str, optional
28+
Filename to save the chart, by default "candle_stick.svg"
29+
width : float, optional
30+
Width of the candlesticks, by default 0.6
31+
colorup : str, optional
32+
Color for upward price movements, by default "g"
33+
colordown : str, optional
34+
Color for downward price movements, by default "r"
35+
title : str, optional
36+
Title for the chart, by default "Stock Price Candlestick Chart"
37+
show_volume : bool, optional
38+
Whether to display volume data in a subplot, by default False
39+
40+
Returns
41+
-------
42+
Tuple[plt.Figure, plt.Axes]
43+
Figure and Axes objects of the created chart
44+
45+
Examples
46+
--------
47+
>>> data = {
48+
>>> "Date": ["2021-01-01", "2021-01-02"],
49+
>>> "Open": [100, 102],
50+
>>> "High": [105, 106],
51+
>>> "Low": [99, 101],
52+
>>> "Close": [104, 105]
53+
>>> }
54+
>>> fig, ax = generate_candlestick_chart(data)
55+
"""
56+
array_lengths = [len(arr) for arr in data_dict.values()]
57+
if len(set(array_lengths)) > 1:
58+
raise ValueError("All arrays in data_dict must be of the same length")
59+
60+
df = pd.DataFrame(data_dict)
61+
df["Date"] = pd.to_datetime(df["Date"])
62+
df["Date_num"] = df["Date"].apply(mdates.date2num) # type: ignore
63+
64+
ohlc = df[["Date_num", "Open", "High", "Low", "Close"]].values
65+
66+
if show_volume and "Volume" in data_dict:
67+
fig, (ax1, ax2) = plt.subplots(
68+
2, 1, figsize=(12, 8), gridspec_kw={"height_ratios": [3, 1]}, sharex=True
69+
)
70+
71+
candlestick_ohlc(
72+
ax1, ohlc, width=width, colorup=colorup, colordown=colordown, alpha=0.8
73+
)
74+
75+
ax1.xaxis_date()
76+
ax1.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
77+
78+
ax1.grid(True, linestyle="--", alpha=0.6)
79+
ax1.set_title(title, fontsize=14)
80+
ax1.set_ylabel("Price", fontsize=12)
81+
82+
ax2.bar(df["Date_num"], df["Volume"], width=width, color="blue", alpha=0.4)
83+
ax2.set_ylabel("Volume", fontsize=12)
84+
ax2.set_xlabel("Date", fontsize=12)
85+
ax2.grid(True, linestyle="--", alpha=0.4)
86+
87+
plt.xticks(rotation=45)
88+
ax = ax1
89+
else:
90+
fig, ax = plt.subplots(figsize=(12, 6))
91+
candlestick_ohlc(
92+
ax, ohlc, width=width, colorup=colorup, colordown=colordown, alpha=0.8
93+
)
94+
95+
ax.xaxis_date()
96+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
97+
plt.xticks(rotation=45)
98+
99+
ax.grid(True, linestyle="--", alpha=0.6)
100+
ax.set_title(title, fontsize=14)
101+
ax.set_xlabel("Date", fontsize=12)
102+
ax.set_ylabel("Price", fontsize=12)
103+
104+
plt.tight_layout()
105+
plt.savefig(output_file)
106+
107+
return fig, ax
108+
109+
110+
def generate_sample_data(
111+
start_date: str = "2021-01-01",
112+
end_date: str = "2021-01-31",
113+
start_price: float = 100.0,
114+
volatility: float = 0.02,
115+
seed: Optional[int] = 42,
116+
) -> Dict[str, List]:
117+
"""
118+
Generate sample OHLCV data for stock price simulation.
119+
120+
Parameters
121+
----------
122+
start_date : str, optional
123+
Start date for the data in YYYY-MM-DD format, by default "2021-01-01"
124+
end_date : str, optional
125+
End date for the data in YYYY-MM-DD format, by default "2021-01-31"
126+
start_price : float, optional
127+
Starting price for the simulation, by default 100.0
128+
volatility : float, optional
129+
Daily price volatility as a decimal, by default 0.02 (2%)
130+
seed : Optional[int], optional
131+
Random seed for reproducibility, by default 42
132+
133+
Returns
134+
-------
135+
Dict[str, List]
136+
Dictionary with keys 'Date', 'Open', 'High', 'Low', 'Close', 'Volume'
137+
containing the generated trading data
138+
139+
Examples
140+
--------
141+
>>> data = generate_sample_data(start_date="2021-01-01", end_date="2021-01-10")
142+
>>> len(data["Date"]) # Number of trading days
143+
"""
144+
if seed is not None:
145+
np.random.seed(seed)
146+
147+
all_dates = pd.date_range(start=start_date, end=end_date)
148+
trading_dates = all_dates[all_dates.dayofweek < 5] # 0-4 are Monday to Friday
149+
dates = [d.strftime("%Y-%m-%d") for d in trading_dates]
150+
151+
opens = []
152+
closes = []
153+
highs = []
154+
lows = []
155+
volumes = []
156+
157+
current_price = start_price
158+
for i in range(len(dates)):
159+
if i == 0:
160+
opens.append(current_price)
161+
else:
162+
opens.append(closes[i - 1])
163+
164+
price_change = np.random.normal(0, volatility * opens[i])
165+
166+
if opens[i] > start_price * 1.1:
167+
price_change -= volatility * opens[i] * 0.05
168+
elif opens[i] < start_price * 0.9:
169+
price_change += volatility * opens[i] * 0.05
170+
171+
close = opens[i] + price_change
172+
closes.append(round(close, 2))
173+
174+
daily_range = abs(price_change) + (volatility * opens[i])
175+
high = max(opens[i], close) + abs(np.random.normal(0, daily_range / 2))
176+
low = min(opens[i], close) - abs(np.random.normal(0, daily_range / 2))
177+
178+
highs.append(round(high, 2))
179+
lows.append(round(low, 2))
180+
181+
base_volume = 100000
182+
vol_factor = 1.0 + 2.0 * (abs(price_change) / (volatility * opens[i]))
183+
volume = int(base_volume * vol_factor * np.random.uniform(0.8, 1.2))
184+
volumes.append(volume)
185+
186+
return {
187+
"Date": dates,
188+
"Open": opens,
189+
"High": highs,
190+
"Low": lows,
191+
"Close": closes,
192+
"Volume": volumes,
193+
}
194+
195+
196+
if __name__ == "__main__":
197+
data = generate_sample_data(
198+
start_date="2021-01-01",
199+
end_date="2021-03-31",
200+
start_price=100.0,
201+
volatility=0.015,
202+
)
203+
204+
fig, ax = generate_candlestick_chart(
205+
data, output_file="candle_stick.svg", show_volume=False
206+
)
207+
print(f"Generated basic candlestick chart with {len(data['Date'])} trading days.")
208+
209+
fig, ax = generate_candlestick_chart(
210+
data,
211+
output_file="candle_stick_with_volume.svg",
212+
show_volume=True,
213+
title="Stock Price with Volume",
214+
)
215+
print("Generated candlestick chart with volume.")

0 commit comments

Comments
 (0)