Skip to content

Commit 56bb815

Browse files
author
Anup Kumar
committed
toggle pretraining limits based on number of samples
1 parent 4be1bb4 commit 56bb815

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

tools/tabpfn/main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def train_evaluate(args):
8484
"""
8585
Train TabPFN and predict
8686
"""
87+
MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES = 1000
88+
SEED = 42
8789
# prepare train data
8890
tr_features, tr_labels = separate_features_labels(args["train_data"], args["train_header"])
8991
# prepare test data
@@ -94,7 +96,10 @@ def train_evaluate(args):
9496
te_labels = []
9597
s_time = time.time()
9698
if args["selected_task"] == "Classification":
97-
classifier = TabPFNClassifier(random_state=42, model_path=args["model_path"], ignore_pretraining_limits=True)
99+
if tr_features.shape[0] <= MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES:
100+
classifier = TabPFNClassifier(random_state=SEED, model_path=args["model_path"])
101+
else:
102+
classifier = TabPFNClassifier(random_state=SEED, model_path=args["model_path"], ignore_pretraining_limits=True)
98103
classifier.fit(tr_features, tr_labels)
99104
y_eval = classifier.predict(te_features)
100105
pred_probas_test = classifier.predict_proba(te_features)
@@ -105,7 +110,10 @@ def train_evaluate(args):
105110
"output_predicted_data", sep="\t", index=None
106111
)
107112
else:
108-
regressor = TabPFNRegressor(random_state=42, model_path=args["model_path"], ignore_pretraining_limits=True)
113+
if tr_features.shape[0] <= MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES:
114+
regressor = TabPFNRegressor(random_state=SEED, model_path=args["model_path"])
115+
else:
116+
regressor = TabPFNRegressor(random_state=SEED, model_path=args["model_path"], ignore_pretraining_limits=True)
109117
regressor.fit(tr_features, tr_labels)
110118
y_eval = regressor.predict(te_features)
111119
if len(te_labels) > 0:

0 commit comments

Comments
 (0)