Skip to content

Commit bdb56a1

Browse files
committed
Fixed missing value to throw error on unify_data. Black reformat.
1 parent e7920ac commit bdb56a1

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

python/interpret-core/interpret/utils/all.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ def unify_data(data, labels=None, feature_names=None, feature_types=None):
298298

299299
# NOTE: Until missing handling is introduced, all methods will fail at data unification stage if present.
300300
new_data_has_na = (
301-
True if new_data is not None and np.isnan(new_data).any() else False
301+
True if new_data is not None and pd.isnull(new_data).any() else False
302302
)
303303
new_labels_has_na = (
304-
True if new_labels is not None and np.isnan(new_labels).any() else False
304+
True if new_labels is not None and pd.isnull(new_labels).any() else False
305305
)
306306

307307
if new_data_has_na or new_labels_has_na:

python/interpret-core/interpret/utils/environment.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,15 @@ def _detect_databricks():
9090

9191

9292
def is_cloud_env(detected):
93-
cloud_env = ["databricks", "azure", "azureml_vm", "kaggle", "sagemaker", "binder", "colab"]
93+
cloud_env = [
94+
"databricks",
95+
"azure",
96+
"azureml_vm",
97+
"kaggle",
98+
"sagemaker",
99+
"binder",
100+
"colab",
101+
]
94102
if len(set(cloud_env).intersection(detected)) != 0:
95103
return True
96104
else:

python/interpret-core/interpret/utils/test/test_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import numpy as np
6+
import pandas as pd
67
from .. import gen_feat_val_list, gen_name_from_class
78
from .. import reverse_map, unify_data
89

@@ -20,6 +21,20 @@ def test_unify_fails_on_missing():
2021
unify_data(orig_data, orig_labels)
2122

2223

24+
def test_unify_dataframe_smoke():
25+
df = pd.DataFrame()
26+
df["f1"] = [1.5, "a"]
27+
df["f2"] = [3, "b"]
28+
df["label"] = [0, 1]
29+
30+
train_cols = df.columns[0:-1]
31+
label = df.columns[-1]
32+
X = df[train_cols]
33+
y = df[label]
34+
35+
unify_data(X, y)
36+
37+
2338
def test_unify_list_data():
2439
orig_data = [[1, 2], [3, 4]]
2540
orig_labels = [0, 0]

0 commit comments

Comments
 (0)