@@ -21,7 +21,8 @@ def heatmap(
2121 annotate_values = True ,
2222 cluster_rows : bool = True , # cluster parameters
2323 cluster_cols : bool = False , # cluster outputs
24- transpose : bool = False
24+ transpose : bool = False ,
25+ title : Optional [str ] = None ,
2526):
2627 """Creates heatmap of model sensitivity"""
2728
@@ -61,12 +62,12 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
6162
6263 n_outputs = df_subset .shape [1 ]
6364 n_parameters = df_subset .shape [0 ]
64- figsize = (10 , 15 )
65+ figsize = (int ( n_outputs / n_parameters * 30 ) , 15 )
6566
6667 colorbar_range = 2.0
6768
6869 # plot heatmap
69- ax = sns .clustermap (
70+ cg = sns .clustermap (
7071 df_subset ,
7172 center = 0 ,
7273 vmin = - colorbar_range ,
@@ -90,17 +91,26 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
9091 figsize = figsize ,
9192 )
9293 plt .setp (
93- ax .ax_heatmap .get_xticklabels (),
94+ cg .ax_heatmap .get_xticklabels (),
9495 rotation = 45 ,
9596 horizontalalignment = "right" ,
9697 size = 20 ,
9798 )
9899 label_fontsize = 10
99- plt .setp (ax .ax_heatmap .get_yticklabels (), size = label_fontsize )
100- plt .setp (ax .ax_heatmap .get_xticklabels (), size = label_fontsize )
101- ax .ax_cbar .tick_params (labelsize = label_fontsize )
102- ax .ax_row_dendrogram .set_visible (False )
103- ax .ax_col_dendrogram .set_visible (False )
100+ plt .setp (cg .ax_heatmap .get_yticklabels (), size = label_fontsize )
101+ plt .setp (cg .ax_heatmap .get_xticklabels (), size = label_fontsize )
102+ cg .ax_cbar .tick_params (labelsize = label_fontsize )
103+ cg .ax_row_dendrogram .set_visible (False )
104+ cg .ax_col_dendrogram .set_visible (False )
105+
106+ if title :
107+ plt .suptitle (title )
108+
109+ # for label in cg.ax_heatmap.get_xticklabels():
110+ # label.set_bbox(dict(facecolor='tab:blue', edgecolor='black', alpha=0.8))
111+ #
112+ # for label in cg.ax_heatmap.get_yticklabels():
113+ # label.set_bbox(dict(facecolor='tab:orange', edgecolor='black', alpha=0.8))
104114
105115 # create custom legend containing yticklabels and their description
106116 # handles = [t.get_text() for t in ax.ax_heatmap.get_yticklabels()]
0 commit comments