@@ -84,6 +84,8 @@ def train_evaluate(args):
8484 """
8585 Train TabPFN and predict
8686 """
87+ MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES = 1000
88+ SEED = 42
8789 # prepare train data
8890 tr_features , tr_labels = separate_features_labels (args ["train_data" ], args ["train_header" ])
8991 # prepare test data
@@ -94,7 +96,10 @@ def train_evaluate(args):
9496 te_labels = []
9597 s_time = time .time ()
9698 if args ["selected_task" ] == "Classification" :
97- classifier = TabPFNClassifier (random_state = 42 , model_path = args ["model_path" ], ignore_pretraining_limits = True )
99+ if tr_features .shape [0 ] <= MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES :
100+ classifier = TabPFNClassifier (random_state = SEED , model_path = args ["model_path" ])
101+ else :
102+ classifier = TabPFNClassifier (random_state = SEED , model_path = args ["model_path" ], ignore_pretraining_limits = True )
98103 classifier .fit (tr_features , tr_labels )
99104 y_eval = classifier .predict (te_features )
100105 pred_probas_test = classifier .predict_proba (te_features )
@@ -105,7 +110,10 @@ def train_evaluate(args):
105110 "output_predicted_data" , sep = "\t " , index = None
106111 )
107112 else :
108- regressor = TabPFNRegressor (random_state = 42 , model_path = args ["model_path" ], ignore_pretraining_limits = True )
113+ if tr_features .shape [0 ] <= MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES :
114+ regressor = TabPFNRegressor (random_state = SEED , model_path = args ["model_path" ])
115+ else :
116+ regressor = TabPFNRegressor (random_state = SEED , model_path = args ["model_path" ], ignore_pretraining_limits = True )
109117 regressor .fit (tr_features , tr_labels )
110118 y_eval = regressor .predict (te_features )
111119 if len (te_labels ) > 0 :
0 commit comments