Skip to content

Commit cefdfdc

Browse files
authored
Merge pull request #1598 from anuprulez/tabpfn_license
TabPFN updates
2 parents 7cdafed + 606bf12 commit cefdfdc

File tree

2 files changed

+89
-80
lines changed

2 files changed

+89
-80
lines changed

tools/tabpfn/main.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def classification_plot(y_true, y_scores):
5757
plt.plot(
5858
recall, precision, linestyle="--", color="black", label="Micro-average"
5959
)
60-
plt.title("Precision-Recall Curve (Multiclass Classification)")
60+
plt.title(
61+
"Precision-Recall Curve (Multiclass Classification)"
62+
)
6163
plt.xlabel("Recall")
6264
plt.ylabel("Precision")
6365
plt.legend(loc="lower left")
@@ -85,21 +87,25 @@ def train_evaluate(args):
8587
# prepare train data
8688
tr_features, tr_labels = separate_features_labels(args["train_data"])
8789
# prepare test data
88-
if args["testhaslabels"] == "haslabels":
90+
if args["testhaslabels"] == "true":
8991
te_features, te_labels = separate_features_labels(args["test_data"])
9092
else:
9193
te_features = pd.read_csv(args["test_data"], sep="\t")
9294
te_labels = []
9395
s_time = time.time()
9496
if args["selected_task"] == "Classification":
95-
classifier = TabPFNClassifier()
97+
classifier = TabPFNClassifier(random_state=42)
9698
classifier.fit(tr_features, tr_labels)
9799
y_eval = classifier.predict(te_features)
98100
pred_probas_test = classifier.predict_proba(te_features)
99101
if len(te_labels) > 0:
100102
classification_plot(te_labels, pred_probas_test)
103+
te_features["predicted_labels"] = y_eval
104+
te_features.to_csv(
105+
"output_predicted_data", sep="\t", index=None
106+
)
101107
else:
102-
regressor = TabPFNRegressor()
108+
regressor = TabPFNRegressor(random_state=42)
103109
regressor.fit(tr_features, tr_labels)
104110
y_eval = regressor.predict(te_features)
105111
if len(te_labels) > 0:
@@ -112,14 +118,14 @@ def train_evaluate(args):
112118
"True values",
113119
"Predicted values",
114120
)
121+
te_features["predicted_labels"] = y_eval
122+
te_features.to_csv(
123+
"output_predicted_data", sep="\t", index=None
124+
)
115125
e_time = time.time()
116126
print(
117-
"Time taken by TabPFN for training and prediction: {} seconds".format(
118-
e_time - s_time
119-
)
127+
f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds"
120128
)
121-
te_features["predicted_labels"] = y_eval
122-
te_features.to_csv("output_predicted_data", sep="\t", index=None)
123129

124130

125131
if __name__ == "__main__":

tools/tabpfn/tabpfn.xml

Lines changed: 74 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,115 +2,118 @@
22
<description>with PyTorch</description>
33
<macros>
44
<token name="@TOOL_VERSION@">2.0.3</token>
5-
<token name="@VERSION_SUFFIX@">1.1</token>
5+
<token name="@VERSION_SUFFIX@">1.2</token>
66
</macros>
77
<creator>
8-
<organization name="European Galaxy Team" url="https://galaxyproject.org/eu/" />
9-
<person givenName="Anup" familyName="Kumar" email="kumara@informatik.uni-freiburg.de" />
10-
<person givenName="Frank" familyName="Hutter" email="fh@cs.uni-freiburg.de" />
8+
<organization name="European Galaxy Team" url="https://galaxyproject.org/eu/"/>
9+
<person givenName="Anup" familyName="Kumar" email="kumara@informatik.uni-freiburg.de"/>
10+
<person givenName="Frank" familyName="Hutter" email="fh@cs.uni-freiburg.de"/>
1111
</creator>
1212
<requirements>
13-
<requirement type="package" version="@TOOL_VERSION@">tabpfn</requirement>
14-
<requirement type="package" version="2.2.2">pandas</requirement>
15-
<requirement type="package" version="3.9.2">matplotlib</requirement>
13+
<requirement type="package" version="@TOOL_VERSION@">tabpfn</requirement>
14+
<requirement type="package" version="2.2.2">pandas</requirement>
15+
<requirement type="package" version="3.9.2">matplotlib</requirement>
1616
</requirements>
1717
<version_command>echo "@VERSION@"</version_command>
1818
<command detect_errors="aggressive">
19-
<![CDATA[
19+
<![CDATA[
2020
python '$__tool_directory__/main.py'
2121
--selected_task '$selected_task'
2222
--train_data '$train_data'
2323
--testhaslabels '$testhaslabels'
2424
--test_data '$test_data'
25-
]]>
25+
]]>
2626
</command>
2727
<inputs>
28-
<param name="selected_task" type="select" label="Select a machine learning task">
29-
<option value="Classification" selected="true"></option>
30-
<option value="Regression" selected="false"></option>
31-
</param>
32-
<param name="train_data" type="data" format="tabular" label="Train data" help="Please provide training data for training model. It should contain labels/class/target in the last column" />
33-
<param name="test_data" type="data" format="tabular" label="Test data" help="Please provide test data for evaluating model. It may or may not contain labels/class/target in the last column" />
34-
<param name="testhaslabels" type="boolean" truevalue="haslabels" falsevalue="" checked="false" label="Does test data contain labels?" help="Set this parameter when test data contains labels" />
28+
<param name="selected_task" type="select" label="Select a machine learning task">
29+
<option value="Classification" selected="true"/>
30+
<option value="Regression" selected="false"/>
31+
</param>
32+
<param name="train_data" type="data" format="tabular" label="Train data" help="Please provide training data for training model. It should contain labels/class/target in the last column"/>
33+
<param name="test_data" type="data" format="tabular" label="Test data" help="Please provide test data for evaluating model. It may or may not contain labels/class/target in the last column"/>
34+
<param name="testhaslabels" type="boolean" truevalue="true" falsevalue="false" checked="false" label="Does test data contain labels?" help="Set this parameter when test data contains labels"/>
3535
</inputs>
3636
<outputs>
37-
<data format="tabular" name="output_predicted_data" from_work_dir="output_predicted_data" label="Predicted data"></data>
38-
<data format="png" name="output_plot" from_work_dir="output_plot.png" label="Prediction plot on test data">
39-
<filter>testhaslabels is True</filter>
40-
</data>
37+
<data format="tabular" name="output_predicted_data" from_work_dir="output_predicted_data" label="Predicted data"/>
38+
<data format="png" name="output_plot" from_work_dir="output_plot.png" label="Prediction plot on test data">
39+
<filter>testhaslabels is True</filter>
40+
</data>
4141
</outputs>
4242
<tests>
43-
<test expect_num_outputs="1">
44-
<param name="selected_task" value="Classification" />
45-
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular" />
46-
<param name="test_data" value="classification_local_test_rows.tabular" ftype="tabular" />
47-
<param name="testhaslabels" value="" />
43+
<test expect_num_outputs="1">
44+
<param name="selected_task" value="Classification"/>
45+
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular"/>
46+
<param name="test_data" value="classification_local_test_rows.tabular" ftype="tabular"/>
47+
<param name="testhaslabels" value="false"/>
4848
<output name="output_predicted_data">
49-
<assert_contents>
50-
<has_n_columns n="42" />
51-
<has_n_lines n="3" />
52-
</assert_contents>
53-
</output>
49+
<assert_contents>
50+
<has_n_columns n="42"/>
51+
<has_n_lines n="3"/>
52+
</assert_contents>
53+
</output>
5454
</test>
55-
<test expect_num_outputs="2">
56-
<param name="selected_task" value="Classification" />
57-
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular" />
58-
<param name="test_data" value="classification_local_test_rows_labels.tabular" ftype="tabular" />
59-
<param name="testhaslabels" value="haslabels" />
60-
<output name="output_plot" file="prc_binary.png" compare="sim_size" />
55+
<test expect_num_outputs="2">
56+
<param name="selected_task" value="Classification"/>
57+
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular"/>
58+
<param name="test_data" value="classification_local_test_rows_labels.tabular" ftype="tabular"/>
59+
<param name="testhaslabels" value="true"/>
60+
<output name="output_plot" file="prc_binary.png" compare="sim_size"/>
6161
</test>
62-
<test expect_num_outputs="2">
63-
<param name="selected_task" value="Classification" />
64-
<param name="train_data" value="train_data_multiclass.tabular" ftype="tabular" />
65-
<param name="test_data" value="test_data_multiclass_labels.tabular" ftype="tabular" />
66-
<param name="testhaslabels" value="haslabels" />
67-
<output name="output_plot" file="prc_multiclass.png" compare="sim_size" />
62+
<test expect_num_outputs="2">
63+
<param name="selected_task" value="Classification"/>
64+
<param name="train_data" value="train_data_multiclass.tabular" ftype="tabular"/>
65+
<param name="test_data" value="test_data_multiclass_labels.tabular" ftype="tabular"/>
66+
<param name="testhaslabels" value="true"/>
67+
<output name="output_plot" file="prc_multiclass.png" compare="sim_size"/>
6868
</test>
6969
<test expect_num_outputs="1">
70-
<param name="selected_task" value="Classification" />
71-
<param name="train_data" value="train_data_multiclass.tabular" ftype="tabular" />
72-
<param name="test_data" value="test_data_multiclass_nolabels.tabular" ftype="tabular" />
73-
<param name="testhaslabels" value="" />
74-
<output name="output_predicted_data">
75-
<assert_contents>
76-
<has_n_columns n="11" />
77-
<has_n_lines n="502" />
70+
<param name="selected_task" value="Classification"/>
71+
<param name="train_data" value="train_data_multiclass.tabular" ftype="tabular"/>
72+
<param name="test_data" value="test_data_multiclass_nolabels.tabular" ftype="tabular"/>
73+
<param name="testhaslabels" value="false"/>
74+
<output name="output_predicted_data">
75+
<assert_contents>
76+
<has_n_columns n="11"/>
77+
<has_n_lines n="502"/>
7878
</assert_contents>
7979
</output>
8080
</test>
81-
<test expect_num_outputs="2">
82-
<param name="selected_task" value="Regression" />
83-
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular" />
84-
<param name="test_data" value="regression_local_test_rows_labels.tabular" ftype="tabular" />
85-
<param name="testhaslabels" value="haslabels" />
86-
<output name="output_plot" file="r2_curve.png" compare="sim_size" />
87-
</test>
88-
<test expect_num_outputs="1">
89-
<param name="selected_task" value="Regression" />
90-
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular" />
91-
<param name="test_data" value="regression_local_test_rows.tabular" ftype="tabular" />
92-
<param name="testhaslabels" value="" />
93-
<output name="output_predicted_data">
94-
<assert_contents>
95-
<has_n_columns n="14" />
96-
<has_n_lines n="105" />
97-
</assert_contents>
98-
</output>
99-
</test>
81+
<test expect_num_outputs="2">
82+
<param name="selected_task" value="Regression"/>
83+
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular"/>
84+
<param name="test_data" value="regression_local_test_rows_labels.tabular" ftype="tabular"/>
85+
<param name="testhaslabels" value="true"/>
86+
<output name="output_plot" file="r2_curve.png" compare="sim_size"/>
87+
</test>
88+
<test expect_num_outputs="1">
89+
<param name="selected_task" value="Regression"/>
90+
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular"/>
91+
<param name="test_data" value="regression_local_test_rows.tabular" ftype="tabular"/>
92+
<param name="testhaslabels" value="false"/>
93+
<output name="output_predicted_data">
94+
<assert_contents>
95+
<has_n_columns n="14"/>
96+
<has_n_lines n="105"/>
97+
</assert_contents>
98+
</output>
99+
</test>
100100
</tests>
101101
<help>
102102
<![CDATA[
103103
**What it does**
104104
105-
Classification and Regression on tabular data by TabPFN
105+
Classification and Regression on tabular data by TabPFN. The use of GPU is recommended while training TabPFN to optimize runtime. Currently, TabPFN supports upto 10,000 samples (rows) and 500 features (columns) in a tabular data.
106106
107107
**Input files**
108108
- Training data: the training data should contain features and the last column should be the class labels. It should be in tabular format.
109109
- Test data: the test data should also contain the same features as the training data and the last column should be the class labels if labels are avaialble. It should be in tabular format. It is not required for the test data to have labels.
110110
111111
**Output files**
112-
- Predicted data along with predicted labels.
112+
- Predicted data along with predicted labels.
113113
- Prediction plot (when test data has labels available).
114+
115+
**License**
116+
- TabPFN is available under an open source license (https://github.com/PriorLabs/TabPFN?tab=License-1-ov-file) that combines Apache with a LLama-like attribution clause. It requires you to prominently display "Built with TabPFN" when you use a pipeline including it in production.
114117
]]>
115118
</help>
116119
<citations>

0 commit comments

Comments
 (0)