Skip to content

Commit ac7e713

Browse files
committed
Make output dir optional
1 parent b1eaae5 commit ac7e713

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/conformist/base_cop.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,10 @@ def do_validation_trial(self,
113113

114114
def predict(self,
115115
pds,
116-
output_dir,
116+
output_dir=None,
117117
validate=False,
118118
upset_plot_color="black"):
119119

120-
self.create_output_dir(output_dir)
121-
122120
self.smx = pds.smx
123121
self.val_smx = pds.smx
124122
self.val_labels = pds.labels_idx
@@ -131,6 +129,11 @@ def predict(self,
131129
prediction_sets_text = self.prediction_sets_to_text(prediction_sets)
132130
formatted_predictions = pds.prediction_sets_df(prediction_sets_text)
133131

132+
if not output_dir:
133+
return formatted_predictions
134+
135+
self.create_output_dir(output_dir)
136+
134137
# -- WRITE PREDICTIONS TO CSV
135138
formatted_predictions.to_csv(f'{self.output_dir}/prediction_sets.csv')
136139

0 commit comments

Comments
 (0)