Skip to content

[ENH] DTW AROW implementation #2621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aeon/distances/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
"wddtw_pairwise_distance",
"wddtw_alignment_path",
"wddtw_cost_matrix",
"dtw_arow_distance",
"dtw_arow_pairwise_distance",
"dtw_arow_cost_matrix",
"dtw_arow_alignment_path",
"lcss_distance",
"lcss_pairwise_distance",
"lcss_alignment_path",
Expand Down Expand Up @@ -113,6 +117,10 @@
ddtw_distance,
ddtw_pairwise_distance,
dtw_alignment_path,
dtw_arow_alignment_path,
dtw_arow_cost_matrix,
dtw_arow_distance,
dtw_arow_pairwise_distance,
dtw_cost_matrix,
dtw_distance,
dtw_gi_alignment_path,
Expand Down
15 changes: 15 additions & 0 deletions aeon/distances/_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
ddtw_distance,
ddtw_pairwise_distance,
dtw_alignment_path,
dtw_arow_alignment_path,
dtw_arow_cost_matrix,
dtw_arow_distance,
dtw_arow_pairwise_distance,
dtw_cost_matrix,
dtw_distance,
dtw_gi_alignment_path,
Expand Down Expand Up @@ -624,6 +628,7 @@ def get_cost_matrix_function(method: str) -> CostMatrixFunction:
'ddtw' distances.ddtw_cost_matrix
'wdtw' distances.wdtw_cost_matrix
'wddtw' distances.wddtw_cost_matrix
'dtw_arow' distances.dtw_arow_cost_matrix
'adtw' distances.adtw_cost_matrix
'erp' distances.erp_cost_matrix
'edr' distances.edr_cost_matrix
Expand Down Expand Up @@ -773,6 +778,16 @@ class DistanceType(Enum):
"symmetric": True,
"unequal_support": True,
},
{
"name": "dtw_arow",
"distance": dtw_arow_distance,
"pairwise_distance": dtw_arow_pairwise_distance,
"cost_matrix": dtw_arow_cost_matrix,
"alignment_path": dtw_arow_alignment_path,
"type": DistanceType.ELASTIC,
"symmetric": True,
"unequal_support": True,
},
{
"name": "lcss",
"distance": lcss_distance,
Expand Down
10 changes: 10 additions & 0 deletions aeon/distances/elastic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
"wddtw_pairwise_distance",
"wddtw_alignment_path",
"wddtw_cost_matrix",
"dtw_arow_distance",
"dtw_arow_pairwise_distance",
"dtw_arow_cost_matrix",
"dtw_arow_alignment_path",
"lcss_distance",
"lcss_pairwise_distance",
"lcss_alignment_path",
Expand Down Expand Up @@ -75,6 +79,12 @@
dtw_distance,
dtw_pairwise_distance,
)
from aeon.distances.elastic._dtw_arow import (
dtw_arow_alignment_path,
dtw_arow_cost_matrix,
dtw_arow_distance,
dtw_arow_pairwise_distance,
)
from aeon.distances.elastic._dtw_gi import (
dtw_gi_alignment_path,
dtw_gi_cost_matrix,
Expand Down
195 changes: 195 additions & 0 deletions aeon/distances/elastic/_dtw_arow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
r"""Dynamic time warping (DTW) between two time series."""

__maintainer__ = []

from typing import Optional, Union

import numpy as np
from numba import njit
from numba.typed import List as NumbaList

from aeon.distances.pointwise._squared import _univariate_squared_distance
from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list
from aeon.utils.validation.collection import _is_numpy_list_multivariate


@njit(cache=True)
def dtw_arow_distance(x: np.ndarray, y: np.ndarray) -> float:
if x.ndim == 1 and y.ndim == 1:
_x = x.reshape((1, x.shape[0]))
_y = y.reshape((1, y.shape[0]))
return _dtw_arow_distance(_x, _y)
if x.ndim == 2 and y.ndim == 2:
return _dtw_arow_distance(x, y)
raise ValueError("x and y must be 1D or 2D")


@njit(cache=True)
def dtw_arow_cost_matrix(x: np.ndarray, y: np.ndarray) -> np.ndarray:
if x.ndim == 1 and y.ndim == 1:
_x = x.reshape((1, x.shape[0]))
_y = y.reshape((1, y.shape[0]))
return np.sqrt(_dtw_arow_cost_matrix(_x, _y))
if x.ndim == 2 and y.ndim == 2:
return np.sqrt(_dtw_arow_cost_matrix(x, y))
raise ValueError("x and y must be 1D or 2D")


@njit(cache=True)
def _dtw_arow_distance(x: np.ndarray, y: np.ndarray) -> float:
M1av = _total_non_nan(x)
M2av = _total_non_nan(y)
gamma = (x.shape[1] + y.shape[1]) / (M1av + M2av)
return np.sqrt(gamma * _dtw_arow_cost_matrix(x, y)[x.shape[1], y.shape[1]])


@njit(cache=True)
def _check_nan(x: np.ndarray) -> bool:
return np.sum(np.isnan(x)) > 0


@njit(cache=True, fastmath=True)
def _cost_function(x: np.ndarray, y: np.ndarray) -> float:
if _check_nan(x) or _check_nan(y):
return 0
return _univariate_squared_distance(x, y)


@njit(cache=True)
def _total_non_nan(x: np.ndarray) -> float:
if x.shape[0] == 1:
return np.sum(~np.isnan(x[0]))
else:
return np.sum(np.sum(np.isnan(x), axis=0) == 0)


@njit(cache=True)
def _dtw_arow_cost_path_helper(
x: np.ndarray, y: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
n = x.shape[1]
m = y.shape[1]
cost_matrix = np.full((n + 1, m + 1), np.inf)
phi = np.full((n + 1, m + 1), -1)

cost_matrix[0, 0] = 0

x_avail = np.zeros(n, dtype=np.bool_)
for k in range(n):
x_avail[k] = not _check_nan(x[:, k])

y_avail = np.zeros(m, dtype=np.bool_)
for k in range(m):
y_avail[k] = not _check_nan(y[:, k])

for i in range(1, n + 1):
for j in range(1, m + 1):
ev = np.inf
if i > 1:
if y_avail[j - 1] and x_avail[i - 1] and x_avail[i - 2]:
ev = 0

eh = np.inf
if j > 1:
if x_avail[i - 1] and y_avail[j - 1] and y_avail[j - 2]:
eh = 0

cost_diag = cost_matrix[i - 1, j - 1]
cost_vert = cost_matrix[i - 1, j] + ev
cost_horiz = cost_matrix[i, j - 1] + eh

options = np.array([cost_diag, cost_vert, cost_horiz])
best_prev = np.min(options)

if np.isinf(best_prev):
cost_matrix[i, j] = np.inf
else:
current_cost = _cost_function(x[:, i - 1], y[:, j - 1])
cost_matrix[i, j] = current_cost + best_prev
phi[i, j] = np.argmin(options)

return cost_matrix, phi


@njit(cache=True)
def _dtw_arow_cost_matrix(x: np.ndarray, y: np.ndarray) -> np.ndarray:
return _dtw_arow_cost_path_helper(x, y)[0]


def dtw_arow_pairwise_distance(
X: Union[np.ndarray, list[np.ndarray]],
y: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
) -> np.ndarray:
multivariate_conversion = _is_numpy_list_multivariate(X, y)
_X, _ = _convert_collection_to_numba_list(X, "X", multivariate_conversion)

if y is None:
# To self
return _dtw_arow_pairwise_distance(_X)
_y, _ = _convert_collection_to_numba_list(y, "y", multivariate_conversion)
return _dtw_arow_from_multiple_to_multiple_distance(_X, _y)


@njit(cache=True)
def _dtw_arow_pairwise_distance(X: NumbaList[np.ndarray]) -> np.ndarray:
n_cases = len(X)
distances = np.zeros((n_cases, n_cases))
for i in range(n_cases):
for j in range(i + 1, n_cases):
x1, x2 = X[i], X[j]
distances[i, j] = _dtw_arow_distance(x1, x2)
distances[j, i] = distances[i, j]

return distances


@njit(cache=True)
def _dtw_arow_from_multiple_to_multiple_distance(
x: NumbaList[np.ndarray], y: NumbaList[np.ndarray]
) -> np.ndarray:
n_cases = len(x)
m_cases = len(y)
distances = np.zeros((n_cases, m_cases))

for i in range(n_cases):
for j in range(m_cases):
x1, y1 = x[i], y[j]
distances[i, j] = _dtw_arow_distance(x1, y1)
return distances


@njit(cache=True)
def dtw_arow_alignment_path(
x: np.ndarray,
y: np.ndarray,
) -> tuple[list[tuple[int, int]], float]:
if x.ndim == 1 and y.ndim == 1:
_x = x.reshape((1, x.shape[0]))
_y = y.reshape((1, y.shape[0]))
return _dtw_arow_alignment_path(_x, _y)
if x.ndim == 2 and y.ndim == 2:
return _dtw_arow_alignment_path(x, y)
raise ValueError("x and y must be 1D or 2D")


@njit(cache=True)
def _dtw_arow_alignment_path(
x: np.ndarray, y: np.ndarray
) -> tuple[list[tuple[int, int]], float]:
i, j = x.shape[1], y.shape[1]
path = []
phi = _dtw_arow_cost_path_helper(x, y)[1]
while i > 0 and j > 0:
path.append((i - 1, j - 1))
step = phi[i, j]
if step == 0:
i, j = i - 1, j - 1
elif step == 1:
i = i - 1
elif step == 2:
j = j - 1
else:
break
path.reverse()

return path, dtw_arow_distance(x, y)
2 changes: 1 addition & 1 deletion aeon/distances/elastic/tests/test_cost_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _validate_cost_matrix_result(
assert_almost_equal(cost_matrix_callable_result, cost_matrix_result)
if name == "ddtw" or name == "wddtw":
assert cost_matrix_result.shape == (x.shape[-1] - 2, y.shape[-1] - 2)
elif name == "lcss":
elif name == "lcss" or name == "dtw_arow":
# lcss cm is one larger than the input
assert cost_matrix_result.shape == (x.shape[-1] + 1, y.shape[-1] + 1)
else:
Expand Down
5 changes: 5 additions & 0 deletions aeon/distances/elastic/tests/test_distance_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aeon.datasets import load_basic_motions, load_unit_test
from aeon.distances import (
ddtw_distance,
dtw_arow_distance,
dtw_distance,
dtw_gi_distance,
edr_distance,
Expand All @@ -34,6 +35,7 @@
"ddtw",
"wddtw",
"twe",
"dtw_arow",
]

distance_parameters = {
Expand Down Expand Up @@ -67,6 +69,7 @@
"squared": 757.25971908652,
"dtw": [757.259719, 330.834497, 330.834497],
"dtw_gi": [259.5333502342899, 310.10738471013804, 310.10738471013804],
"dtw_arow": 18.188856402051858,
"wdtw": [165.41724, 3.308425, 0],
"msm": [70.014828, 89.814828, 268.014828],
"erp": [169.3715, 102.0979, 102.097904],
Expand All @@ -91,6 +94,8 @@ def test_multivariate_correctness():
assert_almost_equal(d, basic_motions_distances["euclidean"], 4)
d = squared_distance(case1, case2)
assert_almost_equal(d, basic_motions_distances["squared"], 4)
d = dtw_arow_distance(case1, case2)
assert_almost_equal(d, basic_motions_distances["dtw_arow"], 4)
for j in range(0, 3):
d = dtw_distance(case1, case2, window=distance_parameters["dtw"][j])
assert_almost_equal(d, basic_motions_distances["dtw"][j], 4)
Expand Down
7 changes: 7 additions & 0 deletions aeon/testing/expected_results/expected_distance_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
5.893210968537887,
25.0,
],
"dtw_arow": [
0.5869589315413677,
0.5869589315413677,
0.5475954351379372,
2.0247879362835803,
5.0,
],
"ddtw": [
0.2963709096971962,
0.2963709096971962,
Expand Down