Skip to content

Commit 7bd487d

Browse files
committed
reformat
1 parent 8e30f7e commit 7bd487d

File tree

24 files changed

+251
-218
lines changed

24 files changed

+251
-218
lines changed

python/interpret-api/interpret/newapi/component.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@ def from_fields(cls, fields: List):
1414
class Attribution(Component):
1515
def __init__(self, values, base_values=None, units=None):
1616
self.fields = locals()
17-
del self.fields['self']
17+
del self.fields["self"]
1818

1919

2020
class BinnedData(Component):
21-
def __init__(self, data, data_counts=None, feature_names=None, feature_types=None, feature_indexes=None):
21+
def __init__(
22+
self,
23+
data,
24+
data_counts=None,
25+
feature_names=None,
26+
feature_types=None,
27+
feature_indexes=None,
28+
):
2229
if not isinstance(feature_names, (Obj, Alias, type(None))):
2330
feature_names = Alias(feature_names, 0)
2431

@@ -29,11 +36,13 @@ def __init__(self, data, data_counts=None, feature_names=None, feature_types=Non
2936
feature_indexes = Alias(feature_indexes, 0)
3037

3138
self.fields = locals()
32-
del self.fields['self']
39+
del self.fields["self"]
3340

3441

3542
class TabularData(Component):
36-
def __init__(self, data, feature_names=None, feature_types=None, feature_indexes=None):
43+
def __init__(
44+
self, data, feature_names=None, feature_types=None, feature_indexes=None
45+
):
3746
if not isinstance(feature_names, (Obj, Alias, type(None))):
3847
feature_names = Alias(feature_names, 1)
3948

@@ -44,31 +53,31 @@ def __init__(self, data, feature_names=None, feature_types=None, feature_indexes
4453
feature_indexes = Alias(feature_indexes, 1)
4554

4655
self.fields = locals()
47-
del self.fields['self']
56+
del self.fields["self"]
4857

4958

5059
class Bound(Component):
5160
def __init__(self, lower_bounds, upper_bounds):
5261
self.fields = locals()
53-
del self.fields['self']
62+
del self.fields["self"]
5463

5564

5665
# TODO: Consider separation of concerns for each field.
5766
class Meta(Component):
5867
def __init__(self, source, pivots, dimension_names=None):
5968
self.fields = locals()
60-
del self.fields['self']
69+
del self.fields["self"]
6170

6271

6372
class Extra(Component):
6473
def __init__(
65-
self,
66-
display_data=None,
67-
output_names=None,
68-
output_indexes=None, # Alias?
69-
main_effects=None,
70-
hierarchical_values=None,
71-
clustering=None,
74+
self,
75+
display_data=None,
76+
output_names=None,
77+
output_indexes=None, # Alias?
78+
main_effects=None,
79+
hierarchical_values=None,
80+
clustering=None,
7281
):
7382
self.fields = locals()
74-
del self.fields['self']
83+
del self.fields["self"]

python/interpret-api/interpret/newapi/explanation.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def __init__(self, **kwargs):
2424
# TODO: Needs further discussion at design-level.
2525
def append(self, component):
2626
if not isinstance(component, Component):
27-
raise Exception(f"Can't append object of type {type(component)} to this object.")
27+
raise Exception(
28+
f"Can't append object of type {type(component)} to this object."
29+
)
2830

2931
self.components[type(component)] = component
3032
for field_name, field_value in component.fields.items():
@@ -50,12 +52,18 @@ def __repr__(self):
5052
field_value_str = f"Dim\t{field_name} = {field_value}"
5153
else:
5254
if field_name in self._objects:
53-
field_type = 'O'
54-
field_dim = ','.join(str(x) for x in self._objects[field_name].dim)
55+
field_type = "O"
56+
field_dim = ",".join(
57+
str(x) for x in self._objects[field_name].dim
58+
)
5559
else:
56-
field_type = 'A'
57-
field_dim = ','.join(str(x) for x in self._aliases[field_name].dim)
58-
field_value_str = f"{field_type}{{{field_dim}}}\t{field_name} = {field_value}"
60+
field_type = "A"
61+
field_dim = ",".join(
62+
str(x) for x in self._aliases[field_name].dim
63+
)
64+
field_value_str = (
65+
f"{field_type}{{{field_dim}}}\t{field_name} = {field_value}"
66+
)
5967

6068
if len(field_value_str) > 60:
6169
field_value_str = field_value_str[:57] + "..."
@@ -81,6 +89,7 @@ def from_components(cls, components):
8189

8290
def to_json(self, **kwargs):
8391
from interpret.newapi.serialization import ExplanationJSONEncoder
92+
8493
version = "0.0.1"
8594
di = {
8695
"version": version,
@@ -100,4 +109,3 @@ def __init__(self, attrib, data=None, perf=None, bound=None, meta=None, **kwargs
100109
meta=meta,
101110
**kwargs,
102111
)
103-

python/interpret-api/interpret/newapi/test_explanation.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def test_explanation_serialize():
1010
data = [[0, 1], [1, 2]]
1111
data_counts = [1, 2]
12-
feature_names = ['f1', 'f2']
12+
feature_names = ["f1", "f2"]
1313
binned = BinnedData(
1414
O(data),
1515
O(data_counts),
@@ -31,7 +31,7 @@ def test_explanation_serialize():
3131

3232
assert deserialized.data == [[0, 1], [1, 2]]
3333
assert deserialized.data_counts == [1, 2]
34-
assert deserialized.feature_names == ['f1', 'f2']
34+
assert deserialized.feature_names == ["f1", "f2"]
3535
assert deserialized.values == [[1, 2], [3, 4]]
3636
assert deserialized.base_values == [0, 1]
3737
assert deserialized.units == "logits"
@@ -40,7 +40,7 @@ def test_explanation_serialize():
4040
def test_explanation():
4141
data = [[0, 1], [1, 2]]
4242
data_counts = [1, 2]
43-
feature_names = ['f1', 'f2']
43+
feature_names = ["f1", "f2"]
4444
binned = BinnedData(
4545
O(data),
4646
O(data_counts),
@@ -60,19 +60,19 @@ def test_explanation():
6060
expl = AttribExplanation.from_json(expl.to_json())
6161
actual = expl[0]
6262
assert actual.data == [0, 1]
63-
assert actual.feature_names == 'f1'
64-
assert actual.units == 'logits'
63+
assert actual.feature_names == "f1"
64+
assert actual.units == "logits"
6565
assert actual.base_values == 0
6666

67-
actual = expl['f1']
67+
actual = expl["f1"]
6868
assert actual.data == [0, 1]
69-
assert actual.feature_names == 'f1'
70-
assert actual.units == 'logits'
69+
assert actual.feature_names == "f1"
70+
assert actual.units == "logits"
7171
assert actual.base_values == 0
7272

7373
expl = AttribExplanation(attrib, extra=binned)
7474
actual = expl[0]
7575
assert actual.data == [0, 1]
76-
assert actual.feature_names == 'f1'
77-
assert actual.units == 'logits'
78-
assert actual.base_values == 0
76+
assert actual.feature_names == "f1"
77+
assert actual.units == "logits"
78+
assert actual.base_values == 0

python/interpret-core/interpret/data/_response.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,9 @@ def visualize(self, key=None):
278278
hovermode="closest",
279279
xaxis2=dict(domain=[do_hi, 1], showgrid=False, zeroline=False),
280280
yaxis2=dict(domain=[do_hi, 1], showgrid=False, zeroline=False),
281-
title="Pearson Correlation: {0:.3f}".format(corr)
282-
if corr is not None
283-
else "",
281+
title=(
282+
"Pearson Correlation: {0:.3f}".format(corr) if corr is not None else ""
283+
),
284284
)
285285
fig = go.Figure(data=data, layout=layout)
286286
return fig

python/interpret-core/interpret/develop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ def dynamic_system_info():
6666
system_info = {
6767
"psutil.virtual_memory": virtual_memory,
6868
"psutil.swap_memory": swap_memory,
69-
"psutil.avg_cpu_percent": None
70-
if cpu_percent is None
71-
else np.mean(cpu_percent),
72-
"psutil.std_cpu_percent": None
73-
if cpu_percent is None
74-
else np.std(cpu_percent),
69+
"psutil.avg_cpu_percent": (
70+
None if cpu_percent is None else np.mean(cpu_percent)
71+
),
72+
"psutil.std_cpu_percent": (
73+
None if cpu_percent is None else np.std(cpu_percent)
74+
),
7575
"psutil.cpu_freq": None if cpu_freq is None else cpu_freq._asdict(),
7676
}
7777
except Exception: # pragma: no cover

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -884,9 +884,11 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): # noqa: C9
884884
noise_scale_boosting,
885885
bin_data_weights,
886886
rngs[idx],
887-
Native.CreateBoosterFlags_DifferentialPrivacy
888-
if is_differential_privacy
889-
else Native.CreateBoosterFlags_Default,
887+
(
888+
Native.CreateBoosterFlags_DifferentialPrivacy
889+
if is_differential_privacy
890+
else Native.CreateBoosterFlags_Default
891+
),
890892
objective,
891893
None,
892894
)
@@ -996,9 +998,11 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): # noqa: C9
996998
Native.CalcInteractionFlags_Default,
997999
max_cardinality,
9981000
min_samples_leaf,
999-
Native.CreateInteractionFlags_DifferentialPrivacy
1000-
if is_differential_privacy
1001-
else Native.CreateInteractionFlags_Default,
1001+
(
1002+
Native.CreateInteractionFlags_DifferentialPrivacy
1003+
if is_differential_privacy
1004+
else Native.CreateInteractionFlags_Default
1005+
),
10021006
objective,
10031007
None,
10041008
)
@@ -1121,9 +1125,11 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): # noqa: C9
11211125
noise_scale_boosting,
11221126
bin_data_weights,
11231127
rngs[idx],
1124-
Native.CreateBoosterFlags_DifferentialPrivacy
1125-
if is_differential_privacy
1126-
else Native.CreateBoosterFlags_Default,
1128+
(
1129+
Native.CreateBoosterFlags_DifferentialPrivacy
1130+
if is_differential_privacy
1131+
else Native.CreateBoosterFlags_Default
1132+
),
11271133
objective,
11281134
None,
11291135
)

python/interpret-core/interpret/glassbox/_ebm/_json.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ def _to_json_inner(ebm, detail="all"):
262262
if histogram_weights is not None:
263263
feature_histogram_weights = histogram_weights[i]
264264
if feature_histogram_weights is not None:
265-
feature[
266-
"histogram_weights"
267-
] = feature_histogram_weights.tolist()
265+
feature["histogram_weights"] = (
266+
feature_histogram_weights.tolist()
267+
)
268268
else:
269269
raise ValueError(f"Unsupported feature type: {feature_type}")
270270

python/interpret-core/interpret/glassbox/_ebm/_research/_group_importance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
A term denotes both single features and interactions (pairs).
77
"""
8+
89
import numpy as np
910
import pandas as pd
1011
import plotly.express as px

python/interpret-core/interpret/utils/_measure_interactions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,11 @@ def measure_interactions(
270270
calc_interaction_flags=Native.CalcInteractionFlags_Pure,
271271
max_cardinality=max_cardinality,
272272
min_samples_leaf=min_samples_leaf,
273-
create_interaction_flags=Native.CreateInteractionFlags_DifferentialPrivacy
274-
if is_differential_privacy
275-
else Native.CreateInteractionFlags_Default,
273+
create_interaction_flags=(
274+
Native.CreateInteractionFlags_DifferentialPrivacy
275+
if is_differential_privacy
276+
else Native.CreateInteractionFlags_Default
277+
),
276278
objective=objective,
277279
experimental_params=None,
278280
n_output_interactions=n_output_interactions,

python/interpret-core/interpret/utils/_privacy.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ def validate_eps_delta(eps, delta):
2121

2222
def calc_classic_noise_multi(total_queries, target_epsilon, delta, sensitivity):
2323
variance = (
24-
8
25-
* total_queries
26-
* sensitivity**2
27-
* np.log(np.exp(1) + target_epsilon / delta)
24+
8 * total_queries * sensitivity**2 * np.log(np.exp(1) + target_epsilon / delta)
2825
) / target_epsilon**2
2926
return np.sqrt(variance)
3027

0 commit comments

Comments
 (0)