File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -113,12 +113,10 @@ def do_validation_trial(self,
113113
114114 def predict (self ,
115115 pds ,
116- output_dir ,
116+ output_dir = None ,
117117 validate = False ,
118118 upset_plot_color = "black" ):
119119
120- self .create_output_dir (output_dir )
121-
122120 self .smx = pds .smx
123121 self .val_smx = pds .smx
124122 self .val_labels = pds .labels_idx
@@ -131,6 +129,11 @@ def predict(self,
131129 prediction_sets_text = self .prediction_sets_to_text (prediction_sets )
132130 formatted_predictions = pds .prediction_sets_df (prediction_sets_text )
133131
132+ if not output_dir :
133+ return formatted_predictions
134+
135+ self .create_output_dir (output_dir )
136+
134137 # -- WRITE PREDICTIONS TO CSV
135138 formatted_predictions .to_csv (f'{ self .output_dir } /prediction_sets.csv' )
136139
You can’t perform that action at this time.
0 commit comments