Skip to content

Commit 56634bb

Browse files
committed
Create new performancereport class to create reports for validation runs and trials
1 parent d1c2adf commit 56634bb

File tree

4 files changed

+79
-71
lines changed

4 files changed

+79
-71
lines changed

src/conformist/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .validation_trial import ValidationTrial
77
from .validation_run import ValidationRun
88
from .model_vs_cop_fnr import ModelVsCopFNR
9+
from .performance_report import PerformanceReport
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import matplotlib.pyplot as plt
2+
import seaborn as sns
3+
import pandas as pd
4+
from .output_dir import OutputDir
5+
6+
7+
class PerformanceReport(OutputDir):
8+
def __init__(self, base_output_dir):
9+
self.base_output_dir = base_output_dir
10+
11+
def report_class_statistics(self,
12+
mean_set_sizes_by_class,
13+
mean_fnrs_by_class):
14+
15+
# Setup
16+
self.create_output_dir(self.base_output_dir)
17+
plt.figure()
18+
19+
# Sort the dictionary by its values
20+
mean_fnrs = dict(sorted(mean_fnrs_by_class.items(),
21+
key=lambda item: item[1]))
22+
23+
# Visualize this dict as a bar chart
24+
sns.set_style('whitegrid')
25+
palette = sns.color_palette("deep")
26+
fig, ax = plt.subplots(figsize=(10, 6))
27+
ax.bar(range(len(mean_fnrs)), mean_fnrs.values(), color=palette[0])
28+
ax.set_xticks(range(len(mean_fnrs)))
29+
ax.set_xticklabels(mean_fnrs.keys(), rotation='vertical')
30+
ax.set_ylabel('Mean FNR')
31+
ax.set_xlabel('True class')
32+
plt.tight_layout()
33+
34+
# Export as fig and text
35+
plt.savefig(f'{self.output_dir}/mean_fnrs_by_class.png')
36+
37+
# Convert dictionary to dataframe and transpose
38+
df = pd.DataFrame(mean_fnrs, index=[0]).T
39+
40+
# Save as csv
41+
df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv',
42+
index=True, header=False)
43+
44+
# Reset plt
45+
plt.figure()
46+
47+
# Sort the dictionary by its values
48+
mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(),
49+
key=lambda item: item[1]))
50+
51+
# Convert dictionary to dataframe and transpose
52+
df = pd.DataFrame(mean_set_sizes, index=[0]).T
53+
54+
# Save as csv
55+
df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv',
56+
index=True, header=False)
57+
58+
# Visualize this dict as a bar chart
59+
sns.set_style('whitegrid')
60+
fig, ax = plt.subplots(figsize=(10, 6))
61+
ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1])
62+
ax.set_xticks(range(len(mean_set_sizes)))
63+
ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical')
64+
ax.set_ylabel('Mean set size')
65+
ax.set_xlabel('True class')
66+
plt.tight_layout()
67+
plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png')

src/conformist/validation_run.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pandas as pd
33
from .output_dir import OutputDir
4+
from .performance_report import PerformanceReport
45

56

67
class ValidationRun(OutputDir):
@@ -124,6 +125,10 @@ def mean_fnrs_by_class(self, class_names):
124125
return averages
125126

126127
def run_reports(self, base_output_dir):
128+
pr = PerformanceReport(base_output_dir)
129+
pr.report_class_statistics(self.mean_set_sizes_by_class(self.class_names),
130+
self.mean_fnrs_by_class(self.class_names))
131+
127132
np.seterr(all='raise')
128133
self.create_output_dir(base_output_dir)
129134

@@ -139,10 +144,4 @@ def run_reports(self, base_output_dir):
139144

140145
df.T.to_csv(f'{self.output_dir}/summary.csv', header=False)
141146

142-
df = pd.DataFrame(self.mean_set_sizes_by_class(self.class_names), index=[0])
143-
df.T.to_csv(f'{self.output_dir}/mean_set_sizes_by_class.csv', header=False)
144-
145-
df = pd.DataFrame(self.mean_fnrs_by_class(self.class_names), index=[0])
146-
df.T.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv', header=False)
147-
148147
print(f'Reports saved to {self.output_dir}')

src/conformist/validation_trial.py

Lines changed: 6 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55

66
from .output_dir import OutputDir
7+
from .performance_report import PerformanceReport
78

89

910
class ValidationTrial(OutputDir):
@@ -102,10 +103,10 @@ def mean_fnrs_by_class(self, class_names):
102103
means[class_name] = statistics.mean(d_means)
103104
return means
104105

105-
def run_reports(self, base_output_dir, display_classes=None):
106+
def run_reports(self, base_output_dir):
106107
self.create_output_dir(base_output_dir)
107108
self.visualize_empirical_fnr()
108-
self.visualize_class_performance(display_classes)
109+
self.visualize_class_performance()
109110
print(f'Reports saved to {self.output_dir}')
110111

111112
def visualize_empirical_fnr(self):
@@ -140,68 +141,8 @@ def visualize_empirical_fnr(self):
140141
# Save the figure
141142
plt.savefig(f'{self.output_dir}/empirical_fnr.png')
142143

143-
def visualize_class_performance(self, display_classes=None):
144-
def translate_class_name(class_name):
145-
if display_classes is None or class_name not in display_classes:
146-
return class_name
147-
else:
148-
return display_classes[class_name]
149-
150-
def translate_dict(d):
151-
return {translate_class_name(key): value for key, value in d.items()}
152-
153-
plt.figure()
154-
155-
# FNRs by class
144+
def visualize_class_performance(self):
156145
fnrs_by_class = self.mean_fnrs_by_class(self.class_names)
157-
158-
# Sort the dictionary by its values
159-
mean_fnrs = translate_dict(dict(sorted(fnrs_by_class.items(), key=lambda item: item[1])))
160-
161-
# Visualize this dict as a bar chart
162-
sns.set_style('whitegrid')
163-
palette = sns.color_palette("deep") # You can change "deep" to any other palette
164-
fig, ax = plt.subplots(figsize=(10, 6))
165-
ax.bar(range(len(mean_fnrs)), mean_fnrs.values(), color=palette[0])
166-
ax.set_xticks(range(len(mean_fnrs)))
167-
ax.set_xticklabels(mean_fnrs.keys(), rotation='vertical')
168-
ax.set_ylabel('Mean FNR')
169-
ax.set_xlabel('True class')
170-
plt.tight_layout()
171-
172-
# Export as fig and text
173-
plt.savefig(f'{self.output_dir}/mean_fnrs_by_class.png')
174-
175-
# Convert dictionary to dataframe
176-
df = pd.DataFrame(mean_fnrs, index=[0]).T
177-
178-
# Transpose the dataframe
179-
df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv',
180-
index=True, header=False)
181-
182-
# Reset plt
183-
plt.figure()
184-
185-
# Set sizes by class
186146
mean_set_sizes = self.mean_set_sizes_by_class(self.class_names)
187-
188-
# Sort the dictionary by its values
189-
mean_set_sizes = translate_dict(dict(sorted(mean_set_sizes.items(), key=lambda item: item[1])))
190-
191-
# Convert dictionary to dataframe
192-
df = pd.DataFrame(mean_set_sizes, index=[0]).T
193-
194-
# Transpose the dataframe
195-
df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv',
196-
index=True, header=False)
197-
198-
# Visualize this dict as a bar chart
199-
sns.set_style('whitegrid')
200-
fig, ax = plt.subplots(figsize=(10, 6))
201-
ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1])
202-
ax.set_xticks(range(len(mean_set_sizes)))
203-
ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical')
204-
ax.set_ylabel('Mean set size')
205-
ax.set_xlabel('True class')
206-
plt.tight_layout()
207-
plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png')
147+
pr = PerformanceReport(self.output_dir)
148+
pr.report_class_statistics(mean_set_sizes, fnrs_by_class)

0 commit comments

Comments
 (0)