@@ -15,13 +15,14 @@ def source_targets(
1515 x : str ,
1616 y : str ,
1717 name : str ,
18- top : int = 5 ,
18+ top : int | None = 5 ,
1919 thr_x : float = 0.0 ,
2020 thr_y : float = 0.0 ,
2121 max_x : float | None = None ,
2222 max_y : float | None = None ,
2323 color_pos : str = "#D62728" ,
2424 color_neg : str = "#1F77B4" ,
25+ kw_scatter : dict | None = None ,
2526 ** kwargs ,
2627) -> None | Figure :
2728 """
@@ -38,7 +39,7 @@ def source_targets(
3839 name
3940 Name of the source to plot.
4041 top
41- Number of top features based on the product of x and y to label.
42+ Number of top features to show labels based on the product of x and y to label. Can be None .
4243 thr_x
4344 Value were to place a baseline for the x-axis.
4445 thr_y
@@ -51,6 +52,8 @@ def source_targets(
5152 Color to plot positively associated features.
5253 color_neg
5354 Color to plot negatively associated features.
55+ kw_scatter
56+ Keyword arguments passed to ``matplotlib.pyplot.scatter``.
5457 %(plot)s
5558
5659 Example
@@ -72,7 +75,9 @@ def source_targets(
7275 assert not pd .api .types .is_numeric_dtype (data .index ), "data index must be features in net"
7376 assert isinstance (net , pd .DataFrame ), f"net must be a pd.DataFrame containing the columns { x } and { y } "
7477 assert isinstance (name , str ), "name must be a str"
75- assert isinstance (top , int ) and top > 0 , "top must be int and > 0"
78+ if top is None :
79+ top = 0
80+ assert isinstance (top , int ) and top >= 0 , "top must be int and >= 0"
7681 assert isinstance (thr_x , int | float ), "thr_x must be numeric"
7782 assert isinstance (thr_y , int | float ), "thr_y must be numeric"
7883 if max_x is None :
@@ -83,6 +88,8 @@ def source_targets(
8388 assert isinstance (max_y , int | float ) and max_y > 0 , "max_y must be None, or numeric and > 0"
8489 assert isinstance (color_pos , str ), "color_pos must be str"
8590 assert isinstance (color_neg , str ), "color_neg must be str"
91+ if kw_scatter is None :
92+ kw_scatter = {}
8693 # Instance
8794 bp = Plotter (** kwargs )
8895 # Extract df
@@ -101,7 +108,7 @@ def source_targets(
101108 df ["color" ] = color_neg
102109 df .loc [pos , "color" ] = color_pos
103110 # Plot
104- df .plot .scatter (x = x , y = y , c = "color" , ax = bp .ax )
111+ df .plot .scatter (x = x , y = y , c = "color" , ax = bp .ax , ** kw_scatter )
105112 # Draw thr lines
106113 bp .ax .axvline (x = thr_x , linestyle = "--" , color = "black" )
107114 bp .ax .axhline (y = thr_y , linestyle = "--" , color = "black" )
@@ -112,10 +119,11 @@ def source_targets(
112119 # Show top features
113120 df ["order" ] = df [x ].abs () * df [y ].abs ()
114121 signs = df .sort_values ("order" , ascending = False )
115- signs = signs .iloc [:top ]
116- texts = []
117- for tx , ty , ts in zip (signs [x ], signs [y ], signs .index , strict = False ):
118- texts .append (bp .ax .text (tx , ty , ts ))
119- if len (texts ) > 0 :
120- at .adjust_text (texts , arrowprops = {"arrowstyle" : "-" , "color" : "black" }, ax = bp .ax )
122+ if top > 0 :
123+ signs = signs .iloc [:top ]
124+ texts = []
125+ for tx , ty , ts in zip (signs [x ], signs [y ], signs .index , strict = False ):
126+ texts .append (bp .ax .text (tx , ty , ts ))
127+ if len (texts ) > 0 :
128+ at .adjust_text (texts , arrowprops = {"arrowstyle" : "-" , "color" : "black" }, ax = bp .ax )
121129 return bp ._return ()
0 commit comments