Skip to content

Commit 9eccfa6

Browse files
fix: indexing in kdtw (#2826)
1 parent 9a0fe0c commit 9eccfa6

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

aeon/clustering/_kernel_k_means.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ def _kdtw_lk(x, y, local_kernel):
2626
diagonal_weights = np.zeros(max(x_timepoints, y_timepoints))
2727

2828
min_timepoints = min(x_timepoints, y_timepoints)
29-
diagonal_weights[1] = 1.0
29+
diagonal_weights[0] = 1.0
3030
for i in range(1, min_timepoints):
3131
diagonal_weights[i] = local_kernel[i - 1, i - 1]
3232

3333
cost_matrix[0, 0] = 1
3434
cumulative_dp_diag[0, 0] = 1
3535

3636
for i in range(1, x_timepoints):
37-
cost_matrix[i, 1] = cost_matrix[i - 1, 1] * local_kernel[i - 1, 2]
38-
cumulative_dp_diag[i, 1] = cumulative_dp_diag[i - 1, 1] * diagonal_weights[i]
37+
cost_matrix[i, 0] = cost_matrix[i - 1, 0] * local_kernel[i - 1, 0]
38+
cumulative_dp_diag[i, 0] = cumulative_dp_diag[i - 1, 0] * diagonal_weights[i]
3939

4040
for j in range(1, y_timepoints):
41-
cost_matrix[1, j] = cost_matrix[1, j - 1] * local_kernel[2, j - 1]
42-
cumulative_dp_diag[1, j] = cumulative_dp_diag[1, j - 1] * diagonal_weights[j]
41+
cost_matrix[0, j] = cost_matrix[0, j - 1] * local_kernel[0, j - 1]
42+
cumulative_dp_diag[0, j] = cumulative_dp_diag[0, j - 1] * diagonal_weights[j]
4343

4444
for i in range(1, x_timepoints):
4545
for j in range(1, y_timepoints):

aeon/clustering/tests/test_kernel_k_means.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55

6-
from aeon.clustering._kernel_k_means import TimeSeriesKernelKMeans
6+
from aeon.clustering._kernel_k_means import TimeSeriesKernelKMeans, _kdtw
77
from aeon.datasets import load_basic_motions
88
from aeon.utils.validation._dependencies import _check_estimator_deps
99

@@ -19,18 +19,18 @@
1919

2020
expected_results_kdtw = [0, 2, 0, 0, 0]
2121

22+
max_train = 5
23+
24+
X_train, y_train = load_basic_motions(split="train")
25+
X_test, y_test = load_basic_motions(split="test")
26+
2227

2328
@pytest.mark.skipif(
2429
not _check_estimator_deps(TimeSeriesKernelKMeans, severity="none"),
2530
reason="skip test if required soft dependencies not available",
2631
)
27-
def test_kernel_k_means():
28-
"""Test implementation of kernel k means."""
29-
max_train = 5
30-
31-
X_train, y_train = load_basic_motions(split="train")
32-
X_test, y_test = load_basic_motions(split="test")
33-
32+
def test_kernel_k_means_gak():
33+
"""Test implementation of kernel k means with GAK kernel."""
3434
kernel_kmeans = TimeSeriesKernelKMeans(random_state=1, n_clusters=3)
3535
kernel_kmeans.fit(X_train[0:max_train])
3636
test_shape_result = kernel_kmeans.predict(X_test[0:max_train])
@@ -44,6 +44,13 @@ def test_kernel_k_means():
4444
for val in proba:
4545
assert np.count_nonzero(val == 1.0) == 1
4646

47+
48+
@pytest.mark.skipif(
49+
not _check_estimator_deps(TimeSeriesKernelKMeans, severity="none"),
50+
reason="skip test if required soft dependencies not available",
51+
)
52+
def test_kernel_k_means_kdtw():
53+
"""Test implementation of kernel k means with KDTW kernel."""
4754
kernel_kmeans_kdtw = TimeSeriesKernelKMeans(
4855
kernel="kdtw",
4956
random_state=1,
@@ -61,3 +68,17 @@ def test_kernel_k_means():
6168

6269
for val in kdtw_proba:
6370
assert np.count_nonzero(val == 1.0) == 1
71+
72+
73+
def test_kdtw_kernel_univariate():
74+
"""Test kdtw kernel for univariate time series."""
75+
# expected value created with the original (Matlab) code from:
76+
# https://people.irisa.fr/Pierre-Francois.Marteau/REDK/KDTW/KDTW.html
77+
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.float64).reshape(-1, 1)
78+
y = np.array([5, 6, 7, 8, 9, 1, 2], dtype=np.float64).reshape(-1, 1)
79+
sigma = 0.125
80+
epsilon = 1e-20
81+
expected_distance = 1.2814e-102
82+
83+
distance = _kdtw(x, y, sigma=sigma, epsilon=epsilon)
84+
np.testing.assert_allclose(expected_distance, distance, rtol=1e-4, atol=1e-106)

0 commit comments

Comments
 (0)