-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrec.py
More file actions
94 lines (77 loc) · 2.47 KB
/
rec.py
File metadata and controls
94 lines (77 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import viz.colors as colors
from textwrap import wrap
def viz_bar_chart(
data,
top_k=5,
figsize=(10, 4),
save_file=None,
max_label_size=100,
remove_space=False,
y_label="Feature Set",
sort_again=True,
bounds=None,
**kwargs
):
feature_labels, attributions = zip(*data)
if top_k > len(attributions):
top_k = len(attributions)
args = np.argsort(-1 * np.abs(attributions))
args = args[:top_k]
args2 = np.argsort(np.array(attributions)[args])
feature_labels = np.array(feature_labels)[args]
attributions = np.array(attributions)[args]
if sort_again:
feature_labels = feature_labels[args2]
attributions = attributions[args2]
fig, axis = plt.subplots(figsize=figsize)
if bounds is None:
bounds = np.max(np.abs(attributions))
normalizer = mpl.colors.Normalize(vmin=-bounds, vmax=bounds)
if "cmap" in kwargs:
cmap = kwargs["cmap"]
else:
cmap = colors.pos_neg_colors()
axis.barh(
np.arange(top_k),
attributions,
color=[cmap(normalizer(c)) for c in attributions],
align="center",
zorder=10,
**kwargs
)
if not sort_again:
axis.invert_yaxis()
axis.set_xlabel("Attribution", fontsize=18)
axis.set_ylabel(y_label, fontsize=18)
axis.set_yticks(np.arange(top_k))
axis.tick_params(axis="y", which="both", left=False, labelsize=14)
axis.tick_params(axis="x", which="both", left=False, labelsize=14)
if remove_space:
token = " "
else:
token = ""
axis.set_yticklabels(
["\n".join(wrap(y, max_label_size)).replace(token, "") for y in feature_labels]
)
axis.grid(axis="x", zorder=0, linewidth=0.2)
axis.grid(axis="y", zorder=0, linestyle="--", linewidth=1.0)
_set_axis_config(axis, linewidths=(0.0, 0.0, 0.0, 1.0))
if save_file is not None:
plt.savefig(save_file, bbox_inches="tight")
def _set_axis_config(
axis, linewidths=(0.0, 0.0, 0.0, 0.0), clear_y_ticks=False, clear_x_ticks=False
):
"""
Source: Integrated Hessians Code Repo
"""
axis.spines["right"].set_linewidth(linewidths[0])
axis.spines["top"].set_linewidth(linewidths[1])
axis.spines["left"].set_linewidth(linewidths[2])
axis.spines["bottom"].set_linewidth(linewidths[3])
if clear_x_ticks:
axis.set_xticks([])
if clear_y_ticks:
axis.set_yticks([])