@@ -123,8 +123,9 @@ def dot_plot(sc_interactions, evaluation='communication', significance=0.05, sen
123123
124124
125125def generate_dot_plot (pval_df , score_df , significance = 0.05 , xlabel = '' , ylabel = '' , cbar_title = 'Score' , cmap = 'PuOr' ,
126- figsize = (16 , 9 ), label_size = 20 , title_size = 20 , tick_size = 14 , filename = None ):
127- '''Generates a dot plot for given P-values and respective scores.
126+ figsize = None , label_size = 20 , title_size = 20 , tick_size = 14 , filename = None ,
127+ min_row_height = 0.3 , reference_height = 1.0 ):
128+ '''Generates a dot plot for given P-values and respective scores with improved spacing.
128129
129130 Parameters
130131 ----------
@@ -154,8 +155,9 @@ def generate_dot_plot(pval_df, score_df, significance=0.05, xlabel='', ylabel=''
154155 cmap : str, default='PuOr'
155156 A matplotlib color palette name.
156157
157- figsize : tuple, default=(16, 9)
158+ figsize : tuple, default=None
158159 Size of the figure (width*height), each in inches.
160+ If None, it will be automatically calculated based on the data.
159161
160162 label_size : int, default=20
161163 Specifies the size of the labels of both X and Y axes.
@@ -171,6 +173,12 @@ def generate_dot_plot(pval_df, score_df, significance=0.05, xlabel='', ylabel=''
171173 Path to save the figure of the elbow analysis. If None, the figure is not
172174 saved.
173175
176+ min_row_height : float, default=0.3
177+ Minimum height per row in inches to prevent dot overlap.
178+
179+ reference_height : float, default=1.0
180+ Fixed height in inches for the reference legend subplot.
181+
174182 Returns
175183 -------
176184 fig : matplotlib.figure.Figure
@@ -183,6 +191,21 @@ def generate_dot_plot(pval_df, score_df, significance=0.05, xlabel='', ylabel=''
183191 df = df .T .loc [(df != 0 ).any (axis = 0 )].T
184192 pval_df = pval_df [df .columns ].loc [df .index ].applymap (lambda x : - 1. * np .log10 (x + 1e-9 ))
185193
194+ n_rows = len (pval_df .index )
195+ n_cols = len (pval_df .columns )
196+
197+ # Auto-calculate figure size if not provided
198+ if figsize is None :
199+ # Calculate width based on number of columns (with a reasonable range)
200+ width = max (8 , min (20 , n_cols * 0.4 + 4 ))
201+ # Calculate height based on number of rows to prevent overlap
202+ main_plot_height = max (6 , n_rows * min_row_height )
203+ height = main_plot_height + reference_height + 1.5 # +1.5 for labels and padding
204+ figsize = (width , height )
205+ else :
206+ # If figsize is provided, extract main plot height
207+ main_plot_height = figsize [1 ] - reference_height - 1.5
208+
186209 # Set dot sizes and color range
187210 max_abs = np .max ([np .abs (np .min (np .min (score_df ))), np .abs (np .max (np .max (score_df )))])
188211 norm = mpl .colors .Normalize (vmin = - 1. * max_abs , vmax = max_abs )
@@ -191,8 +214,18 @@ def generate_dot_plot(pval_df, score_df, significance=0.05, xlabel='', ylabel=''
191214 # Colormap
192215 cmap = mpl .cm .get_cmap (cmap )
193216
194- # Dot plot
195- fig , (ax2 , ax ) = plt .subplots (2 , 1 , figsize = figsize , gridspec_kw = {'height_ratios' : [1 , 9 ]})
217+ # Create figure with proper height ratios
218+ # Use height_ratios based on actual inches rather than arbitrary numbers
219+ height_ratio_ref = reference_height / main_plot_height
220+
221+ fig = plt .figure (figsize = figsize )
222+ gs = fig .add_gridspec (2 , 1 , height_ratios = [height_ratio_ref , 1.0 ],
223+ hspace = 0.15 ) # Reduced space between subplots
224+
225+ ax2 = fig .add_subplot (gs [0 ])
226+ ax = fig .add_subplot (gs [1 ])
227+
228+ # Dot plot - scatter each point
196229 for i , idx in enumerate (pval_df .index ):
197230 for j , col in enumerate (pval_df .columns ):
198231 color = np .asarray (cmap (norm (score_df [[col ]].loc [[idx ]].values .item ()))).reshape (1 , - 1 )
@@ -215,57 +248,72 @@ def generate_dot_plot(pval_df, score_df, significance=0.05, xlabel='', ylabel=''
215248 ax .set_yticks (ticks = range (0 , len (pval_df .index )))
216249 ax .set_yticklabels (ylabels ,
217250 fontsize = tick_size ,
218- rotation = 0 , ha = 'right' , va = 'center'
219- )
220-
221- plt .gca ().invert_yaxis ()
222-
223- plt .tick_params (axis = 'both' ,
224- which = 'both' ,
225- bottom = True ,
226- top = False ,
227- right = False ,
228- left = True ,
229- labelleft = True ,
230- labelbottom = True )
251+ rotation = 0 , ha = 'right' , va = 'center' )
252+
253+ # Set explicit axis limits with small padding to eliminate white space
254+ ax .set_xlim (- 0.5 , n_cols - 0.5 )
255+ ax .set_ylim (- 0.5 , n_rows - 0.5 )
256+
257+ ax .invert_yaxis ()
258+
259+ ax .tick_params (axis = 'both' ,
260+ which = 'both' ,
261+ bottom = True ,
262+ top = False ,
263+ right = False ,
264+ left = True ,
265+ labelleft = True ,
266+ labelbottom = True )
231267 ax .set_xlabel (xlabel , fontsize = label_size )
232268 ax .set_ylabel (ylabel , fontsize = label_size )
233269
234270 # Colorbar
235- # create an axes on the top side of ax. The width of cax will be 3%
236- # of ax and the padding between cax and ax will be fixed at 0.21 inch.
237271 divider = make_axes_locatable (ax )
238272 cax = divider .append_axes ("top" , size = "3%" , pad = 0.21 )
239273
240274 cbar = plt .colorbar (mpl .cm .ScalarMappable (norm = norm , cmap = cmap ),
241275 cax = cax ,
242- orientation = 'horizontal'
243- )
276+ orientation = 'horizontal' )
244277 cbar .ax .tick_params (labelsize = tick_size )
245278
246- cax .tick_params (axis = 'x' , # changes apply to the x-axis
247- which = 'both' , # both major and minor ticks are affected
248- bottom = False , # ticks along the bottom edge are off
249- top = True , # ticks along the top edge are off
250- labelbottom = False , # labels along the bottom edge are off
251- labeltop = True
252- )
279+ cax .tick_params (axis = 'x' ,
280+ which = 'both' ,
281+ bottom = False ,
282+ top = True ,
283+ labelbottom = False ,
284+ labeltop = True )
253285 cax .set_title (cbar_title , fontsize = title_size )
254286
255- for i , v in enumerate ([np .min ([np .min (np .min (pval_df )), - 1. * np .log10 (0.99 )]), - 1. * np .log10 (significance + 1e-9 ), 3.0 ]): # old min np.min(np.min(pval_df))
256- ax2 .scatter (i , 0 , s = (max_size (v ) * tick_size * 2 ) ** 2 , c = 'k' )
257- ax2 .scatter (i , 1 , s = 0 , c = 'k' )
258- if v == 3.0 :
287+ # Reference size legend
288+ # Calculate reference p-values to show in legend
289+ min_pval_shown = np .max ([np .min (np .min (pval_df )), - 1. * np .log10 (0.99 )]) # Show at least -log10(0.99)
290+ threshold_pval = - 1. * np .log10 (significance + 1e-9 )
291+ max_pval_shown = 3.0
292+
293+ pval_values = [min_pval_shown , threshold_pval , max_pval_shown ]
294+
295+ for i , v in enumerate (pval_values ):
296+ # Cap at 3 just like in the main plot for consistency
297+ v_capped = np .min ([v , 3.0 ])
298+ size = (max_size (v_capped ) * tick_size * 2 ) ** 2
299+ ax2 .scatter (i , 0 , s = size , c = 'k' )
300+
301+ if v >= 3.0 :
259302 extra = '>='
260303 elif i == 1 :
261304 extra = 'Threshold: '
262305 else :
263306 extra = ''
264- ax2 .annotate (extra + str (np .round (abs (v ), 4 )), (i , 1 ), fontsize = tick_size , horizontalalignment = 'center' )
265- ax2 .set_ylim (- 0.5 , 2 )
307+ ax2 .annotate (extra + str (np .round (abs (v ), 4 )), (i , 0.5 ),
308+ fontsize = tick_size , horizontalalignment = 'center' )
309+
310+ # Set limits for reference plot to minimize white space
311+ ax2 .set_xlim (- 0.5 , 2.5 )
312+ ax2 .set_ylim (- 0.2 , 0.8 )
266313 ax2 .axis ('off' )
267- ax2 .set_title ('-log10(P-value) sizes' , fontsize = title_size )
314+ ax2 .set_title ('-log10(P-value) sizes' , fontsize = title_size , pad = 10 )
268315
269316 if filename is not None :
270317 plt .savefig (filename , dpi = 300 , bbox_inches = 'tight' )
271- return fig
318+
319+ return fig
0 commit comments