66
77class PerformanceReport (OutputDir ):
88 def __init__ (self , base_output_dir ):
9- self .base_output_dir = base_output_dir
9+ self .create_output_dir ( self . base_output_dir )
1010
1111 def mean_set_size (prediction_sets ):
1212 return sum (sum (prediction_set ) for
@@ -33,12 +33,36 @@ def pct_trio_plus_sets(prediction_sets):
3333 prediction_set in prediction_sets ) / \
3434 len (prediction_sets )
3535
36- def report_class_statistics (self ,
37- mean_set_sizes_by_class ,
38- mean_fnrs_by_class ):
36+ def visualize_mean_set_sizes_by_class (self ,
37+ mean_set_sizes_by_class ):
38+ # Reset plt
39+ plt .figure ()
40+
41+ # Sort the dictionary by its values
42+ mean_set_sizes = dict (sorted (mean_set_sizes_by_class .items (),
43+ key = lambda item : item [1 ]))
44+
45+ # Convert dictionary to dataframe and transpose
46+ df = pd .DataFrame (mean_set_sizes , index = [0 ]).T
47+
48+ # Save as csv
49+ df .to_csv (f'{ self .output_dir } /mean_set_sizes_class.csv' ,
50+ index = True , header = False )
51+
52+ # Visualize this dict as a bar chart
53+ sns .set_style ('whitegrid' )
54+ fig , ax = plt .subplots (figsize = (10 , 6 ))
55+ ax .bar (range (len (mean_set_sizes )), mean_set_sizes .values (), color = palette [1 ])
56+ ax .set_xticks (range (len (mean_set_sizes )))
57+ ax .set_xticklabels (mean_set_sizes .keys (), rotation = 'vertical' )
58+ ax .set_ylabel ('Mean set size' )
59+ ax .set_xlabel ('True class' )
60+ plt .tight_layout ()
61+ plt .savefig (f'{ self .output_dir } /mean_set_sizes_by_class.png' )
3962
63+ def visualize_mean_fnrs_by_class (self ,
64+ mean_fnrs_by_class ):
4065 # Setup
41- self .create_output_dir (self .base_output_dir )
4266 plt .figure ()
4367
4468 # Sort the dictionary by its values
@@ -66,27 +90,8 @@ def report_class_statistics(self,
6690 df .to_csv (f'{ self .output_dir } /mean_fnrs_by_class.csv' ,
6791 index = True , header = False )
6892
69- # Reset plt
70- plt .figure ()
71-
72- # Sort the dictionary by its values
73- mean_set_sizes = dict (sorted (mean_set_sizes_by_class .items (),
74- key = lambda item : item [1 ]))
75-
76- # Convert dictionary to dataframe and transpose
77- df = pd .DataFrame (mean_set_sizes , index = [0 ]).T
78-
79- # Save as csv
80- df .to_csv (f'{ self .output_dir } /mean_set_sizes_class.csv' ,
81- index = True , header = False )
82-
83- # Visualize this dict as a bar chart
84- sns .set_style ('whitegrid' )
85- fig , ax = plt .subplots (figsize = (10 , 6 ))
86- ax .bar (range (len (mean_set_sizes )), mean_set_sizes .values (), color = palette [1 ])
87- ax .set_xticks (range (len (mean_set_sizes )))
88- ax .set_xticklabels (mean_set_sizes .keys (), rotation = 'vertical' )
89- ax .set_ylabel ('Mean set size' )
90- ax .set_xlabel ('True class' )
91- plt .tight_layout ()
92- plt .savefig (f'{ self .output_dir } /mean_set_sizes_by_class.png' )
93+ def report_class_statistics (self ,
94+ mean_set_sizes_by_class ,
95+ mean_fnrs_by_class ):
96+ self .visualize_mean_fnrs_by_class (mean_fnrs_by_class )
97+ self .visualize_mean_set_sizes_by_class (mean_set_sizes_by_class )
0 commit comments