-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
132 lines (109 loc) · 4.18 KB
/
main.py
File metadata and controls
132 lines (109 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import multiprocessing
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, GCN2Conv, SAGEConv, GATConv, HGTConv, Linear
from torch_geometric.utils import to_undirected
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, ParameterGrid
from sklearn.metrics import roc_auc_score
import os
from utils import get_data
from train_model import CV_train
from sklearn.metrics import roc_curve, auc
import time
from xgboost import XGBClassifier
from sklearn.model_selection import GridSearchCV, ParameterGrid, train_test_split
import joblib
from models import ModelSelector,MGTCDA
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class Config:
def __init__(self):
self.datapath = './data/'
self.save_file = 'save_file/'
self.kfold = 5
self.maskMDI = False
self.self_encode_len = 256
self.globel_random = 222 # 120
self.other_args = {'arg_name': [], 'arg_value': []}
# 解码参数
self.epochs = 1000 ## 1000
self.print_epoch = 100 ## 20
def set_attr(config, param_search):
param_grid = param_search
param_keys = param_grid.keys()
param_grid_list = list(ParameterGrid(param_grid))
for param in param_grid_list:
config.other_args = {'arg_name': [], 'arg_value': []}
for keys in param_keys:
setattr(config, keys, param[keys])
config.other_args['arg_name'].append(keys)
print(keys, param[keys])
config.other_args['arg_value'].append(param[keys])
yield config#迭代
return 0
class Data_paths:
def __init__(self):
self.paths = './data/'
self.cd = self.paths + 'Association Matrixss.csv'
self.cc = [self.paths + 'integrated_circ_sim.csv', self.paths + 'Pearson_cc.csv']
self.dd = [self.paths + 'MeSHSemanticSimilarity.csv', self.paths + 'integrated_dise_sim.csv']
if __name__ == '__main__':
set_seed(521)
best_param_search = {
'hidden_channels': [128],
'num_heads': [4],
'num_layers': [5],
}
param_search = best_param_search
save_file = '5cv_data_1000'
params_all = Config()
param_generator = set_attr(params_all, param_search)
data_list = []
filepath = Data_paths()
while True:
try:
params = next(param_generator)
except:
break
data, y, edg_index_all = get_data(file_pair=filepath, params=params)
data_tuple = get_data(file_pair=filepath, params=params)
data_idx, auc_name = CV_train(params, data_tuple) # 交叉验证
for i in range(1, 6):
kf = i
file_name = './mid_data/' + str(6) + 'nl' + str(kf) + 'kf_best_cat_data.dict'
while True:
if os.path.exists(file_name):
break
else:
time.sleep(1)
continue
data_load = joblib.load(file_name)
print( './mid_data/' + str(kf) + 'kf_best_cat_data.dict')
selector = ModelSelector()
X_train, X_test, y_train, y_test = data_load['train_data'], data_load['test_data'], data_load['y_train'], data_load[
'y_test']
# 获取模型并进行训练
model_list = [] # 选择模型
models = selector.get_models(model_list)
ls_dict = selector.train_with_grid_search(X_train, np.reshape(y_train, (-1,)), X_test,
np.reshape(y_test, (-1,)), models)
data_list.append(ls_dict)
if data_list is not None:
np.save(params_all.save_file + save_file + '.npy', np.array(data_list,dtype=object))
# 其他操作
data_idx = np.load(params_all.save_file + save_file + '.npy', allow_pickle=True)
data =data_idx[:, :, :,2:].reshape(data_idx.shape[0],7)
mean=np.mean(data,axis=0)
print()
print('7个评价指标的平均值分别为:')
print(mean)