Skip to content

Commit 61e37b5

Browse files
authored
Merge pull request #186 from Intron7/make-igraph-optinal
make igraph optional to keep BDS license
2 parents da398b9 + c269ea3 commit 61e37b5

2 files changed

Lines changed: 122 additions & 71 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ dependencies = [
3131
"xgboost",
3232
"seaborn",
3333
"requests",
34-
"igraph",
3534
"marsilea",
3635
"adjustText",
3736
# for debug logging (referenced from the issue template)
3837
"session-info2",
3938
]
39+
optional-dependencies.plot = [
40+
"igraph",
41+
]
4042
optional-dependencies.dev = [
4143
"pre-commit",
4244
"twine>=4.0.2",
@@ -56,13 +58,15 @@ optional-dependencies.doc = [
5658
"sphinx-tabs",
5759
"sphinxcontrib-bibtex>=1",
5860
"sphinxext-opengraph",
61+
"igraph",
5962
]
6063
optional-dependencies.test = [
6164
"coverage",
6265
"pytest",
6366
"statsmodels",
6467
"gseapy",
6568
"scanpy",
69+
"igraph",
6670
]
6771
# https://docs.pypi.org/project_metadata/#project-urls
6872
urls.Documentation = "https://decoupler.readthedocs.io/"

src/decoupler/pl/_network.py

Lines changed: 117 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,51 @@
1-
from typing import Tuple
1+
from typing import TYPE_CHECKING
22

3-
import pandas as pd
4-
import numpy as np
53
import matplotlib
4+
import matplotlib.cm
5+
import matplotlib.colors
6+
import matplotlib.gridspec
7+
import numpy as np
8+
import pandas as pd
69
from matplotlib.lines import Line2D
7-
import matplotlib.pyplot as plt
8-
import igraph as ig
910

1011
from decoupler._docs import docs
1112
from decoupler._Plotter import Plotter
1213

14+
# Handle optional igraph dependency
15+
try:
16+
import igraph as ig
17+
18+
HAS_IGRAPH = True
19+
if TYPE_CHECKING:
20+
from igraph import Graph
21+
else:
22+
Graph = ig.Graph
23+
except ImportError:
24+
ig = None
25+
HAS_IGRAPH = False
26+
if TYPE_CHECKING:
27+
from typing import Any as Graph
28+
else:
29+
Graph = None
30+
31+
32+
def _check_igraph() -> None:
33+
"""Check if igraph is available and raise informative error if not."""
34+
if not HAS_IGRAPH:
35+
raise ImportError(
36+
"igraph is not installed. Please install it using:\n"
37+
" pip install igraph\n"
38+
"or install decoupler with plotting dependencies:\n"
39+
" pip install 'decoupler[plot]'"
40+
)
41+
1342

1443
def _src_idxs(
1544
score: pd.DataFrame,
1645
sources: int | list | str,
1746
by_abs: bool,
1847
) -> np.ndarray:
19-
assert isinstance(sources, (int, list, str)), \
20-
'sources must be int, list or str'
48+
assert isinstance(sources, (int, list, str)), "sources must be int, list or str"
2149
if isinstance(sources, int):
2250
if by_abs:
2351
s_idx = np.argsort(-abs(score.values[0]))[:sources]
@@ -36,62 +64,62 @@ def _trg_idxs(
3664
targets: int | list | str,
3765
by_abs: bool,
3866
) -> np.ndarray:
39-
assert isinstance(targets, (int, list, str)), \
40-
'targets must be int, list or str'
67+
assert isinstance(targets, (int, list, str)), "targets must be int, list or str"
4168
if isinstance(targets, int):
42-
net['prod'] = [data.iloc[0][t] * w if t in data.columns else 0 for t, w in zip(net['target'], net['weight'])]
69+
net["prod"] = [
70+
data.iloc[0][t] * w if t in data.columns else 0 for t, w in zip(net["target"], net["weight"], strict=False)
71+
]
4372
if by_abs:
44-
net['prod'] = abs(net['prod'])
73+
net["prod"] = abs(net["prod"])
4574
t_idx = (
46-
net
47-
.sort_values(['source', 'prod'], ascending=[True, False])
48-
.groupby(['source'], observed=True)
75+
net.sort_values(["source", "prod"], ascending=[True, False])
76+
.groupby(["source"], observed=True)
4977
.head(targets)
50-
.index
51-
.values
78+
.index.values
5279
)
5380
elif isinstance(targets, list):
54-
t_idx = np.isin(net['target'].astype(str), targets)
81+
t_idx = np.isin(net["target"].astype(str), targets)
5582
else:
56-
t_idx = np.isin(net['target'].astype(str), [targets])
83+
t_idx = np.isin(net["target"].astype(str), [targets])
5784
return t_idx
5885

59-
86+
6087
def _filter(
6188
data: pd.DataFrame,
6289
score: pd.DataFrame,
6390
net: pd.DataFrame,
6491
sources: int,
6592
targets: int,
6693
by_abs: bool,
67-
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
68-
assert isinstance(data, pd.DataFrame), 'data must be pd.DataFrame'
69-
assert isinstance(score, pd.DataFrame), 'score must be pd.DataFrame'
70-
assert np.all(data.index == score.index) and (data.index.size == 1), \
71-
'data and score need to have the same row index.'
72-
assert isinstance(by_abs, bool), 'by_abs must be bool'
94+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
95+
assert isinstance(data, pd.DataFrame), "data must be pd.DataFrame"
96+
assert isinstance(score, pd.DataFrame), "score must be pd.DataFrame"
97+
assert np.all(data.index == score.index) and (data.index.size == 1), (
98+
"data and score need to have the same row index."
99+
)
100+
assert isinstance(by_abs, bool), "by_abs must be bool"
73101
# Select top sources
74102
s_idx = _src_idxs(score=score, sources=sources, by_abs=by_abs)
75103
# Filter
76104
score = score.iloc[:, s_idx]
77-
net = net.loc[np.isin(net['source'].astype(str), score.columns.astype(str)), :].copy()
78-
if 'weight' not in net.columns:
79-
net['weight'] = 1.
105+
net = net.loc[np.isin(net["source"].astype(str), score.columns.astype(str)), :].copy()
106+
if "weight" not in net.columns:
107+
net["weight"] = 1.0
80108
# Select top targets
81109
t_idx = _trg_idxs(data=data, net=net, targets=targets, by_abs=by_abs)
82110
# Filter
83111
net = net.loc[t_idx]
84112
# Filter unmatched features
85-
data = data.loc[:, np.isin(data.columns.astype(str), net['target'].astype(str))]
86-
net = net.loc[np.isin(net['target'].astype(str), data.columns.astype(str)), :]
113+
data = data.loc[:, np.isin(data.columns.astype(str), net["target"].astype(str))]
114+
net = net.loc[np.isin(net["target"].astype(str), data.columns.astype(str)), :]
87115
return data, score, net
88116

89117

90118
def _norm(
91119
x: np.ndarray,
92120
vcenter: bool,
93121
) -> matplotlib.colors.Normalize:
94-
assert isinstance(vcenter, bool), 'vcenter must be bool'
122+
assert isinstance(vcenter, bool), "vcenter must be bool"
95123
if vcenter:
96124
vmax = np.max(np.abs(x))
97125
norm = matplotlib.colors.Normalize(vmin=-vmax, vmax=vmax)
@@ -105,7 +133,7 @@ def _norm(
105133
def _dict_types(
106134
data: pd.DataFrame,
107135
score: pd.DataFrame,
108-
) -> Tuple[dict, np.ndarray]:
136+
) -> tuple[dict, np.ndarray]:
109137
vs = np.unique(np.hstack([data.columns, score.columns]))
110138
v_dict = {k: i for i, k in enumerate(vs)}
111139
types = (~np.isin(vs, score.columns)) * 1
@@ -118,7 +146,7 @@ def _net_2_elist(
118146
) -> list:
119147
edges = []
120148
for i in net.index:
121-
source, target = net.loc[i, 'source'], net.loc[i, 'target']
149+
source, target = net.loc[i, "source"], net.loc[i, "target"]
122150
edge = [v_dict[source], v_dict[target]]
123151
edges.append(edge)
124152
return edges
@@ -128,8 +156,9 @@ def _net_2_g(
128156
data: pd.DataFrame,
129157
score: pd.DataFrame,
130158
net: pd.DataFrame,
131-
) -> ig.Graph:
159+
) -> Graph:
132160
# Unify network
161+
_check_igraph()
133162
v_dict, types = _dict_types(data=data, score=score)
134163
# Transform net to edges
135164
edges = _net_2_elist(net=net, v_dict=v_dict)
@@ -139,15 +168,15 @@ def _net_2_g(
139168
directed=True,
140169
)
141170
# Update attributes
142-
g.es['weight'] = net['weight'].values
143-
g.vs['type'] = types
144-
g.vs['label'] = list(v_dict.keys())
145-
g.vs['shape'] = np.where(types, 'circle', 'square')
171+
g.es["weight"] = net["weight"].values
172+
g.vs["type"] = types
173+
g.vs["label"] = list(v_dict.keys())
174+
g.vs["shape"] = np.where(types, "circle", "square")
146175
return g
147176

148177

149178
def _gcolors(
150-
g: ig.Graph,
179+
g: Graph,
151180
data: pd.DataFrame,
152181
score: pd.DataFrame,
153182
s_norm: matplotlib.colors.Normalize,
@@ -160,16 +189,16 @@ def _gcolors(
160189
s_cmap = matplotlib.colormaps.get_cmap(s_cmap)
161190
t_cmap = matplotlib.colormaps.get_cmap(t_cmap)
162191
color = []
163-
for i, k in enumerate(g.vs['label']):
164-
if g.vs['type'][i]:
192+
for i, k in enumerate(g.vs["label"]):
193+
if g.vs["type"][i]:
165194
color.append(t_cmap(t_norm(data[k].values[0])))
166195
else:
167196
color.append(s_cmap(s_norm(score[k].values[0])))
168197
is_cmap = True
169198
else:
170-
color = [s_cmap if typ == 0. else t_cmap for typ in g.vs['type']]
199+
color = [s_cmap if typ == 0.0 else t_cmap for typ in g.vs["type"]]
171200
is_cmap = False
172-
g.vs['color'] = color
201+
g.vs["color"] = color
173202
return is_cmap
174203

175204

@@ -183,15 +212,15 @@ def network(
183212
by_abs=True,
184213
size_node=5,
185214
size_label=2.5,
186-
s_cmap='RdBu_r',
187-
t_cmap='viridis',
215+
s_cmap="RdBu_r",
216+
t_cmap="viridis",
188217
vcenter=False,
189-
c_pos_w='darkgreen',
190-
c_neg_w='darkred',
191-
s_label='Enrichment\nscore',
192-
t_label='Gene\nexpression',
193-
layout='kk',
194-
**kwargs
218+
c_pos_w="darkgreen",
219+
c_neg_w="darkred",
220+
s_label="Enrichment\nscore",
221+
t_label="Gene\nexpression",
222+
layout="kk",
223+
**kwargs,
195224
):
196225
"""
197226
Plot results of enrichment analysis as network.
@@ -231,14 +260,16 @@ def network(
231260
Layout to use to order the nodes. Check ``igraph`` documentation for more options.
232261
%(plot)s
233262
"""
234-
assert isinstance(net, pd.DataFrame), 'net must be pd.DataFrame'
235-
assert (data is None) == (score is None), 'data and score must either both be None'
263+
assert isinstance(net, pd.DataFrame), "net must be pd.DataFrame"
264+
assert (data is None) == (score is None), "data and score must either both be None"
265+
if ig is None:
266+
raise ImportError("igraph is not installed. Please install it using `pip install igraph`.")
236267
if data is None:
237-
srcs = net['source'].unique().astype('U')
238-
score = pd.DataFrame(np.ones((1, srcs.size)), index=['0'], columns=srcs)
239-
trgs = net['target'].unique().astype('U')
240-
data = pd.DataFrame(np.ones((1, trgs.size)), index=['0'], columns=trgs)
241-
t_cmap = 'white'
268+
srcs = net["source"].unique().astype("U")
269+
score = pd.DataFrame(np.ones((1, srcs.size)), index=["0"], columns=srcs)
270+
trgs = net["target"].unique().astype("U")
271+
data = pd.DataFrame(np.ones((1, trgs.size)), index=["0"], columns=trgs)
272+
t_cmap = "white"
242273
# Filter
243274
fdata, fscore, fnet = _filter(
244275
data=data,
@@ -253,7 +284,7 @@ def network(
253284
t_norm = _norm(x=fdata, vcenter=vcenter)
254285
# Make graph
255286
g = _net_2_g(data=fdata, score=fscore, net=fnet)
256-
g.es['color'] = [c_pos_w if w > 0 else c_neg_w for w in g.es['weight']]
287+
g.es["color"] = [c_pos_w if w > 0 else c_neg_w for w in g.es["weight"]]
257288
is_cmap = _gcolors(
258289
g=g,
259290
data=data,
@@ -264,7 +295,7 @@ def network(
264295
t_cmap=t_cmap,
265296
)
266297
# Instance
267-
kwargs['ax'] = None
298+
kwargs["ax"] = None
268299
bp = Plotter(**kwargs)
269300
bp.fig.delaxes(bp.ax)
270301
# Plot
@@ -279,7 +310,7 @@ def network(
279310
layout=layout,
280311
vertex_size=(size_node * bp.dpi) / (bp.figsize[0] * bp.figsize[0]),
281312
vertex_size_label=(size_label * bp.dpi) / (bp.figsize[0] * bp.figsize[0]),
282-
bbox_inches='tight',
313+
bbox_inches="tight",
283314
)
284315
if is_cmap:
285316
sm = matplotlib.cm.ScalarMappable(norm=s_norm, cmap=s_cmap)
@@ -290,19 +321,35 @@ def network(
290321
ax2.axis("off")
291322
ax4.axis("off")
292323
# Add legend
293-
square = Line2D([0], [0], marker='s', color='w', label='Source', markerfacecolor='white',
294-
markeredgecolor='black', markersize=10)
295-
circle = Line2D([0], [0], marker='o', color='w', label='Target', markerfacecolor='white',
296-
markeredgecolor='black', markersize=10)
297-
line1 = Line2D((0, 0), (1, 0), color=c_pos_w, lw=2, marker='>',)
298-
line2 = Line2D((0, 0), (1, 0), color=c_neg_w, lw=2, marker='>',)
324+
square = Line2D(
325+
[0], [0], marker="s", color="w", label="Source", markerfacecolor="white", markeredgecolor="black", markersize=10
326+
)
327+
circle = Line2D(
328+
[0], [0], marker="o", color="w", label="Target", markerfacecolor="white", markeredgecolor="black", markersize=10
329+
)
330+
line1 = Line2D(
331+
(0, 0),
332+
(1, 0),
333+
color=c_pos_w,
334+
lw=2,
335+
marker=">",
336+
)
337+
line2 = Line2D(
338+
(0, 0),
339+
(1, 0),
340+
color=c_neg_w,
341+
lw=2,
342+
marker=">",
343+
)
299344
handles = [square, circle, line1, line2]
300-
labels = ['Source', 'Target', 'Positive', 'Negative']
345+
labels = ["Source", "Target", "Positive", "Negative"]
301346
legend = ax3.legend(
302347
handles=[square, circle, line1, line2],
303348
labels=labels,
304349
frameon=False,
305-
loc='center', bbox_to_anchor=(0.5, 0.5), bbox_transform=ax3.transAxes
350+
loc="center",
351+
bbox_to_anchor=(0.5, 0.5),
352+
bbox_transform=ax3.transAxes,
306353
)
307-
ax3.axis('off')
354+
ax3.axis("off")
308355
return bp._return()

0 commit comments

Comments
 (0)