Skip to content

Commit 01a66e5

Browse files
committed
Break up reporting function
1 parent bfdc061 commit 01a66e5

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

src/conformist/performance_report.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class PerformanceReport(OutputDir):
88
def __init__(self, base_output_dir):
9-
self.base_output_dir = base_output_dir
9+
self.create_output_dir(self.base_output_dir)
1010

1111
def mean_set_size(prediction_sets):
1212
return sum(sum(prediction_set) for
@@ -33,12 +33,36 @@ def pct_trio_plus_sets(prediction_sets):
3333
prediction_set in prediction_sets) / \
3434
len(prediction_sets)
3535

36-
def report_class_statistics(self,
37-
mean_set_sizes_by_class,
38-
mean_fnrs_by_class):
36+
def visualize_mean_set_sizes_by_class(self,
37+
mean_set_sizes_by_class):
38+
# Reset plt
39+
plt.figure()
40+
41+
# Sort the dictionary by its values
42+
mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(),
43+
key=lambda item: item[1]))
44+
45+
# Convert dictionary to dataframe and transpose
46+
df = pd.DataFrame(mean_set_sizes, index=[0]).T
47+
48+
# Save as csv
49+
df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv',
50+
index=True, header=False)
51+
52+
# Visualize this dict as a bar chart
53+
sns.set_style('whitegrid')
54+
fig, ax = plt.subplots(figsize=(10, 6))
55+
ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1])
56+
ax.set_xticks(range(len(mean_set_sizes)))
57+
ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical')
58+
ax.set_ylabel('Mean set size')
59+
ax.set_xlabel('True class')
60+
plt.tight_layout()
61+
plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png')
3962

63+
def visualize_mean_fnrs_by_class(self,
64+
mean_fnrs_by_class):
4065
# Setup
41-
self.create_output_dir(self.base_output_dir)
4266
plt.figure()
4367

4468
# Sort the dictionary by its values
@@ -66,27 +90,8 @@ def report_class_statistics(self,
6690
df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv',
6791
index=True, header=False)
6892

69-
# Reset plt
70-
plt.figure()
71-
72-
# Sort the dictionary by its values
73-
mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(),
74-
key=lambda item: item[1]))
75-
76-
# Convert dictionary to dataframe and transpose
77-
df = pd.DataFrame(mean_set_sizes, index=[0]).T
78-
79-
# Save as csv
80-
df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv',
81-
index=True, header=False)
82-
83-
# Visualize this dict as a bar chart
84-
sns.set_style('whitegrid')
85-
fig, ax = plt.subplots(figsize=(10, 6))
86-
ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1])
87-
ax.set_xticks(range(len(mean_set_sizes)))
88-
ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical')
89-
ax.set_ylabel('Mean set size')
90-
ax.set_xlabel('True class')
91-
plt.tight_layout()
92-
plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png')
93+
def report_class_statistics(self,
94+
mean_set_sizes_by_class,
95+
mean_fnrs_by_class):
96+
self.visualize_mean_fnrs_by_class(mean_fnrs_by_class)
97+
self.visualize_mean_set_sizes_by_class(mean_set_sizes_by_class)

0 commit comments

Comments
 (0)