forked from bgruening/galaxytools
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
161 lines (150 loc) · 5.5 KB
/
main.py
File metadata and controls
161 lines (150 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Tabular data prediction using TabPFN
"""
import argparse
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from catboost import CatBoostClassifier, CatBoostRegressor
from sklearn.metrics import (
average_precision_score,
precision_recall_curve,
r2_score,
root_mean_squared_error,
)
from sklearn.preprocessing import label_binarize
from tabpfn import TabPFNClassifier, TabPFNRegressor
def separate_features_labels(data):
df = pd.read_csv(data, sep="\t")
labels = df.iloc[:, -1]
features = df.iloc[:, :-1]
return features, labels
def classification_plot(y_true, y_scores, m_name):
plt.figure(figsize=(8, 6))
is_binary = len(np.unique(y_true)) == 2
if is_binary:
# Compute precision-recall curve
precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1])
average_precision = average_precision_score(y_true, y_scores[:, 1])
plt.plot(
recall,
precision,
label=f"Precision-Recall Curve (AP={average_precision:.2f})",
)
plt.title("{}: Precision-Recall Curve (binary classification)".format(m_name))
else:
y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
n_classes = y_true_bin.shape[1]
class_labels = [f"Class {i}" for i in range(n_classes)]
# Plot PR curve for each class
for i in range(n_classes):
precision, recall, _ = precision_recall_curve(
y_true_bin[:, i], y_scores[:, i]
)
ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i])
plt.plot(
recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})"
)
# Compute micro-average PR curve
precision, recall, _ = precision_recall_curve(
y_true_bin.ravel(), y_scores.ravel()
)
plt.plot(
recall, precision, linestyle="--", color="black", label="Micro-average"
)
plt.title(
"{}: Precision-Recall Curve (Multiclass Classification)".format(m_name)
)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend(loc="lower left")
plt.grid(True)
plt.savefig("output_plot_{}.png".format(m_name))
def regression_plot(xval, yval, title, xlabel, ylabel, m_name):
plt.figure(figsize=(8, 6))
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.legend(loc="lower left")
plt.grid(True)
plt.scatter(xval, yval, alpha=0.8)
xticks = np.arange(len(xval))
plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
plt.savefig("output_plot_{}.png".format(m_name))
def train_evaluate(args):
"""
Train TabPFN and predict
"""
# prepare train data
tr_features, tr_labels = separate_features_labels(args["train_data"])
# prepare test data
if args["testhaslabels"] == "haslabels":
te_features, te_labels = separate_features_labels(args["test_data"])
else:
te_features = pd.read_csv(args["test_data"], sep="\t")
te_labels = []
s_time = time.time()
if args["selected_task"] == "Classification":
models = [
("TabPFN", TabPFNClassifier(random_state=42)),
("CatBoost", CatBoostClassifier(random_state=42, verbose=0)),
]
for m_name, model in models:
model.fit(tr_features, tr_labels)
y_eval = model.predict(te_features)
pred_probas_test = model.predict_proba(te_features)
if len(te_labels) > 0:
classification_plot(te_labels, pred_probas_test, m_name)
te_features["predicted_labels"] = y_eval
te_features.to_csv(
"output_predicted_data_{}".format(m_name), sep="\t", index=None
)
else:
models = [
("TabPFN", TabPFNRegressor(random_state=42)),
("CatBoost", CatBoostRegressor(random_state=42, verbose=0)),
]
for m_name, model in models:
model.fit(tr_features, tr_labels)
y_eval = model.predict(te_features)
if len(te_labels) > 0:
score = root_mean_squared_error(te_labels, y_eval)
r2_metric_score = r2_score(te_labels, y_eval)
regression_plot(
te_labels,
y_eval,
f"Scatter plot for {m_name}: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
"True values",
"Predicted values",
m_name,
)
te_features["predicted_labels"] = y_eval
te_features.to_csv(
"output_predicted_data_{}".format(m_name), sep="\t", index=None
)
e_time = time.time()
print(
"Time taken by TabPFN for training and prediction: {} seconds".format(
e_time - s_time
)
)
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
arg_parser.add_argument(
"-testhaslabels",
"--testhaslabels",
required=True,
help="if test data contain labels",
)
arg_parser.add_argument(
"-selectedtask",
"--selected_task",
required=True,
help="Type of machine learning task",
)
# get argument values
args = vars(arg_parser.parse_args())
train_evaluate(args)