-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path機器學習(SVM)(optuna)(標準化).py
More file actions
111 lines (98 loc) · 3.69 KB
/
機器學習(SVM)(optuna)(標準化).py
File metadata and controls
111 lines (98 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import yfinance as yf
import pandas as pd
import pandas_datareader as data
from datetime import datetime
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn import ensemble,metrics
from sklearn.metrics import classification_report,confusion_matrix
from xgboost.sklearn import XGBClassifier
from sklearn import tree
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt
from scikitplot.metrics import plot_confusion_matrix
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.model_selection import cross_val_score
import optuna
from sklearn import preprocessing
yf.pdr_override()
data=pd.read_csv('機器學習_資料(股價)(訓練集).csv')
condition=data['2308 adjclose']>data['2308 adjclose'].shift(1)
data['2308 adjclose']=condition
data['2308 adjclose']=data['2308 adjclose'].astype(int)
data['2308 adjclose']=data['2308 adjclose'].shift(-1)
data=data.drop([2266])
data['2308 adjclose']=data['2308 adjclose'].astype(int)
x=data.drop(['date','2308 adjclose'],axis=1).copy()
y=data['2308 adjclose']
zscore=preprocessing.StandardScaler()
x_zs=zscore.fit_transform(x)
y=np.array(data['2308 adjclose'])
print(y)
x_train, x_test, y_train, y_test = train_test_split(x_zs, y, test_size= 0.2, random_state = 5)
x_train = np.array(x_train)
x_test = np.array(x_test)
y_train = np.array(y_train)
y_test = np.array(y_test)
print(x_train)
print(y_train)
from sklearn.svm import SVR
from sklearn.svm import SVC
def objective(trial):
kernel = trial.suggest_categorical('kernel', ['rbf', 'sigmoid'])
c = trial.suggest_float("C", 0.1, 10.0)
gamma = trial.suggest_categorical('gamma', ['auto', 'scale'])
model=SVC(
kernel=kernel,
gamma=gamma,
C=c,
random_state=4,
probability=True
)
model.fit(x_train,y_train)
return 1.0 - metrics.f1_score(y_test, model.predict(x_test))
study=optuna.create_study()
study.optimize(objective,n_trials=100)
print(study.best_params)
print(1.0-study.best_value)
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_param_importances
import plotly.express as plotly
plotly_config={"staticPlot": True}
fig=plot_optimization_history(study)
fig.show(config=plotly_config)
fig = plot_param_importances(study)
fig.show(config=plotly_config)
model=SVC(
kernel=study.best_params['kernel'],
gamma=study.best_params['gamma'],
C=study.best_params['C'],
random_state=4,
probability=True
)
model.fit(x_train,y_train)
y_hat=model.predict(x_test)
accuracy=metrics.accuracy_score(y_test,y_hat)
f1_score=metrics.f1_score(y_test,y_hat)
recall=metrics.recall_score(y_test,y_hat)
precision=metrics.precision_score(y_test,y_hat)
print(f'Accuracy:{accuracy:.5f}/ F1 Score:{f1_score:.5f}/Recall:{recall:.5f}/ Precision:{precision:.5f}')
#print('each f1:',CV5F_acc)
#print('Average f1:',round((np.mean(CV5F_acc))*100,2),'+/-',round((np.std(CV5F_acc))*100,2))
print('****svm***')
plot_confusion_matrix(y_test,y_hat,cmap='Accent')
#plt.show()
test_data=pd.read_csv('機器學習_資料(測試集)(股價).csv')
zscore=preprocessing.StandardScaler()
test_data_zscore=zscore.fit_transform(test_data)
predict_result=model.predict_proba(test_data_zscore)
#print(predict_result)
predict_result_df=pd.DataFrame(predict_result)
#print(predict_result_df)
predict_result_df.to_csv('機器學習輸出機率(svm)(標準化).csv')
plt.show()