Skip to content

Commit e7a6221

Browse files
committed
Add functions to count classes appearing in prediction sets
1 parent 47b8f6a commit e7a6221

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

src/conformist/validation_run.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,25 @@ def mean_fnrs_by_class(self, sets, class_names):
117117
averages[key] = sum(sizes) / len(sizes)
118118
return averages
119119

120+
def prediction_counts_by_class(self, class_names, coocurring_only=False):
121+
counts = {}
122+
for i in range(len(self.prediction_sets)):
123+
labels = self.prediction_sets[i]
124+
125+
# Get corresponding values from class_names
126+
pset_class_names = [class_names[i] for i, label in enumerate(labels) if label == 1]
127+
128+
do_count = True
129+
if coocurring_only:
130+
do_count = len(pset_class_names) > 1
131+
132+
if do_count:
133+
for class_name in pset_class_names:
134+
class_counts = counts.get(class_name, 0)
135+
class_counts += 1
136+
counts[class_name] = class_counts
137+
return counts
138+
120139
def run_reports(self, base_output_dir):
121140
mean_set_sizes = self.mean_set_sizes_by_class(self.class_names)
122141
mean_fnrs = self.mean_fnrs_by_class(self.prediction_sets, self.class_names)

src/conformist/validation_trial.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,24 @@ def mean_fnrs_by_class(self, class_names):
115115
means[class_name] = statistics.mean(d_means)
116116
return means
117117

118+
def mean_prediction_counts_by_class(self, class_names, coocurring_only=False):
119+
prediction_count_dicts = []
120+
for run in self.runs:
121+
prediction_count_dicts.append(run.prediction_counts_by_class(class_names, coocurring_only))
122+
123+
means = {}
124+
for class_name in class_names:
125+
d_means = []
126+
for d in prediction_count_dicts:
127+
if class_name in d:
128+
d_means.append(d[class_name])
129+
if len(d_means) > 0:
130+
means[class_name] = statistics.mean(d_means)
131+
132+
# Sort by count descending
133+
means = dict(sorted(means.items(), key=lambda item: item[1], reverse=True))
134+
return means
135+
118136
def run_reports(self, base_output_dir):
119137
self.create_output_dir(base_output_dir)
120138
self.visualize_empirical_fnr()

0 commit comments

Comments
 (0)