@@ -36,6 +36,9 @@ class ScatterStyle:
3636 alpha : float = 0.25
3737 linewidths : float = 0.0
3838
39+ display_legend : bool = True
40+ legend_outside : bool = False
41+
3942 legend_title : Optional [str ] = None
4043 legend_loc : str = "lower left"
4144 legend_frameon : bool = False
@@ -132,21 +135,40 @@ def scatter_plot_base(
132135 ax .set_xlabel ("" )
133136 ax .set_ylabel ("" )
134137
135- legend_title = style .legend_title if style .legend_title is not None else label_col
136- handles = _build_legend_handles (
137- legend_labels ,
138- palette ,
139- markersize = style .legend_markersize ,
140- alpha = style .legend_alpha ,
141- )
138+ # ---- legend (optional + outside option) ----
139+ if style .display_legend :
140+ legend_title = style .legend_title if style .legend_title is not None else label_col
141+ handles = _build_legend_handles (
142+ legend_labels ,
143+ palette ,
144+ markersize = style .legend_markersize ,
145+ alpha = style .legend_alpha ,
146+ )
142147
143- ax .legend (
144- handles = handles ,
145- title = legend_title ,
146- loc = style .legend_loc ,
147- frameon = style .legend_frameon ,
148- ncol = style .legend_ncol ,
149- )
148+ if style .legend_outside :
149+ # Put legend outside right; loc controls anchor point of legend box itself.
150+ ax .legend (
151+ handles = handles ,
152+ title = legend_title ,
153+ loc = "center left" ,
154+ bbox_to_anchor = (1.02 , 0.5 ),
155+ frameon = style .legend_frameon ,
156+ ncol = style .legend_ncol ,
157+ borderaxespad = 0.0 ,
158+ )
159+ # Leave room on the right so legend isn't clipped
160+ fig .tight_layout (rect = (0 , 0 , 0.85 , 1 ))
161+ else :
162+ ax .legend (
163+ handles = handles ,
164+ title = legend_title ,
165+ loc = style .legend_loc ,
166+ frameon = style .legend_frameon ,
167+ ncol = style .legend_ncol ,
168+ )
169+ fig .tight_layout ()
170+ else :
171+ fig .tight_layout ()
150172
151173 fig .tight_layout ()
152174 return fig , ax
@@ -174,6 +196,8 @@ def scatter_plot_all_classes(
174196 s : float = 5.0 ,
175197 alpha : float = 0.25 ,
176198 linewidths : float = 0.0 ,
199+ display_legend : bool = True ,
200+ legend_outside : bool = False ,
177201 legend_title : Optional [str ] = None ,
178202 legend_loc : str = "lower left" ,
179203 legend_frameon : bool = False ,
@@ -243,6 +267,8 @@ def scatter_plot_all_classes(
243267 s = s ,
244268 alpha = alpha ,
245269 linewidths = linewidths ,
270+ display_legend = display_legend ,
271+ legend_outside = legend_outside ,
246272 legend_title = legend_title if legend_title is not None else subclass_col ,
247273 legend_loc = legend_loc ,
248274 legend_frameon = legend_frameon ,
@@ -300,6 +326,8 @@ def scatter_plot_hierarchical_labels(
300326 s : float = 2.0 ,
301327 alpha : float = 0.2 ,
302328 linewidths : float = 0.0 ,
329+ display_legend : bool = True ,
330+ legend_outside : bool = False ,
303331 legend_title : str = "Class / Superclass" ,
304332 legend_loc : str = "lower left" ,
305333 legend_frameon : bool = False ,
@@ -398,6 +426,8 @@ def scatter_plot_hierarchical_labels(
398426 s = s ,
399427 alpha = alpha ,
400428 linewidths = linewidths ,
429+ display_legend = display_legend ,
430+ legend_outside = legend_outside ,
401431 legend_title = legend_title ,
402432 legend_loc = legend_loc ,
403433 legend_frameon = legend_frameon ,
0 commit comments