Skip to content

Commit 2b67a49

Browse files
committed
Finish API tests
1 parent ef16c84 commit 2b67a49

File tree

4 files changed

+58
-32
lines changed

4 files changed

+58
-32
lines changed

h2o-py/h2o/model/model_base.py

+10
Original file line numberDiff line numberDiff line change
@@ -1986,6 +1986,16 @@ def _replace_empty_str(row):
19861986
)
19871987
return varimp
19881988

1989+
def distances(self):
1990+
"""
1991+
Obtain the distances frame for a KNN model.
1992+
1993+
:return: H2OFrame
1994+
"""
1995+
if self._model_json["algo"] != "knn":
1996+
raise H2OValueError("This function is available for KNN models only")
1997+
return h2o.get_frame(self._model_json["output"]["distances"])
1998+
19891999
# --------------------------------
19902000
# ModelBase representation methods
19912001
# --------------------------------

h2o-py/tests/testdir_algos/knn/pyunit_knn_compare_sklearn.py

+36-24
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
sys.path.insert(1, os.path.join("..", "..", ".."))
44
import h2o
5-
from tests import pyunit_utils, assert_equals
5+
from tests import pyunit_utils
66
from h2o.estimators.knn import H2OKnnEstimator
77
import numpy as np
8-
from sklearn.neighbors import KNeighborsClassifier
98
from sklearn.neighbors import kneighbors_graph
109
import pandas as pd
1110

@@ -15,35 +14,48 @@ def knn_sklearn_compare():
1514
id_column = "id"
1615
response_column = "class"
1716
x_names = ["sepal_len", "sepal_wid", "petal_len", "petal_wid"]
17+
k = 3
18+
metrics = ["euclidean", "manhattan", "cosine"]
1819

1920
train = pd.read_csv(pyunit_utils.locate("smalldata/iris/iris_wheader.csv"))
20-
21-
knn = KNeighborsClassifier(n_neighbors=3)
22-
knn.fit(train[x_names], train[response_column])
23-
print(knn)
24-
knn_score = knn.score(train[x_names], train[response_column])
25-
print(knn_score)
26-
27-
knn_graph = kneighbors_graph(train[x_names], 3, mode='connectivity', include_self=False, metric="euclidean")
28-
print(knn_graph)
29-
21+
3022
train_h2o = h2o.H2OFrame(train)
3123
train_h2o[response_column] = train_h2o[response_column].asfactor()
3224
train_h2o[id_column] = h2o.H2OFrame(np.arange(0, train_h2o.shape[0]))
25+
26+
for metric in metrics:
27+
print("Check results for "+metric+" metric.")
28+
sklearn_knn_graph = kneighbors_graph(train[x_names],
29+
k,
30+
mode='connectivity',
31+
include_self=True,
32+
metric=metric)
3333

34-
h2o_knn = H2OKnnEstimator(
35-
k=3,
36-
id_column=id_column,
37-
distance="euclidean",
38-
seed=seed,
39-
auc_type="macroovr"
40-
)
34+
h2o_knn = H2OKnnEstimator(k=k,
35+
id_column=id_column,
36+
distance=metric,
37+
seed=seed)
38+
39+
h2o_knn.train(y=response_column, x=x_names, training_frame=train_h2o)
40+
41+
distances_frame = h2o_knn.distances().as_data_frame()
42+
assert distances_frame is not None
4143

42-
h2o_knn.train(y=response_column, x=x_names, training_frame=train_h2o)
43-
distances_key = h2o_knn._model_json["output"]["distances"]
44-
print(distances_key)
45-
distances_frame = h2o.get_frame(distances_key)
46-
print(distances_frame)
44+
diff = 0
45+
allowed_diff = 20
46+
for i in range(train.shape[0]):
47+
sklearn_neighbours = sklearn_knn_graph[i].nonzero()[1]
48+
for j in range(k):
49+
sklearn_n = sklearn_neighbours[j]
50+
h2o_n = distances_frame["id_"+str(j+1)][i]
51+
if sklearn_n != h2o_n:
52+
print(distances_frame.loc[[i]])
53+
print("["+str(i)+","+str(j)+"] sklearn:h2o "+str(sklearn_n)+" == "+str(h2o_n))
54+
diff += 1
55+
56+
# some neighbours should have different order due to parallelization
57+
print("Number of different neighbours: "+str(diff))
58+
assert diff < allowed_diff
4759

4860

4961
if __name__ == "__main__":

h2o-py/tests/testdir_algos/knn/pyunit_knn_api_test.py renamed to h2o-py/tests/testdir_algos/knn/pyunit_knn_smoke.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ def knn_api_smoke():
1616
train_h2o = h2o.upload_file(pyunit_utils.locate("smalldata/iris/iris_wheader.csv"))
1717
train_h2o[response_column] = train_h2o[response_column].asfactor()
1818
train_h2o[id_column] = h2o.H2OFrame(np.arange(0, train_h2o.shape[0]))
19-
20-
19+
2120
model = H2OKnnEstimator(
2221
k=3,
2322
id_column=id_column,
@@ -35,6 +34,9 @@ def knn_api_smoke():
3534
assert_equals(perf.mse(), model.mse())
3635
assert_equals(perf.multinomial_auc_table(), model.multinomial_auc_table())
3736

37+
distances = model.distances()
38+
assert distances is not None
39+
3840

3941
if __name__ == "__main__":
4042
pyunit_utils.standalone_test(knn_api_smoke)

h2o-r/tests/testdir_algos/knn/runit_knn_smoke.R

+8-6
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,22 @@ source("../../../scripts/h2o-r-test-setup.R")
55

66
knn.smoke <- function() {
77
iris.hex <- h2o.uploadFile( locate("smalldata/iris/iris.csv"))
8-
iris.knn <- h2o.knn(x=1:4, y=5, training_frame=iris.hex, k=3 , distance="euclidean", seed=1234)
8+
9+
iris.hex$id <- as.h2o(1:nrow(iris.hex))
10+
iris.knn <- h2o.knn(x=1:4, y=5, training_frame=iris.hex, id_column = "id", k=3 , distance="euclidean", seed=1234, auc_type="WEIGHTED_OVO")
911

1012
# Score test data with different default auc_type (previous was "NONE", so no AUC calculation)
11-
perf <- h2o.performance(iris.knn, test.hex, auc_type="WEIGHTED_OVO")
13+
perf <- h2o.performance(iris.knn, iris.hex, auc_type="WEIGHTED_OVO")
1214

1315
# Check default AUC is set correctly
1416
auc_table <- h2o.multinomial_auc_table(perf)
1517
default_auc <- h2o.auc(perf)
16-
weighted_ovo_auc <- auc_table[32, 4] # weighted ovo AUC is the last number in the table
17-
18+
weighted_ovo_auc <- auc_table[10, 4] # weighted ovo AUC is the last number in the table
19+
1820
expect_equal(default_auc, weighted_ovo_auc)
1921

20-
distances <- iris.knn@model$distances
21-
print(distances)
22+
distances <- h2o.getFrame(iris.knn@model$distances)
23+
expect_equal(is.null(distances), FALSE)
2224
}
2325

2426
doTest("KNN Test: Check model is running.", knn.smoke)

0 commit comments

Comments
 (0)