-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmain.py
More file actions
272 lines (246 loc) · 14 KB
/
main.py
File metadata and controls
272 lines (246 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import argparse
import torch
import random
import numpy as np
import pandas as pd
import os
import time
import uuid
from pathlib import Path
import yfinance as yf
from exp.exp_DeepAries import Exp_DeepAries
import torch.multiprocessing as mp
from utils.tools import initialize_logger, fix_seed
from data_provider import data_prep # data_prep 모듈 내에 YfinancePreprocessor와 관련 함수들이 정의되어 있음
##############################################
# MAIN FUNCTION: DATA PREPARATION & TRAINING
##############################################
def main():
parser = argparse.ArgumentParser(
description='Transformer Family and DeepAries for Time Series Forecasting'
)
# [General Settings]
parser.add_argument('--seed', type=int, help='Random seed for reproducibility')
parser.add_argument('--model', type=str, default='iTransformer',
help='Model type. Options: [Transformer, iTransformer]')
parser.add_argument('--is_training', type=int, default=1,
help='Training flag (1: train, 0: inference)')
parser.add_argument('--temperature', type=float, default=1.0,
help='Temperature parameter for softmax (used in model components)')
# [Data Settings]
parser.add_argument('--market', type=str, default='dj30',
help='Market dataset to use. Options: [dj30, kospi, csi300, ftse, etc.]')
parser.add_argument('--data', type=str, default='general',
help='Data type. Options: [general, alpha158]')
parser.add_argument('--root_path', type=str,
help='Root path for the dataset. Defaults to "./data/<market>/" if not provided')
parser.add_argument('--data_path', type=str,
help='CSV filename of the dataset. Defaults to "<market>_data.csv" if not provided')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/',
help='Directory to save model checkpoints')
parser.add_argument('--checkpoint_dir', type=str, default=None,
help='Checkpoint directory name for loading pre-trained model. If not specified, will auto-search for matching checkpoint.')
# [Forecasting Task Settings]
parser.add_argument('--valid_year', type=str, default='2020-12-31',
help='Validation period end date')
parser.add_argument('--test_year', type=str, default='2021-12-31',
help='Test period start date')
parser.add_argument('--seq_len', type=int, default=20,
help='Input sequence length')
parser.add_argument('--label_len', type=int, default=5,
help='Label length for decoder start token')
parser.add_argument('--pred_len', type=int, default=20,
help='Prediction sequence length')
parser.add_argument('--freq', type=str, default='d',
help='Frequency for time feature encoding (e.g., s, t, h, d, b, w, m)')
parser.add_argument('--horizons', nargs='+', type=int, default=[1, 5, 20],
help='List of forecasting horizons for multi-horizon trading (e.g., 1 5 20)')
# [Model Architecture Settings]
parser.add_argument('--enc_in', type=int, help='Encoder input size (auto-detected from data)', required=False)
parser.add_argument('--dec_in', type=int, help='Decoder input size (auto-detected from data)', required=False)
parser.add_argument('--c_out', type=int, default=1, help='Output size')
parser.add_argument('--d_model', type=int, default=512, help='Model (hidden) dimension')
parser.add_argument('--n_heads', type=int, default=8, help='Number of attention heads')
parser.add_argument('--e_layers', type=int, default=2, help='Number of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='Number of decoder layers')
parser.add_argument('--d_ff', type=int, default=2048, help='Dimension of the feed-forward network')
parser.add_argument('--dropout', type=float, default=0.05, help='Dropout rate')
parser.add_argument('--output_attention', action='store_true',
help='Output attention weights from the encoder')
parser.add_argument('--factor', type=int, default=1, help='Attention factor (specific to certain models)')
parser.add_argument('--embed', type=str, default='timeF',
help='Time features encoding type. Options: [timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='gelu', help='Activation function')
parser.add_argument('--use_adappo', type=bool, default=True,
help='Use Adappo (reinforcement learning) if True, else use supervised learning')
# [Optimization Settings]
parser.add_argument('--num_workers', type=int, default=1,
help='Number of workers for data loading')
parser.add_argument('--itr', type=int, default=1,
help='Number of experiment iterations')
parser.add_argument('--train_epochs', type=int, default=1,
help='Number of training epochs')
parser.add_argument('--patience', type=int, default=3,
help='Early stopping patience')
parser.add_argument('--learning_rate', type=float, default=1e-5,
help='Learning rate')
parser.add_argument('--use_amp', action='store_true', default=False,
help='Enable automatic mixed precision training')
parser.add_argument('--lradj', type=str, default='type1',
help='Learning rate adjustment type')
# [GPU Settings]
parser.add_argument('--use_gpu', type=bool, default=True,
help='Use GPU if available')
parser.add_argument('--gpu', type=int, default=0,
help='GPU id to use')
parser.add_argument('--use_multi_gpu', action='store_true', default=False,
help='Use multiple GPUs')
parser.add_argument('--devices', type=str, default='0,1',
help='Comma separated list of GPU device ids')
# [Portfolio Management Settings]
parser.add_argument('--fee_rate', type=float, default=0.0001,
help='Transaction fee rate')
parser.add_argument('--complex_fee', action='store_true', default=False,
help='Enable complex fee calculation')
parser.add_argument('--num_stocks', type=int, default=20,
help='Number of stocks to include in the portfolio')
parser.add_argument('--total_stocks', type=int, required=False,
help='Total number of stocks in the dataset')
args = parser.parse_args()
# --- Automatically set data paths ---
if not args.root_path:
args.root_path = f'./data/{args.market}/'
if not args.data_path:
args.data_path = f'{args.market}_data.csv'
# Create data directory if it doesn't exist
data_dir = Path(args.root_path)
data_dir.mkdir(parents=True, exist_ok=True)
# --- Use the last horizon as prediction length ---
args.pred_len = args.horizons[-1]
# --- Check if raw data file exists; if not, fetch and save ticker data ---
raw_data_csv = os.path.join(args.root_path, args.data_path)
if not os.path.exists(raw_data_csv):
tickers_csv = f'complete_{args.market.lower()}_tickers.csv'
if os.path.exists(tickers_csv):
tickers = [str(ticker) for ticker in pd.read_csv(tickers_csv)['ticker'].tolist()]
print("Tickers to fetch:", data_prep.format_tickers(tickers))
data_prep.fetch_and_save_ticker_data(
ticker_list=tickers,
output_csv=raw_data_csv,
start_date='2000-01-01',
end_date='2025-02-09',
threshold=0.1
)
else:
print(f"Ticker list file {tickers_csv} not found. Please provide ticker list.")
return
else:
print(f"Raw data file {raw_data_csv} already exists.")
# --- Check if processed data file exists; if not, generate features ---
processed_data_csv = os.path.join(args.root_path, f"{args.market}_{args.data}_data.csv")
if not os.path.exists(processed_data_csv):
preprocessor = data_prep.YfinancePreprocessor(input_path=raw_data_csv, output_path=processed_data_csv)
preprocessor.make_feature()
else:
print(f"Processed data file {processed_data_csv} already exists.")
# --- Determine input dimensions and total stocks from processed data file ---
try:
data = pd.read_csv(processed_data_csv)
# Assuming CSV contains date and tic columns plus features
num_features = data.shape[1] - 2 # Excluding date and tic
args.enc_in = num_features if args.enc_in is None else args.enc_in
args.dec_in = num_features if args.dec_in is None else args.dec_in
args.total_stocks = data['tic'].nunique()
if (not args.num_stocks) or (args.num_stocks > args.total_stocks):
args.num_stocks = args.total_stocks
print(f"Detected {num_features} input features across {args.total_stocks} stocks. Using {args.num_stocks} stocks for training.")
print(f"Setting enc_in={args.enc_in}, dec_in={args.dec_in}.")
except Exception as e:
print(f"Error loading data from {processed_data_csv}: {e}")
return
# --- Create a unique experiment identifier and result directory ---
setting_components = [
f"{args.model}",
"DeepAries",
args.market,
args.data,
f"num_stocks({args.num_stocks})",
f"sl({args.seq_len})",
f"pl({args.pred_len})"
]
setting = "_".join(setting_components)
unique_id = uuid.uuid4().hex[:8]
unique_setting = f"{setting}_{unique_id}"
result_dir = os.path.join("./results", unique_setting)
os.makedirs(result_dir, exist_ok=True)
# --- Initialize logger ---
global logger
logger = initialize_logger(result_dir)
logger.info(f"Dataset root path: {args.root_path}")
logger.info(f"Dataset file: {args.data_path}")
# --- GPU Configuration ---
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
if args.use_gpu and args.use_multi_gpu:
args.devices = args.devices.replace(' ', '')
device_ids = args.devices.split(',')
args.device_ids = [int(id_) for id_ in device_ids]
args.gpu = args.device_ids[0]
# --- Set random seed for reproducibility ---
if args.seed is None:
args.seed = int(time.time()) % (2 ** 32)
print(f"No seed provided. Generated random seed: {args.seed}")
fix_seed(args.seed)
# --- Run Training or Inference using DeepAries ---
if args.is_training:
exp = Exp_DeepAries(args, unique_setting)
exp.train(unique_setting)
logger.info(f"DeepAries Backtesting: {setting}")
# After training, model is already loaded in memory, no need to load from checkpoint
exp.backtest(unique_setting, load=False)
else:
# --- Checkpoint loading for inference ---
checkpoint_setting = None
if args.checkpoint_dir:
# Use specified checkpoint directory
checkpoint_setting = args.checkpoint_dir
logger.info(f"Using specified checkpoint directory: {checkpoint_setting}")
else:
# Auto-search for matching checkpoint
import glob
# First, try to find pretrained model (preferred)
pretrained_pattern = os.path.join(args.checkpoints, f"{setting}_pretrained")
if os.path.exists(pretrained_pattern):
checkpoint_setting = os.path.basename(pretrained_pattern)
logger.info(f"Auto-found pretrained checkpoint: {checkpoint_setting}")
else:
# Then try to find any matching checkpoint (training outputs)
checkpoint_pattern = os.path.join(args.checkpoints, f"{setting}_*")
matching_dirs = glob.glob(checkpoint_pattern)
if matching_dirs:
# Use the most recent matching checkpoint
checkpoint_setting = os.path.basename(matching_dirs[0])
logger.info(f"Auto-found checkpoint directory: {checkpoint_setting}")
else:
# Default: try to find dj30 iTransformer pretrained checkpoint
default_pattern = os.path.join(args.checkpoints, "iTransformer_DeepAries_dj30_*_pretrained")
default_dirs = glob.glob(default_pattern)
if default_dirs:
checkpoint_setting = os.path.basename(default_dirs[0])
logger.info(f"Using default pretrained checkpoint: {checkpoint_setting}")
else:
# Fallback: try any iTransformer_DeepAries_dj30 checkpoint
fallback_pattern = os.path.join(args.checkpoints, "iTransformer_DeepAries_dj30_*")
fallback_dirs = glob.glob(fallback_pattern)
if fallback_dirs:
checkpoint_setting = os.path.basename(fallback_dirs[0])
logger.info(f"Using fallback checkpoint: {checkpoint_setting}")
else:
logger.warning(f"No matching checkpoint found for setting: {setting}")
logger.warning(f"Please specify --checkpoint_dir or ensure checkpoint exists")
checkpoint_setting = unique_setting # Fallback to new setting (will fail with clear error)
exp = Exp_DeepAries(args, unique_setting)
logger.info(f"DeepAries Backtesting: {setting}")
exp.backtest(checkpoint_setting, 1) # Pass checkpoint_setting instead of unique_setting
torch.cuda.empty_cache()
if __name__ == "__main__":
main()