Skip to content

Commit c1037cd

Browse files
Updated fillna function in base data_processor. Some changes in akshare, tushare and baostock (#268)
* updated fillna function in base data_processor. Some changes in akshare, tushare and baostock * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * beautlfied by flake8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * beautlfied by flake8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f4f89d1 commit c1037cd

File tree

4 files changed

+112
-348
lines changed

4 files changed

+112
-348
lines changed

meta/data_processors/_base.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,51 @@ def clean_data(self):
9494
]
9595
]
9696

97+
def fillna(self):
98+
df = self.dataframe
99+
100+
dfcode = pd.DataFrame(columns=["tic"])
101+
dfdate = pd.DataFrame(columns=["time"])
102+
103+
dfcode.tic = df.tic.unique()
104+
dfdate.time = df.time.unique()
105+
dfdate.sort_values(by="time", ascending=False, ignore_index=True, inplace=True)
106+
107+
# the old pandas may not support pd.merge(how="cross")
108+
try:
109+
df1 = pd.merge(dfcode, dfdate, how="cross")
110+
except:
111+
print("Please wait for a few seconds...")
112+
df1 = pd.DataFrame(columns=["tic", "time"])
113+
for i in range(dfcode.shape[0]):
114+
for j in range(dfdate.shape[0]):
115+
df1 = df1.append(
116+
pd.DataFrame(
117+
data={
118+
"tic": dfcode.iat[i, 0],
119+
"time": dfdate.iat[j, 0],
120+
},
121+
index=[(i + 1) * (j + 1) - 1],
122+
)
123+
)
124+
125+
df = pd.merge(df1, df, how="left", on=["tic", "time"])
126+
127+
# back fill missing data then front fill
128+
df_new = pd.DataFrame(columns=df.columns)
129+
for i in df.tic.unique():
130+
df_tmp = df[df.tic == i].fillna(method="bfill").fillna(method="ffill")
131+
df_new = pd.concat([df_new, df_tmp], ignore_index=True)
132+
133+
df_new = df_new.fillna(0)
134+
135+
# reshape dataframe
136+
df_new = df_new.sort_values(by=["time", "tic"]).reset_index(drop=True)
137+
138+
print("Shape of DataFrame: ", df_new.shape)
139+
140+
self.dataframe = df_new
141+
97142
def get_trading_days(self, start: str, end: str) -> List[str]:
98143
if self.data_source in [
99144
"binance",
@@ -108,8 +153,12 @@ def get_trading_days(self, start: str, end: str) -> List[str]:
108153
return None
109154

110155
# select_stockstats_talib: 0 (stockstats, default), or 1 (use talib). Users can choose the method.
156+
# drop_na_timestep: 0 (not dropping timesteps that contain nan), or 1 (dropping timesteps that contain nan, default). Users can choose the method.
111157
def add_technical_indicator(
112-
self, tech_indicator_list: List[str], select_stockstats_talib: int = 0
158+
self,
159+
tech_indicator_list: List[str],
160+
select_stockstats_talib: int = 0,
161+
drop_na_timesteps: int = 1,
113162
):
114163
"""
115164
calculate technical indicators
@@ -189,8 +238,11 @@ def add_technical_indicator(
189238
self.dataframe = final_df
190239

191240
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
192-
time_to_drop = self.dataframe[self.dataframe.isna().any(axis=1)].time.unique()
193-
self.dataframe = self.dataframe[~self.dataframe.time.isin(time_to_drop)]
241+
if drop_na_timesteps:
242+
time_to_drop = self.dataframe[
243+
self.dataframe.isna().any(axis=1)
244+
].time.unique()
245+
self.dataframe = self.dataframe[~self.dataframe.time.isin(time_to_drop)]
194246
print("Succesfully add technical indicators")
195247

196248
def add_turbulence(self):

meta/data_processors/akshare.py

Lines changed: 19 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def download_data(
5656
):
5757
"""
5858
`pd.DataFrame`
59-
7 columns: A tick symbol, date, open, high, low, close and volume
59+
7 columns: A tick symbol, time, open, high, low, close and volume
6060
for the specified stock ticker
6161
"""
6262
assert self.time_interval in [
@@ -79,7 +79,7 @@ def download_data(
7979
time.sleep(0.25)
8080

8181
self.dataframe.columns = [
82-
"date",
82+
"time",
8383
"open",
8484
"close",
8585
"high",
@@ -93,23 +93,23 @@ def download_data(
9393
"tic",
9494
]
9595

96-
self.dataframe.sort_values(by=["date", "tic"], inplace=True)
96+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
9797
self.dataframe.reset_index(drop=True, inplace=True)
9898

9999
self.dataframe = self.dataframe[
100-
["tic", "date", "open", "high", "low", "close", "volume"]
100+
["tic", "time", "open", "high", "low", "close", "volume"]
101101
]
102102
# self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist()))
103-
self.dataframe["date"] = pd.to_datetime(
104-
self.dataframe["date"], format="%Y-%m-%d"
103+
self.dataframe["time"] = pd.to_datetime(
104+
self.dataframe["time"], format="%Y-%m-%d"
105105
)
106-
self.dataframe["day"] = self.dataframe["date"].dt.dayofweek
107-
self.dataframe["date"] = self.dataframe.date.apply(
106+
self.dataframe["day"] = self.dataframe["time"].dt.dayofweek
107+
self.dataframe["time"] = self.dataframe.time.apply(
108108
lambda x: x.strftime("%Y-%m-%d")
109109
)
110110

111111
self.dataframe.dropna(inplace=True)
112-
self.dataframe.sort_values(by=["date", "tic"], inplace=True)
112+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
113113
self.dataframe.reset_index(drop=True, inplace=True)
114114

115115
self.save_data(save_path)
@@ -118,155 +118,9 @@ def download_data(
118118
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
119119
)
120120

121-
def clean_data(self):
122-
dfc = copy.deepcopy(self.dataframe)
123-
124-
dfcode = pd.DataFrame(columns=["tic"])
125-
dfdate = pd.DataFrame(columns=["date"])
126-
127-
dfcode.tic = dfc.tic.unique()
128-
129-
if "time" in dfc.columns.values.tolist():
130-
dfc = dfc.rename(columns={"time": "date"})
131-
132-
dfdate.date = dfc.date.unique()
133-
dfdate.sort_values(by="date", ascending=False, ignore_index=True, inplace=True)
134-
135-
# the old pandas may not support pd.merge(how="cross")
136-
try:
137-
df1 = pd.merge(dfcode, dfdate, how="cross")
138-
except:
139-
print("Please wait for a few seconds...")
140-
df1 = pd.DataFrame(columns=["tic", "date"])
141-
for i in range(dfcode.shape[0]):
142-
for j in range(dfdate.shape[0]):
143-
df1 = df1.append(
144-
pd.DataFrame(
145-
data={
146-
"tic": dfcode.iat[i, 0],
147-
"date": dfdate.iat[j, 0],
148-
},
149-
index=[(i + 1) * (j + 1) - 1],
150-
)
151-
)
152-
153-
df2 = pd.merge(df1, dfc, how="left", on=["tic", "date"])
154-
155-
# back fill missing data then front fill
156-
df3 = pd.DataFrame(columns=df2.columns)
157-
for i in self.ticker_list:
158-
df4 = df2[df2.tic == i].fillna(method="bfill").fillna(method="ffill")
159-
df3 = pd.concat([df3, df4], ignore_index=True)
160-
161-
df3 = df3.fillna(0)
162-
163-
# reshape dataframe
164-
df3 = df3.sort_values(by=["date", "tic"]).reset_index(drop=True)
165-
166-
if "date" in self.dataframe.columns.values.tolist():
167-
self.dataframe.rename(columns={"date": "time"}, inplace=True)
168-
if "datetime" in self.dataframe.columns.values.tolist():
169-
self.dataframe.rename(columns={"datetime": "time"}, inplace=True)
170-
171-
print("Shape of DataFrame: ", df3.shape)
172-
173-
self.dataframe = df3
174-
175-
def add_technical_indicator(
176-
self,
177-
tech_indicator_list: List[str],
178-
select_stockstats_talib: int = 0,
179-
drop_na_timestpe: int = 0,
180-
):
181-
"""
182-
calculate technical indicators
183-
use stockstats/talib package to add technical inidactors
184-
:param data: (df) pandas dataframe
185-
:return: (df) pandas dataframe
186-
"""
187-
if "date" in self.dataframe.columns.values.tolist():
188-
self.dataframe.rename(columns={"date": "time"}, inplace=True)
189-
190-
if self.data_source == "ccxt":
191-
self.dataframe.rename(columns={"index": "time"}, inplace=True)
192-
193-
self.dataframe.reset_index(drop=False, inplace=True)
194-
if "level_1" in self.dataframe.columns:
195-
self.dataframe.drop(columns=["level_1"], inplace=True)
196-
if "level_0" in self.dataframe.columns and "tic" not in self.dataframe.columns:
197-
self.dataframe.rename(columns={"level_0": "tic"}, inplace=True)
198-
assert select_stockstats_talib in {0, 1}
199-
print("tech_indicator_list: ", tech_indicator_list)
200-
if select_stockstats_talib == 0: # use stockstats
201-
stock = stockstats.StockDataFrame.retype(self.dataframe)
202-
unique_ticker = stock.tic.unique()
203-
for indicator in tech_indicator_list:
204-
print("indicator: ", indicator)
205-
indicator_df = pd.DataFrame()
206-
for i in range(len(unique_ticker)):
207-
try:
208-
temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
209-
temp_indicator = pd.DataFrame(temp_indicator)
210-
temp_indicator["tic"] = unique_ticker[i]
211-
temp_indicator["time"] = self.dataframe[
212-
self.dataframe.tic == unique_ticker[i]
213-
]["time"].to_list()
214-
indicator_df = pd.concat(
215-
[indicator_df, temp_indicator],
216-
axis=0,
217-
join="outer",
218-
ignore_index=True,
219-
)
220-
except Exception as e:
221-
print(e)
222-
if not indicator_df.empty:
223-
self.dataframe = self.dataframe.merge(
224-
indicator_df[["tic", "time", indicator]],
225-
on=["tic", "time"],
226-
how="left",
227-
)
228-
else: # use talib
229-
final_df = pd.DataFrame()
230-
for i in self.dataframe.tic.unique():
231-
tic_df = self.dataframe[self.dataframe.tic == i]
232-
(
233-
tic_df.loc["macd"],
234-
tic_df.loc["macd_signal"],
235-
tic_df.loc["macd_hist"],
236-
) = talib.MACD(
237-
tic_df["close"],
238-
fastperiod=12,
239-
slowperiod=26,
240-
signalperiod=9,
241-
)
242-
tic_df.loc["rsi"] = talib.RSI(tic_df["close"], timeperiod=14)
243-
tic_df.loc["cci"] = talib.CCI(
244-
tic_df["high"],
245-
tic_df["low"],
246-
tic_df["close"],
247-
timeperiod=14,
248-
)
249-
tic_df.loc["dx"] = talib.DX(
250-
tic_df["high"],
251-
tic_df["low"],
252-
tic_df["close"],
253-
timeperiod=14,
254-
)
255-
final_df = pd.concat([final_df, tic_df], axis=0, join="outer")
256-
self.dataframe = final_df
257-
258-
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
259-
if drop_na_timestpe:
260-
time_to_drop = self.dataframe[
261-
self.dataframe.isna().any(axis=1)
262-
].time.unique()
263-
self.dataframe = self.dataframe[~self.dataframe.time.isin(time_to_drop)]
264-
self.dataframe.rename(columns={"time": "date"}, inplace=True)
265-
print("Succesfully add technical indicators")
266-
267-
def data_split(self, df, start, end, target_date_col="date"):
121+
def data_split(self, df, start, end, target_date_col="time"):
268122
"""
269-
split the dataset into training or testing using date
123+
split the dataset into training or testing using time
270124
:param data: (df) pandas dataframe, start, end
271125
:return: (df) pandas dataframe
272126
"""
@@ -285,11 +139,11 @@ def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
285139
# assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
286140
return n
287141

288-
def transfer_date(self, date: str) -> str:
289-
if "-" in date:
290-
date = "".join(date.split("-"))
291-
elif "." in date:
292-
date = "".join(date.split("."))
293-
elif "/" in date:
294-
date = "".join(date.split("/"))
295-
return date
142+
def transfer_date(self, time: str) -> str:
143+
if "-" in time:
144+
time = "".join(time.split("-"))
145+
elif "." in time:
146+
time = "".join(time.split("."))
147+
elif "/" in time:
148+
time = "".join(time.split("/"))
149+
return time

meta/data_processors/baostock.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def download_data(
8484
)
8585
bs.logout()
8686

87+
self.dataframe.open = self.dataframe.open.astype(float)
88+
self.dataframe.high = self.dataframe.high.astype(float)
89+
self.dataframe.low = self.dataframe.low.astype(float)
90+
self.dataframe.close = self.dataframe.close.astype(float)
8791
self.save_data(save_path)
8892

8993
print(

0 commit comments

Comments
 (0)