Skip to content

Commit 03e7b4c

Browse files
committed
fix review comments
1 parent 84fa20d commit 03e7b4c

File tree

2 files changed

+50
-63
lines changed

2 files changed

+50
-63
lines changed

tools/tabpfn/main.py

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import matplotlib.pyplot as plt
88
import numpy as np
99
import pandas as pd
10-
from catboost import CatBoostClassifier, CatBoostRegressor
10+
1111
from sklearn.metrics import (
1212
average_precision_score,
1313
precision_recall_curve,
@@ -25,7 +25,7 @@ def separate_features_labels(data):
2525
return features, labels
2626

2727

28-
def classification_plot(y_true, y_scores, m_name):
28+
def classification_plot(y_true, y_scores):
2929
plt.figure(figsize=(8, 6))
3030
is_binary = len(np.unique(y_true)) == 2
3131
if is_binary:
@@ -37,7 +37,7 @@ def classification_plot(y_true, y_scores, m_name):
3737
precision,
3838
label=f"Precision-Recall Curve (AP={average_precision:.2f})",
3939
)
40-
plt.title(f"{m_name}: Precision-Recall Curve (binary classification)")
40+
plt.title("Precision-Recall Curve (binary classification)")
4141
else:
4242
y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
4343
n_classes = y_true_bin.shape[1]
@@ -59,16 +59,16 @@ def classification_plot(y_true, y_scores, m_name):
5959
recall, precision, linestyle="--", color="black", label="Micro-average"
6060
)
6161
plt.title(
62-
f"{m_name}: Precision-Recall Curve (Multiclass Classification)"
62+
"Precision-Recall Curve (Multiclass Classification)"
6363
)
6464
plt.xlabel("Recall")
6565
plt.ylabel("Precision")
6666
plt.legend(loc="lower left")
6767
plt.grid(True)
68-
plt.savefig(f"output_plot_{m_name}.png")
68+
plt.savefig("output_plot.png")
6969

7070

71-
def regression_plot(xval, yval, title, xlabel, ylabel, m_name):
71+
def regression_plot(xval, yval, title, xlabel, ylabel):
7272
plt.figure(figsize=(8, 6))
7373
plt.xlabel(xlabel)
7474
plt.ylabel(ylabel)
@@ -78,7 +78,7 @@ def regression_plot(xval, yval, title, xlabel, ylabel, m_name):
7878
plt.scatter(xval, yval, alpha=0.8)
7979
xticks = np.arange(len(xval))
8080
plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
81-
plt.savefig(f"output_plot_{m_name}.png")
81+
plt.savefig("output_plot.png")
8282

8383

8484
def train_evaluate(args):
@@ -95,43 +95,34 @@ def train_evaluate(args):
9595
te_labels = []
9696
s_time = time.time()
9797
if args["selected_task"] == "Classification":
98-
models = [
99-
("TabPFN", TabPFNClassifier(random_state=42)),
100-
("CatBoost", CatBoostClassifier(random_state=42, verbose=0)),
101-
]
102-
for m_name, model in models:
103-
model.fit(tr_features, tr_labels)
104-
y_eval = model.predict(te_features)
105-
pred_probas_test = model.predict_proba(te_features)
106-
if len(te_labels) > 0:
107-
classification_plot(te_labels, pred_probas_test, m_name)
108-
te_features["predicted_labels"] = y_eval
109-
te_features.to_csv(
110-
f"output_predicted_data_{m_name}", sep="\t", index=None
111-
)
98+
classifier = TabPFNClassifier(random_state=42)
99+
classifier.fit(tr_features, tr_labels)
100+
y_eval = classifier.predict(te_features)
101+
pred_probas_test = classifier.predict_proba(te_features)
102+
if len(te_labels) > 0:
103+
classification_plot(te_labels, pred_probas_test)
104+
te_features["predicted_labels"] = y_eval
105+
te_features.to_csv(
106+
"output_predicted_data", sep="\t", index=None
107+
)
112108
else:
113-
models = [
114-
("TabPFN", TabPFNRegressor(random_state=42)),
115-
("CatBoost", CatBoostRegressor(random_state=42, verbose=0)),
116-
]
117-
for m_name, model in models:
118-
model.fit(tr_features, tr_labels)
119-
y_eval = model.predict(te_features)
120-
if len(te_labels) > 0:
121-
score = root_mean_squared_error(te_labels, y_eval)
122-
r2_metric_score = r2_score(te_labels, y_eval)
123-
regression_plot(
124-
te_labels,
125-
y_eval,
126-
f"Scatter plot for {m_name}: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
127-
"True values",
128-
"Predicted values",
129-
m_name,
130-
)
131-
te_features["predicted_labels"] = y_eval
132-
te_features.to_csv(
133-
f"output_predicted_data_{m_name}", sep="\t", index=None
109+
regressor = TabPFNRegressor(random_state=42)
110+
regressor.fit(tr_features, tr_labels)
111+
y_eval = regressor.predict(te_features)
112+
if len(te_labels) > 0:
113+
score = root_mean_squared_error(te_labels, y_eval)
114+
r2_metric_score = r2_score(te_labels, y_eval)
115+
regression_plot(
116+
te_labels,
117+
y_eval,
118+
f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
119+
"True values",
120+
"Predicted values",
134121
)
122+
te_features["predicted_labels"] = y_eval
123+
te_features.to_csv(
124+
"output_predicted_data", sep="\t", index=None
125+
)
135126
e_time = time.time()
136127
print(
137128
f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds"

tools/tabpfn/tabpfn.xml

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<description>with PyTorch</description>
33
<macros>
44
<token name="@TOOL_VERSION@">2.0.3</token>
5-
<token name="@VERSION_SUFFIX@">1.2</token>
5+
<token name="@VERSION_SUFFIX@">1.1</token>
66
</macros>
77
<creator>
88
<organization name="European Galaxy Team" url="https://galaxyproject.org/eu/"/>
@@ -11,7 +11,8 @@
1111
</creator>
1212
<requirements>
1313
<requirement type="package" version="@TOOL_VERSION@">tabpfn</requirement>
14-
<requirement type="package" version="1.2.7">catboost</requirement>
14+
<requirement type="package" version="2.2.2">pandas</requirement>
15+
<requirement type="package" version="3.9.2">matplotlib</requirement>
1516
</requirements>
1617
<version_command>echo "@VERSION@"</version_command>
1718
<command detect_errors="aggressive">
@@ -33,67 +34,63 @@
3334
<param name="testhaslabels" type="boolean" truevalue="true" falsevalue="false" checked="false" label="Does test data contain labels?" help="Set this parameter when test data contains labels"/>
3435
</inputs>
3536
<outputs>
36-
<data format="tabular" name="output_predicted_data_TabPFN" from_work_dir="output_predicted_data_TabPFN" label="Predicted data by TabPFN"/>
37-
<data format="tabular" name="output_predicted_data_CatBoost" from_work_dir="output_predicted_data_CatBoost" label="Predicted data by CatBoost"/>
38-
<data format="png" name="output_plot_TabPFN" from_work_dir="output_plot_TabPFN.png" label="Prediction plot on test data TabPFN">
39-
<filter>testhaslabels is True</filter>
40-
</data>
41-
<data format="png" name="output_plot_CatBoost" from_work_dir="output_plot_CatBoost.png" label="Prediction plot on test data using CatBoost">
37+
<data format="tabular" name="output_predicted_data" from_work_dir="output_predicted_data" label="Predicted data"/>
38+
<data format="png" name="output_plot" from_work_dir="output_plot.png" label="Prediction plot on test data">
4239
<filter>testhaslabels is True</filter>
4340
</data>
4441
</outputs>
4542
<tests>
46-
<test expect_num_outputs="2">
43+
<test expect_num_outputs="1">
4744
<param name="selected_task" value="Classification"/>
4845
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular"/>
4946
<param name="test_data" value="classification_local_test_rows.tabular" ftype="tabular"/>
5047
<param name="testhaslabels" value="false"/>
51-
<output name="output_predicted_data_TabPFN">
48+
<output name="output_predicted_data">
5249
<assert_contents>
5350
<has_n_columns n="42"/>
5451
<has_n_lines n="3"/>
5552
</assert_contents>
5653
</output>
5754
</test>
58-
<test expect_num_outputs="4">
55+
<test expect_num_outputs="2">
5956
<param name="selected_task" value="Classification"/>
6057
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular"/>
6158
<param name="test_data" value="classification_local_test_rows_labels.tabular" ftype="tabular"/>
6259
<param name="testhaslabels" value="true"/>
63-
<output name="output_plot_TabPFN" file="prc_binary.png" compare="sim_size"/>
60+
<output name="output_plot" file="prc_binary.png" compare="sim_size"/>
6461
</test>
65-
<test expect_num_outputs="4">
62+
<test expect_num_outputs="2">
6663
<param name="selected_task" value="Classification"/>
6764
<param name="train_data" value="train_data_multiclass.tabular" ftype="tabular"/>
6865
<param name="test_data" value="test_data_multiclass_labels.tabular" ftype="tabular"/>
6966
<param name="testhaslabels" value="true"/>
70-
<output name="output_plot_TabPFN" file="prc_multiclass.png" compare="sim_size"/>
67+
<output name="output_plot" file="prc_multiclass.png" compare="sim_size"/>
7168
</test>
72-
<test expect_num_outputs="2">
69+
<test expect_num_outputs="1">
7370
<param name="selected_task" value="Classification"/>
7471
<param name="train_data" value="train_data_multiclass.tabular" ftype="tabular"/>
7572
<param name="test_data" value="test_data_multiclass_nolabels.tabular" ftype="tabular"/>
7673
<param name="testhaslabels" value="false"/>
77-
<output name="output_predicted_data_CatBoost">
74+
<output name="output_predicted_data">
7875
<assert_contents>
7976
<has_n_columns n="11"/>
8077
<has_n_lines n="502"/>
8178
</assert_contents>
8279
</output>
8380
</test>
84-
<test expect_num_outputs="4">
81+
<test expect_num_outputs="2">
8582
<param name="selected_task" value="Regression"/>
8683
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular"/>
8784
<param name="test_data" value="regression_local_test_rows_labels.tabular" ftype="tabular"/>
8885
<param name="testhaslabels" value="true"/>
89-
<output name="output_plot_TabPFN" file="r2_curve.png" compare="sim_size"/>
86+
<output name="output_plot" file="r2_curve.png" compare="sim_size"/>
9087
</test>
91-
<test expect_num_outputs="2">
88+
<test expect_num_outputs="1">
9289
<param name="selected_task" value="Regression"/>
9390
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular"/>
9491
<param name="test_data" value="regression_local_test_rows.tabular" ftype="tabular"/>
9592
<param name="testhaslabels" value="false"/>
96-
<output name="output_predicted_data_TabPFN">
93+
<output name="output_predicted_data">
9794
<assert_contents>
9895
<has_n_columns n="14"/>
9996
<has_n_lines n="105"/>
@@ -110,7 +107,6 @@
110107
**Input files**
111108
- Training data: the training data should contain features and the last column should be the class labels. It should be in tabular format.
112109
- Test data: the test data should also contain the same features as the training data and the last column should be the class labels if labels are avaialble. It should be in tabular format. It is not required for the test data to have labels.
113-
- Above files show performance comparison of TabPFN with CatBoost (https://github.com/catboost/catboost).
114110
115111
**Output files**
116112
- Predicted data along with predicted labels.

0 commit comments

Comments
 (0)