33import numpy as np
44import pytest
55
6- from aeon .clustering ._kernel_k_means import TimeSeriesKernelKMeans
6+ from aeon .clustering ._kernel_k_means import TimeSeriesKernelKMeans , _kdtw
77from aeon .datasets import load_basic_motions
88from aeon .utils .validation ._dependencies import _check_estimator_deps
99
1919
2020expected_results_kdtw = [0 , 2 , 0 , 0 , 0 ]
2121
22+ max_train = 5
23+
24+ X_train , y_train = load_basic_motions (split = "train" )
25+ X_test , y_test = load_basic_motions (split = "test" )
26+
2227
2328@pytest .mark .skipif (
2429 not _check_estimator_deps (TimeSeriesKernelKMeans , severity = "none" ),
2530 reason = "skip test if required soft dependencies not available" ,
2631)
27- def test_kernel_k_means ():
28- """Test implementation of kernel k means."""
29- max_train = 5
30-
31- X_train , y_train = load_basic_motions (split = "train" )
32- X_test , y_test = load_basic_motions (split = "test" )
33-
32+ def test_kernel_k_means_gak ():
33+ """Test implementation of kernel k means with GAK kernel."""
3434 kernel_kmeans = TimeSeriesKernelKMeans (random_state = 1 , n_clusters = 3 )
3535 kernel_kmeans .fit (X_train [0 :max_train ])
3636 test_shape_result = kernel_kmeans .predict (X_test [0 :max_train ])
@@ -44,6 +44,13 @@ def test_kernel_k_means():
4444 for val in proba :
4545 assert np .count_nonzero (val == 1.0 ) == 1
4646
47+
48+ @pytest .mark .skipif (
49+ not _check_estimator_deps (TimeSeriesKernelKMeans , severity = "none" ),
50+ reason = "skip test if required soft dependencies not available" ,
51+ )
52+ def test_kernel_k_means_kdtw ():
53+ """Test implementation of kernel k means with KDTW kernel."""
4754 kernel_kmeans_kdtw = TimeSeriesKernelKMeans (
4855 kernel = "kdtw" ,
4956 random_state = 1 ,
@@ -61,3 +68,17 @@ def test_kernel_k_means():
6168
6269 for val in kdtw_proba :
6370 assert np .count_nonzero (val == 1.0 ) == 1
71+
72+
73+ def test_kdtw_kernel_univariate ():
74+ """Test kdtw kernel for univariate time series."""
75+ # expected value created with the original (Matlab) code from:
76+ # https://people.irisa.fr/Pierre-Francois.Marteau/REDK/KDTW/KDTW.html
77+ x = np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], dtype = np .float64 ).reshape (- 1 , 1 )
78+ y = np .array ([5 , 6 , 7 , 8 , 9 , 1 , 2 ], dtype = np .float64 ).reshape (- 1 , 1 )
79+ sigma = 0.125
80+ epsilon = 1e-20
81+ expected_distance = 1.2814e-102
82+
83+ distance = _kdtw (x , y , sigma = sigma , epsilon = epsilon )
84+ np .testing .assert_allclose (expected_distance , distance , rtol = 1e-4 , atol = 1e-106 )
0 commit comments