@@ -56,7 +56,7 @@ def download_data(
5656 ):
5757 """
5858 `pd.DataFrame`
59- 7 columns: A tick symbol, date , open, high, low, close and volume
59+ 7 columns: A tick symbol, time , open, high, low, close and volume
6060 for the specified stock ticker
6161 """
6262 assert self .time_interval in [
@@ -79,7 +79,7 @@ def download_data(
7979 time .sleep (0.25 )
8080
8181 self .dataframe .columns = [
82- "date " ,
82+ "time " ,
8383 "open" ,
8484 "close" ,
8585 "high" ,
@@ -93,23 +93,23 @@ def download_data(
9393 "tic" ,
9494 ]
9595
96- self .dataframe .sort_values (by = ["date " , "tic" ], inplace = True )
96+ self .dataframe .sort_values (by = ["time " , "tic" ], inplace = True )
9797 self .dataframe .reset_index (drop = True , inplace = True )
9898
9999 self .dataframe = self .dataframe [
100- ["tic" , "date " , "open" , "high" , "low" , "close" , "volume" ]
100+ ["tic" , "time " , "open" , "high" , "low" , "close" , "volume" ]
101101 ]
102102 # self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist()))
103- self .dataframe ["date " ] = pd .to_datetime (
104- self .dataframe ["date " ], format = "%Y-%m-%d"
103+ self .dataframe ["time " ] = pd .to_datetime (
104+ self .dataframe ["time " ], format = "%Y-%m-%d"
105105 )
106- self .dataframe ["day" ] = self .dataframe ["date " ].dt .dayofweek
107- self .dataframe ["date " ] = self .dataframe .date .apply (
106+ self .dataframe ["day" ] = self .dataframe ["time " ].dt .dayofweek
107+ self .dataframe ["time " ] = self .dataframe .time .apply (
108108 lambda x : x .strftime ("%Y-%m-%d" )
109109 )
110110
111111 self .dataframe .dropna (inplace = True )
112- self .dataframe .sort_values (by = ["date " , "tic" ], inplace = True )
112+ self .dataframe .sort_values (by = ["time " , "tic" ], inplace = True )
113113 self .dataframe .reset_index (drop = True , inplace = True )
114114
115115 self .save_data (save_path )
@@ -118,155 +118,9 @@ def download_data(
118118 f"Download complete! Dataset saved to { save_path } . \n Shape of DataFrame: { self .dataframe .shape } "
119119 )
120120
121- def clean_data (self ):
122- dfc = copy .deepcopy (self .dataframe )
123-
124- dfcode = pd .DataFrame (columns = ["tic" ])
125- dfdate = pd .DataFrame (columns = ["date" ])
126-
127- dfcode .tic = dfc .tic .unique ()
128-
129- if "time" in dfc .columns .values .tolist ():
130- dfc = dfc .rename (columns = {"time" : "date" })
131-
132- dfdate .date = dfc .date .unique ()
133- dfdate .sort_values (by = "date" , ascending = False , ignore_index = True , inplace = True )
134-
135- # the old pandas may not support pd.merge(how="cross")
136- try :
137- df1 = pd .merge (dfcode , dfdate , how = "cross" )
138- except :
139- print ("Please wait for a few seconds..." )
140- df1 = pd .DataFrame (columns = ["tic" , "date" ])
141- for i in range (dfcode .shape [0 ]):
142- for j in range (dfdate .shape [0 ]):
143- df1 = df1 .append (
144- pd .DataFrame (
145- data = {
146- "tic" : dfcode .iat [i , 0 ],
147- "date" : dfdate .iat [j , 0 ],
148- },
149- index = [(i + 1 ) * (j + 1 ) - 1 ],
150- )
151- )
152-
153- df2 = pd .merge (df1 , dfc , how = "left" , on = ["tic" , "date" ])
154-
155- # back fill missing data then front fill
156- df3 = pd .DataFrame (columns = df2 .columns )
157- for i in self .ticker_list :
158- df4 = df2 [df2 .tic == i ].fillna (method = "bfill" ).fillna (method = "ffill" )
159- df3 = pd .concat ([df3 , df4 ], ignore_index = True )
160-
161- df3 = df3 .fillna (0 )
162-
163- # reshape dataframe
164- df3 = df3 .sort_values (by = ["date" , "tic" ]).reset_index (drop = True )
165-
166- if "date" in self .dataframe .columns .values .tolist ():
167- self .dataframe .rename (columns = {"date" : "time" }, inplace = True )
168- if "datetime" in self .dataframe .columns .values .tolist ():
169- self .dataframe .rename (columns = {"datetime" : "time" }, inplace = True )
170-
171- print ("Shape of DataFrame: " , df3 .shape )
172-
173- self .dataframe = df3
174-
175- def add_technical_indicator (
176- self ,
177- tech_indicator_list : List [str ],
178- select_stockstats_talib : int = 0 ,
179- drop_na_timestpe : int = 0 ,
180- ):
181- """
182- calculate technical indicators
183- use stockstats/talib package to add technical inidactors
184- :param data: (df) pandas dataframe
185- :return: (df) pandas dataframe
186- """
187- if "date" in self .dataframe .columns .values .tolist ():
188- self .dataframe .rename (columns = {"date" : "time" }, inplace = True )
189-
190- if self .data_source == "ccxt" :
191- self .dataframe .rename (columns = {"index" : "time" }, inplace = True )
192-
193- self .dataframe .reset_index (drop = False , inplace = True )
194- if "level_1" in self .dataframe .columns :
195- self .dataframe .drop (columns = ["level_1" ], inplace = True )
196- if "level_0" in self .dataframe .columns and "tic" not in self .dataframe .columns :
197- self .dataframe .rename (columns = {"level_0" : "tic" }, inplace = True )
198- assert select_stockstats_talib in {0 , 1 }
199- print ("tech_indicator_list: " , tech_indicator_list )
200- if select_stockstats_talib == 0 : # use stockstats
201- stock = stockstats .StockDataFrame .retype (self .dataframe )
202- unique_ticker = stock .tic .unique ()
203- for indicator in tech_indicator_list :
204- print ("indicator: " , indicator )
205- indicator_df = pd .DataFrame ()
206- for i in range (len (unique_ticker )):
207- try :
208- temp_indicator = stock [stock .tic == unique_ticker [i ]][indicator ]
209- temp_indicator = pd .DataFrame (temp_indicator )
210- temp_indicator ["tic" ] = unique_ticker [i ]
211- temp_indicator ["time" ] = self .dataframe [
212- self .dataframe .tic == unique_ticker [i ]
213- ]["time" ].to_list ()
214- indicator_df = pd .concat (
215- [indicator_df , temp_indicator ],
216- axis = 0 ,
217- join = "outer" ,
218- ignore_index = True ,
219- )
220- except Exception as e :
221- print (e )
222- if not indicator_df .empty :
223- self .dataframe = self .dataframe .merge (
224- indicator_df [["tic" , "time" , indicator ]],
225- on = ["tic" , "time" ],
226- how = "left" ,
227- )
228- else : # use talib
229- final_df = pd .DataFrame ()
230- for i in self .dataframe .tic .unique ():
231- tic_df = self .dataframe [self .dataframe .tic == i ]
232- (
233- tic_df .loc ["macd" ],
234- tic_df .loc ["macd_signal" ],
235- tic_df .loc ["macd_hist" ],
236- ) = talib .MACD (
237- tic_df ["close" ],
238- fastperiod = 12 ,
239- slowperiod = 26 ,
240- signalperiod = 9 ,
241- )
242- tic_df .loc ["rsi" ] = talib .RSI (tic_df ["close" ], timeperiod = 14 )
243- tic_df .loc ["cci" ] = talib .CCI (
244- tic_df ["high" ],
245- tic_df ["low" ],
246- tic_df ["close" ],
247- timeperiod = 14 ,
248- )
249- tic_df .loc ["dx" ] = talib .DX (
250- tic_df ["high" ],
251- tic_df ["low" ],
252- tic_df ["close" ],
253- timeperiod = 14 ,
254- )
255- final_df = pd .concat ([final_df , tic_df ], axis = 0 , join = "outer" )
256- self .dataframe = final_df
257-
258- self .dataframe .sort_values (by = ["time" , "tic" ], inplace = True )
259- if drop_na_timestpe :
260- time_to_drop = self .dataframe [
261- self .dataframe .isna ().any (axis = 1 )
262- ].time .unique ()
263- self .dataframe = self .dataframe [~ self .dataframe .time .isin (time_to_drop )]
264- self .dataframe .rename (columns = {"time" : "date" }, inplace = True )
265- print ("Succesfully add technical indicators" )
266-
267- def data_split (self , df , start , end , target_date_col = "date" ):
121+ def data_split (self , df , start , end , target_date_col = "time" ):
268122 """
269- split the dataset into training or testing using date
123+ split the dataset into training or testing using time
270124 :param data: (df) pandas dataframe, start, end
271125 :return: (df) pandas dataframe
272126 """
@@ -285,11 +139,11 @@ def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
285139 # assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
286140 return n
287141
288- def transfer_date (self , date : str ) -> str :
289- if "-" in date :
290- date = "" .join (date .split ("-" ))
291- elif "." in date :
292- date = "" .join (date .split ("." ))
293- elif "/" in date :
294- date = "" .join (date .split ("/" ))
295- return date
142+ def transfer_date (self , time : str ) -> str :
143+ if "-" in time :
144+ time = "" .join (time .split ("-" ))
145+ elif "." in time :
146+ time = "" .join (time .split ("." ))
147+ elif "/" in time :
148+ time = "" .join (time .split ("/" ))
149+ return time
0 commit comments