1- from typing import Tuple
1+ from typing import TYPE_CHECKING
22
3- import pandas as pd
4- import numpy as np
53import matplotlib
4+ import matplotlib .cm
5+ import matplotlib .colors
6+ import matplotlib .gridspec
7+ import numpy as np
8+ import pandas as pd
69from matplotlib .lines import Line2D
7- import matplotlib .pyplot as plt
8- import igraph as ig
910
1011from decoupler ._docs import docs
1112from decoupler ._Plotter import Plotter
1213
14+ # Handle optional igraph dependency
15+ try :
16+ import igraph as ig
17+
18+ HAS_IGRAPH = True
19+ if TYPE_CHECKING :
20+ from igraph import Graph
21+ else :
22+ Graph = ig .Graph
23+ except ImportError :
24+ ig = None
25+ HAS_IGRAPH = False
26+ if TYPE_CHECKING :
27+ from typing import Any as Graph
28+ else :
29+ Graph = None
30+
31+
32+ def _check_igraph () -> None :
33+ """Check if igraph is available and raise informative error if not."""
34+ if not HAS_IGRAPH :
35+ raise ImportError (
36+ "igraph is not installed. Please install it using:\n "
37+ " pip install igraph\n "
38+ "or install decoupler with plotting dependencies:\n "
39+ " pip install 'decoupler[plot]'"
40+ )
41+
1342
1443def _src_idxs (
1544 score : pd .DataFrame ,
1645 sources : int | list | str ,
1746 by_abs : bool ,
1847) -> np .ndarray :
19- assert isinstance (sources , (int , list , str )), \
20- 'sources must be int, list or str'
48+ assert isinstance (sources , (int , list , str )), "sources must be int, list or str"
2149 if isinstance (sources , int ):
2250 if by_abs :
2351 s_idx = np .argsort (- abs (score .values [0 ]))[:sources ]
@@ -36,62 +64,62 @@ def _trg_idxs(
3664 targets : int | list | str ,
3765 by_abs : bool ,
3866) -> np .ndarray :
39- assert isinstance (targets , (int , list , str )), \
40- 'targets must be int, list or str'
67+ assert isinstance (targets , (int , list , str )), "targets must be int, list or str"
4168 if isinstance (targets , int ):
42- net ['prod' ] = [data .iloc [0 ][t ] * w if t in data .columns else 0 for t , w in zip (net ['target' ], net ['weight' ])]
69+ net ["prod" ] = [
70+ data .iloc [0 ][t ] * w if t in data .columns else 0 for t , w in zip (net ["target" ], net ["weight" ], strict = False )
71+ ]
4372 if by_abs :
44- net [' prod' ] = abs (net [' prod' ])
73+ net [" prod" ] = abs (net [" prod" ])
4574 t_idx = (
46- net
47- .sort_values (['source' , 'prod' ], ascending = [True , False ])
48- .groupby (['source' ], observed = True )
75+ net .sort_values (["source" , "prod" ], ascending = [True , False ])
76+ .groupby (["source" ], observed = True )
4977 .head (targets )
50- .index
51- .values
78+ .index .values
5279 )
5380 elif isinstance (targets , list ):
54- t_idx = np .isin (net [' target' ].astype (str ), targets )
81+ t_idx = np .isin (net [" target" ].astype (str ), targets )
5582 else :
56- t_idx = np .isin (net [' target' ].astype (str ), [targets ])
83+ t_idx = np .isin (net [" target" ].astype (str ), [targets ])
5784 return t_idx
5885
59-
86+
6087def _filter (
6188 data : pd .DataFrame ,
6289 score : pd .DataFrame ,
6390 net : pd .DataFrame ,
6491 sources : int ,
6592 targets : int ,
6693 by_abs : bool ,
67- ) -> Tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
68- assert isinstance (data , pd .DataFrame ), 'data must be pd.DataFrame'
69- assert isinstance (score , pd .DataFrame ), 'score must be pd.DataFrame'
70- assert np .all (data .index == score .index ) and (data .index .size == 1 ), \
71- 'data and score need to have the same row index.'
72- assert isinstance (by_abs , bool ), 'by_abs must be bool'
94+ ) -> tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
95+ assert isinstance (data , pd .DataFrame ), "data must be pd.DataFrame"
96+ assert isinstance (score , pd .DataFrame ), "score must be pd.DataFrame"
97+ assert np .all (data .index == score .index ) and (data .index .size == 1 ), (
98+ "data and score need to have the same row index."
99+ )
100+ assert isinstance (by_abs , bool ), "by_abs must be bool"
73101 # Select top sources
74102 s_idx = _src_idxs (score = score , sources = sources , by_abs = by_abs )
75103 # Filter
76104 score = score .iloc [:, s_idx ]
77- net = net .loc [np .isin (net [' source' ].astype (str ), score .columns .astype (str )), :].copy ()
78- if ' weight' not in net .columns :
79- net [' weight' ] = 1.
105+ net = net .loc [np .isin (net [" source" ].astype (str ), score .columns .astype (str )), :].copy ()
106+ if " weight" not in net .columns :
107+ net [" weight" ] = 1.0
80108 # Select top targets
81109 t_idx = _trg_idxs (data = data , net = net , targets = targets , by_abs = by_abs )
82110 # Filter
83111 net = net .loc [t_idx ]
84112 # Filter unmatched features
85- data = data .loc [:, np .isin (data .columns .astype (str ), net [' target' ].astype (str ))]
86- net = net .loc [np .isin (net [' target' ].astype (str ), data .columns .astype (str )), :]
113+ data = data .loc [:, np .isin (data .columns .astype (str ), net [" target" ].astype (str ))]
114+ net = net .loc [np .isin (net [" target" ].astype (str ), data .columns .astype (str )), :]
87115 return data , score , net
88116
89117
90118def _norm (
91119 x : np .ndarray ,
92120 vcenter : bool ,
93121) -> matplotlib .colors .Normalize :
94- assert isinstance (vcenter , bool ), ' vcenter must be bool'
122+ assert isinstance (vcenter , bool ), " vcenter must be bool"
95123 if vcenter :
96124 vmax = np .max (np .abs (x ))
97125 norm = matplotlib .colors .Normalize (vmin = - vmax , vmax = vmax )
@@ -105,7 +133,7 @@ def _norm(
105133def _dict_types (
106134 data : pd .DataFrame ,
107135 score : pd .DataFrame ,
108- ) -> Tuple [dict , np .ndarray ]:
136+ ) -> tuple [dict , np .ndarray ]:
109137 vs = np .unique (np .hstack ([data .columns , score .columns ]))
110138 v_dict = {k : i for i , k in enumerate (vs )}
111139 types = (~ np .isin (vs , score .columns )) * 1
@@ -118,7 +146,7 @@ def _net_2_elist(
118146) -> list :
119147 edges = []
120148 for i in net .index :
121- source , target = net .loc [i , ' source' ], net .loc [i , ' target' ]
149+ source , target = net .loc [i , " source" ], net .loc [i , " target" ]
122150 edge = [v_dict [source ], v_dict [target ]]
123151 edges .append (edge )
124152 return edges
@@ -128,8 +156,9 @@ def _net_2_g(
128156 data : pd .DataFrame ,
129157 score : pd .DataFrame ,
130158 net : pd .DataFrame ,
131- ) -> ig . Graph :
159+ ) -> Graph :
132160 # Unify network
161+ _check_igraph ()
133162 v_dict , types = _dict_types (data = data , score = score )
134163 # Transform net to edges
135164 edges = _net_2_elist (net = net , v_dict = v_dict )
@@ -139,15 +168,15 @@ def _net_2_g(
139168 directed = True ,
140169 )
141170 # Update attributes
142- g .es [' weight' ] = net [' weight' ].values
143- g .vs [' type' ] = types
144- g .vs [' label' ] = list (v_dict .keys ())
145- g .vs [' shape' ] = np .where (types , ' circle' , ' square' )
171+ g .es [" weight" ] = net [" weight" ].values
172+ g .vs [" type" ] = types
173+ g .vs [" label" ] = list (v_dict .keys ())
174+ g .vs [" shape" ] = np .where (types , " circle" , " square" )
146175 return g
147176
148177
149178def _gcolors (
150- g : ig . Graph ,
179+ g : Graph ,
151180 data : pd .DataFrame ,
152181 score : pd .DataFrame ,
153182 s_norm : matplotlib .colors .Normalize ,
@@ -160,16 +189,16 @@ def _gcolors(
160189 s_cmap = matplotlib .colormaps .get_cmap (s_cmap )
161190 t_cmap = matplotlib .colormaps .get_cmap (t_cmap )
162191 color = []
163- for i , k in enumerate (g .vs [' label' ]):
164- if g .vs [' type' ][i ]:
192+ for i , k in enumerate (g .vs [" label" ]):
193+ if g .vs [" type" ][i ]:
165194 color .append (t_cmap (t_norm (data [k ].values [0 ])))
166195 else :
167196 color .append (s_cmap (s_norm (score [k ].values [0 ])))
168197 is_cmap = True
169198 else :
170- color = [s_cmap if typ == 0. else t_cmap for typ in g .vs [' type' ]]
199+ color = [s_cmap if typ == 0.0 else t_cmap for typ in g .vs [" type" ]]
171200 is_cmap = False
172- g .vs [' color' ] = color
201+ g .vs [" color" ] = color
173202 return is_cmap
174203
175204
@@ -183,15 +212,15 @@ def network(
183212 by_abs = True ,
184213 size_node = 5 ,
185214 size_label = 2.5 ,
186- s_cmap = ' RdBu_r' ,
187- t_cmap = ' viridis' ,
215+ s_cmap = " RdBu_r" ,
216+ t_cmap = " viridis" ,
188217 vcenter = False ,
189- c_pos_w = ' darkgreen' ,
190- c_neg_w = ' darkred' ,
191- s_label = ' Enrichment\n score' ,
192- t_label = ' Gene\n expression' ,
193- layout = 'kk' ,
194- ** kwargs
218+ c_pos_w = " darkgreen" ,
219+ c_neg_w = " darkred" ,
220+ s_label = " Enrichment\n score" ,
221+ t_label = " Gene\n expression" ,
222+ layout = "kk" ,
223+ ** kwargs ,
195224):
196225 """
197226 Plot results of enrichment analysis as network.
@@ -231,14 +260,16 @@ def network(
231260 Layout to use to order the nodes. Check ``igraph`` documentation for more options.
232261 %(plot)s
233262 """
234- assert isinstance (net , pd .DataFrame ), 'net must be pd.DataFrame'
235- assert (data is None ) == (score is None ), 'data and score must either both be None'
263+ assert isinstance (net , pd .DataFrame ), "net must be pd.DataFrame"
264+ assert (data is None ) == (score is None ), "data and score must either both be None"
265+ if ig is None :
266+ raise ImportError ("igraph is not installed. Please install it using `pip install igraph`." )
236267 if data is None :
237- srcs = net [' source' ].unique ().astype ('U' )
238- score = pd .DataFrame (np .ones ((1 , srcs .size )), index = ['0' ], columns = srcs )
239- trgs = net [' target' ].unique ().astype ('U' )
240- data = pd .DataFrame (np .ones ((1 , trgs .size )), index = ['0' ], columns = trgs )
241- t_cmap = ' white'
268+ srcs = net [" source" ].unique ().astype ("U" )
269+ score = pd .DataFrame (np .ones ((1 , srcs .size )), index = ["0" ], columns = srcs )
270+ trgs = net [" target" ].unique ().astype ("U" )
271+ data = pd .DataFrame (np .ones ((1 , trgs .size )), index = ["0" ], columns = trgs )
272+ t_cmap = " white"
242273 # Filter
243274 fdata , fscore , fnet = _filter (
244275 data = data ,
@@ -253,7 +284,7 @@ def network(
253284 t_norm = _norm (x = fdata , vcenter = vcenter )
254285 # Make graph
255286 g = _net_2_g (data = fdata , score = fscore , net = fnet )
256- g .es [' color' ] = [c_pos_w if w > 0 else c_neg_w for w in g .es [' weight' ]]
287+ g .es [" color" ] = [c_pos_w if w > 0 else c_neg_w for w in g .es [" weight" ]]
257288 is_cmap = _gcolors (
258289 g = g ,
259290 data = data ,
@@ -264,7 +295,7 @@ def network(
264295 t_cmap = t_cmap ,
265296 )
266297 # Instance
267- kwargs ['ax' ] = None
298+ kwargs ["ax" ] = None
268299 bp = Plotter (** kwargs )
269300 bp .fig .delaxes (bp .ax )
270301 # Plot
@@ -279,7 +310,7 @@ def network(
279310 layout = layout ,
280311 vertex_size = (size_node * bp .dpi ) / (bp .figsize [0 ] * bp .figsize [0 ]),
281312 vertex_size_label = (size_label * bp .dpi ) / (bp .figsize [0 ] * bp .figsize [0 ]),
282- bbox_inches = ' tight' ,
313+ bbox_inches = " tight" ,
283314 )
284315 if is_cmap :
285316 sm = matplotlib .cm .ScalarMappable (norm = s_norm , cmap = s_cmap )
@@ -290,19 +321,35 @@ def network(
290321 ax2 .axis ("off" )
291322 ax4 .axis ("off" )
292323 # Add legend
293- square = Line2D ([0 ], [0 ], marker = 's' , color = 'w' , label = 'Source' , markerfacecolor = 'white' ,
294- markeredgecolor = 'black' , markersize = 10 )
295- circle = Line2D ([0 ], [0 ], marker = 'o' , color = 'w' , label = 'Target' , markerfacecolor = 'white' ,
296- markeredgecolor = 'black' , markersize = 10 )
297- line1 = Line2D ((0 , 0 ), (1 , 0 ), color = c_pos_w , lw = 2 , marker = '>' ,)
298- line2 = Line2D ((0 , 0 ), (1 , 0 ), color = c_neg_w , lw = 2 , marker = '>' ,)
324+ square = Line2D (
325+ [0 ], [0 ], marker = "s" , color = "w" , label = "Source" , markerfacecolor = "white" , markeredgecolor = "black" , markersize = 10
326+ )
327+ circle = Line2D (
328+ [0 ], [0 ], marker = "o" , color = "w" , label = "Target" , markerfacecolor = "white" , markeredgecolor = "black" , markersize = 10
329+ )
330+ line1 = Line2D (
331+ (0 , 0 ),
332+ (1 , 0 ),
333+ color = c_pos_w ,
334+ lw = 2 ,
335+ marker = ">" ,
336+ )
337+ line2 = Line2D (
338+ (0 , 0 ),
339+ (1 , 0 ),
340+ color = c_neg_w ,
341+ lw = 2 ,
342+ marker = ">" ,
343+ )
299344 handles = [square , circle , line1 , line2 ]
300- labels = [' Source' , ' Target' , ' Positive' , ' Negative' ]
345+ labels = [" Source" , " Target" , " Positive" , " Negative" ]
301346 legend = ax3 .legend (
302347 handles = [square , circle , line1 , line2 ],
303348 labels = labels ,
304349 frameon = False ,
305- loc = 'center' , bbox_to_anchor = (0.5 , 0.5 ), bbox_transform = ax3 .transAxes
350+ loc = "center" ,
351+ bbox_to_anchor = (0.5 , 0.5 ),
352+ bbox_transform = ax3 .transAxes ,
306353 )
307- ax3 .axis (' off' )
354+ ax3 .axis (" off" )
308355 return bp ._return ()
0 commit comments