Skip to content

Improve gak #514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions docs/user_guide/kernel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Global Alignment Kernel
The Global Alignment Kernel (GAK) is a kernel that operates on time
series.

It is defined, for a given bandwidth :math:`\sigma`, as:
The unnormalized GAK is defined, for a given bandwidth :math:`\sigma`, as:

.. math::

Expand All @@ -64,6 +64,15 @@ It is defined, for a given bandwidth :math:`\sigma`, as:
where :math:`\mathcal{A}(\mathbf{x}, \mathbf{y})` is the set of all possible
alignments between series :math:`\mathbf{x}` and :math:`\mathbf{y}`.

Note that the function ``gak`` is normalized in ``tslearn``: it corresponds to the quotient

.. math::

\text{gak}(\mathbf{x}, \mathbf{y}) = \frac{k(\mathbf{x}, \mathbf{y})}{\sqrt{k(\mathbf{x}, \mathbf{x})k(\mathbf{y}, \mathbf{y})}}

This normalization ensures that :math:`\text{gak}(\mathbf{x}, \mathbf{x})=1` for all :math:`\mathbf{x}`
and :math:`\text{gak}(\mathbf{x}, \mathbf{y}) \in [0, 1]` for all :math:`\mathbf{x}, \mathbf{y}`.

It is advised in [1]_ to set the bandwidth :math:`\sigma` as a multiple of a
simple estimate of the median distance of different points observed in
different time-series of your training set, scaled by the square root of the
Expand All @@ -81,7 +90,7 @@ This estimate is made available in ``tslearn`` through
Note however that, on long time series, this estimate can lead to numerical
overflows, which smaller values can avoid.

Finally, GAK is related to :ref:`softDTW <dtw-softdtw>` [3]_ through the
Finally, the unnormalized GAK is related to :ref:`softDTW <dtw-softdtw>` [3]_ through the
following formula:

.. math::
Expand Down
8 changes: 3 additions & 5 deletions tslearn/metrics/softdtw_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _gak(gram, be=None):
gram = be.array(gram)
sz1, sz2 = be.shape(gram)

cum_sum = be.zeros((sz1 + 1, sz2 + 1))
cum_sum = be.zeros((sz1 + 1, sz2 + 1), dtype=gram.dtype)
cum_sum[0, 0] = 1.0

for i in range(sz1):
Expand Down Expand Up @@ -240,10 +240,8 @@ def gak(s1, s2, sigma=1.0, be=None): # TODO: better doc (formula for the kernel
be = instantiate_backend(be, s1, s2)
s1 = be.array(s1)
s2 = be.array(s2)
denom = be.sqrt(
unnormalized_gak(s1, s1, sigma=sigma, be=be)
* unnormalized_gak(s2, s2, sigma=sigma, be=be)
)
denom = be.sqrt(unnormalized_gak(s1, s1, sigma=sigma, be=be)) * be.sqrt(
unnormalized_gak(s2, s2, sigma=sigma, be=be))
return unnormalized_gak(s1, s2, sigma=sigma, be=be) / denom


Expand Down
5 changes: 3 additions & 2 deletions tslearn/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _cdist_generic(
if dataset2 is None:
# Inspired from code by @GillesVandewiele:
# https://github.com/rtavenar/tslearn/pull/128#discussion_r314978479
matrix = be.zeros((len(dataset1), len(dataset1)))
matrix = be.zeros((len(dataset1), len(dataset1)), dtype=dataset1.dtype)
indices = be.triu_indices(
len(dataset1), k=0 if compute_diagonal else 1, m=len(dataset1)
)
Expand All @@ -89,7 +89,8 @@ def _cdist_generic(
delayed(dist_fun)(dataset1[i], dataset1[j], *args, **kwargs)
for i in range(len(dataset1))
for j in range(i if compute_diagonal else i + 1, len(dataset1))
)
),
dtype=matrix.dtype
)

indices = be.tril_indices(len(dataset1), k=-1, m=len(dataset1))
Expand Down
5 changes: 5 additions & 0 deletions tslearn/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,11 @@ def test_gak():
for array_type in array_types:
backend = instantiate_backend(be, array_type)
# GAK
gak_zeros = tslearn.metrics.gak(
s1=backend.zeros(405, dtype=backend.float64),
s2=backend.zeros(405, dtype=backend.float64),
sigma=1.0)
np.testing.assert_allclose(gak_zeros, desired=1, atol=1e-8)
g = tslearn.metrics.cdist_gak(
cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type), sigma=2.0, be=be
)
Expand Down