|
4 | 4 | import pandas as pd |
5 | 5 |
|
6 | 6 | from .output_dir import OutputDir |
| 7 | +from .performance_report import PerformanceReport |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class ValidationTrial(OutputDir): |
@@ -102,10 +103,10 @@ def mean_fnrs_by_class(self, class_names): |
102 | 103 | means[class_name] = statistics.mean(d_means) |
103 | 104 | return means |
104 | 105 |
|
105 | | - def run_reports(self, base_output_dir, display_classes=None): |
| 106 | + def run_reports(self, base_output_dir): |
106 | 107 | self.create_output_dir(base_output_dir) |
107 | 108 | self.visualize_empirical_fnr() |
108 | | - self.visualize_class_performance(display_classes) |
| 109 | + self.visualize_class_performance() |
109 | 110 | print(f'Reports saved to {self.output_dir}') |
110 | 111 |
|
111 | 112 | def visualize_empirical_fnr(self): |
@@ -140,68 +141,8 @@ def visualize_empirical_fnr(self): |
140 | 141 | # Save the figure |
141 | 142 | plt.savefig(f'{self.output_dir}/empirical_fnr.png') |
142 | 143 |
|
143 | | - def visualize_class_performance(self, display_classes=None): |
144 | | - def translate_class_name(class_name): |
145 | | - if display_classes is None or class_name not in display_classes: |
146 | | - return class_name |
147 | | - else: |
148 | | - return display_classes[class_name] |
149 | | - |
150 | | - def translate_dict(d): |
151 | | - return {translate_class_name(key): value for key, value in d.items()} |
152 | | - |
153 | | - plt.figure() |
154 | | - |
155 | | - # FNRs by class |
| 144 | + def visualize_class_performance(self): |
156 | 145 | fnrs_by_class = self.mean_fnrs_by_class(self.class_names) |
157 | | - |
158 | | - # Sort the dictionary by its values |
159 | | - mean_fnrs = translate_dict(dict(sorted(fnrs_by_class.items(), key=lambda item: item[1]))) |
160 | | - |
161 | | - # Visualize this dict as a bar chart |
162 | | - sns.set_style('whitegrid') |
163 | | - palette = sns.color_palette("deep") # You can change "deep" to any other palette |
164 | | - fig, ax = plt.subplots(figsize=(10, 6)) |
165 | | - ax.bar(range(len(mean_fnrs)), mean_fnrs.values(), color=palette[0]) |
166 | | - ax.set_xticks(range(len(mean_fnrs))) |
167 | | - ax.set_xticklabels(mean_fnrs.keys(), rotation='vertical') |
168 | | - ax.set_ylabel('Mean FNR') |
169 | | - ax.set_xlabel('True class') |
170 | | - plt.tight_layout() |
171 | | - |
172 | | - # Export as fig and text |
173 | | - plt.savefig(f'{self.output_dir}/mean_fnrs_by_class.png') |
174 | | - |
175 | | - # Convert dictionary to dataframe |
176 | | - df = pd.DataFrame(mean_fnrs, index=[0]).T |
177 | | - |
178 | | - # Transpose the dataframe |
179 | | - df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv', |
180 | | - index=True, header=False) |
181 | | - |
182 | | - # Reset plt |
183 | | - plt.figure() |
184 | | - |
185 | | - # Set sizes by class |
186 | 146 | mean_set_sizes = self.mean_set_sizes_by_class(self.class_names) |
187 | | - |
188 | | - # Sort the dictionary by its values |
189 | | - mean_set_sizes = translate_dict(dict(sorted(mean_set_sizes.items(), key=lambda item: item[1]))) |
190 | | - |
191 | | - # Convert dictionary to dataframe |
192 | | - df = pd.DataFrame(mean_set_sizes, index=[0]).T |
193 | | - |
194 | | - # Transpose the dataframe |
195 | | - df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv', |
196 | | - index=True, header=False) |
197 | | - |
198 | | - # Visualize this dict as a bar chart |
199 | | - sns.set_style('whitegrid') |
200 | | - fig, ax = plt.subplots(figsize=(10, 6)) |
201 | | - ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1]) |
202 | | - ax.set_xticks(range(len(mean_set_sizes))) |
203 | | - ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical') |
204 | | - ax.set_ylabel('Mean set size') |
205 | | - ax.set_xlabel('True class') |
206 | | - plt.tight_layout() |
207 | | - plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png') |
| 147 | + pr = PerformanceReport(self.output_dir) |
| 148 | + pr.report_class_statistics(mean_set_sizes, fnrs_by_class) |
0 commit comments