Skip to content

[ENH] Distance module n_jobs support #2545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
25 changes: 22 additions & 3 deletions aeon/distances/_distance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__maintainer__ = []

import warnings
from enum import Enum
from typing import Any, Callable, Optional, TypedDict, Union

Expand Down Expand Up @@ -173,6 +174,7 @@ def pairwise_distance(
y: Optional[np.ndarray] = None,
method: Union[str, DistanceFunction, None] = None,
symmetric: bool = True,
n_jobs: int = 1,
**kwargs: Unpack[DistanceKwargs],
) -> np.ndarray:
"""Compute the pairwise distance matrix between two time series.
Expand All @@ -197,6 +199,10 @@ def pairwise_distance(
function is provided as the "method" parameter, then it will compute an
asymmetric distance matrix, and the entire matrix (including both upper and
lower triangles) is returned.
n_jobs : int, default=1
The number of jobs to run in parallel. If -1, then the number of jobs is set
to the number of CPU cores. If 1, then the function is executed in a single
thread. If greater than 1, then the function is executed in parallel.
kwargs : Any
Extra arguments for distance. Refer to each distance documentation for a list of
possible arguments.
Expand Down Expand Up @@ -240,11 +246,13 @@ def pairwise_distance(
[ 48.]])
"""
if method in PAIRWISE_DISTANCE:
return DISTANCES_DICT[method]["pairwise_distance"](x, y, **kwargs)
return DISTANCES_DICT[method]["pairwise_distance"](
x, y, n_jobs=n_jobs, **kwargs
)
elif isinstance(method, Callable):
if y is None and not symmetric:
return _custom_func_pairwise(x, x, method, **kwargs)
return _custom_func_pairwise(x, y, method, **kwargs)
return _custom_func_pairwise(x, x, method, n_jobs=n_jobs, **kwargs)
return _custom_func_pairwise(x, y, method, n_jobs=n_jobs, **kwargs)
else:
raise ValueError("Method must be one of the supported strings or a callable")

Expand All @@ -253,11 +261,22 @@ def _custom_func_pairwise(
X: Optional[Union[np.ndarray, list[np.ndarray]]],
y: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
dist_func: Union[DistanceFunction, None] = None,
n_jobs: int = 1,
**kwargs: Unpack[DistanceKwargs],
) -> np.ndarray:
if dist_func is None:
raise ValueError("dist_func must be a callable")

if n_jobs != 1:
warnings.warn(
"You are using a custom distance function with n_jobs > 1. "
"Aeon does not support parallelization for custom distance "
"functions. If it is an existing aeon distance try using the "
"string name instead.",
UserWarning,
stacklevel=2,
)

multivariate_conversion = _is_numpy_list_multivariate(X, y)
X, _ = _convert_collection_to_numba_list(X, "X", multivariate_conversion)
if y is None:
Expand Down
14 changes: 11 additions & 3 deletions aeon/distances/_mpdist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Matrix Profile Distances."""

import warnings
from typing import Optional, Union

import numpy as np
Expand Down Expand Up @@ -287,6 +288,7 @@ def mp_pairwise_distance(
X: Union[np.ndarray, list[np.ndarray]],
y: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
m: int = 0,
**kwargs,
) -> np.ndarray:
"""Compute the mpdist pairwise distance between a set of time series.

Expand Down Expand Up @@ -339,14 +341,20 @@ def mp_pairwise_distance(
[2.82842712],
[2.82842712]])
"""
if "n_jobs" in kwargs:
warnings.warn(
"n_jobs is not supported for the mpdist distance method and will be "
"ignored.",
UserWarning,
stacklevel=2,
)
if m == 0:
m = int(X.shape[2] / 4)
multivariate_conversion = _is_numpy_list_multivariate(X, y)
_X, unequal_length = _convert_collection_to_numba_list(
X, "X", multivariate_conversion
)

if m == 0:
m = int(_X.shape[2] / 4)

if y is None:
return _mpdist_pairwise_distance_single(_X, m)

Expand Down
29 changes: 24 additions & 5 deletions aeon/distances/_sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

__maintainer__ = ["SebastianSchmidl"]

import warnings
from typing import Optional, Union

import numpy as np
from numba import njit, objmode
from numba import njit, objmode, prange, set_num_threads
from numba.typed import List as NumbaList
from scipy.signal import correlate

from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list
from aeon.utils.validation import check_n_jobs
from aeon.utils.validation.collection import _is_numpy_list_multivariate


Expand Down Expand Up @@ -117,6 +119,8 @@ def sbd_pairwise_distance(
X: Union[np.ndarray, list[np.ndarray]],
y: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
standardize: bool = True,
n_jobs: int = 1,
**kwargs,
) -> np.ndarray:
"""
Compute the shape-based distance (SBD) between all pairs of time series.
Expand All @@ -138,6 +142,10 @@ def sbd_pairwise_distance(
standardize : bool, default=True
Apply z-score to both input time series for standardization before
computing the distance. This makes SBD scaling invariant. Default is True.
n_jobs : int, default=1
The number of jobs to run in parallel. If -1, then the number of jobs is set
to the number of CPU cores. If 1, then the function is executed in a single
thread. If greater than 1, then the function is executed in parallel.

Returns
-------
Expand Down Expand Up @@ -188,6 +196,17 @@ def sbd_pairwise_distance(
[0.36754447, 0. , 0.29289322],
[0.5527864 , 0.29289322, 0. ]])
"""
n_jobs = check_n_jobs(n_jobs)
set_num_threads(n_jobs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this should set it back to the original value after?

Maybe its possible to use a decorator in these cases similar to https://github.com/aeon-toolkit/aeon/blob/v1.0.0/aeon/testing/utils/output_suppression.py

if n_jobs > 1:
warnings.warn(
"You have set n_jobs > 1. For this distance function "
"unless your data has a large number of time points, it is "
"recommended to use n_jobs=1. If this function is slower than "
"expected try setting n_jobs=1.",
UserWarning,
stacklevel=2,
)
multivariate_conversion = _is_numpy_list_multivariate(X, y)
_X, _ = _convert_collection_to_numba_list(X, "", multivariate_conversion)

Expand All @@ -199,30 +218,30 @@ def sbd_pairwise_distance(
return _sbd_pairwise_distance(_X, _y, standardize)


@njit(cache=True, fastmath=True)
@njit(cache=True, fastmath=True, parallel=True)
def _sbd_pairwise_distance_single(
x: NumbaList[np.ndarray], standardize: bool
) -> np.ndarray:
n_cases = len(x)
distances = np.zeros((n_cases, n_cases))

for i in range(n_cases):
for i in prange(n_cases):
for j in range(i + 1, n_cases):
distances[i, j] = sbd_distance(x[i], x[j], standardize)
distances[j, i] = distances[i, j]

return distances


@njit(cache=True, fastmath=True)
@njit(cache=True, fastmath=True, parallel=True)
def _sbd_pairwise_distance(
x: NumbaList[np.ndarray], y: NumbaList[np.ndarray], standardize: bool
) -> np.ndarray:
n_cases = len(x)
m_cases = len(y)
distances = np.zeros((n_cases, m_cases))

for i in range(n_cases):
for i in prange(n_cases):
for j in range(m_cases):
distances[i, j] = sbd_distance(x[i], y[j], standardize)
return distances
Expand Down
15 changes: 12 additions & 3 deletions aeon/distances/_shift_scale_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import Optional, Union

import numpy as np
from numba import njit
from numba import njit, prange, set_num_threads
from numba.typed import List as NumbaList

from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list
from aeon.utils.validation import check_n_jobs
from aeon.utils.validation.collection import _is_numpy_list_multivariate


Expand Down Expand Up @@ -160,6 +161,8 @@ def shift_scale_invariant_pairwise_distance(
X: Union[np.ndarray, list[np.ndarray]],
y: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
max_shift: Optional[int] = None,
n_jobs: int = 1,
**kwargs,
) -> np.ndarray:
r"""Compute the shift-scale invariant pairwise distance between time series.

Expand Down Expand Up @@ -193,6 +196,10 @@ def shift_scale_invariant_pairwise_distance(
Maximum shift allowed in the alignment path. If None, then max_shift is set
to min(X.shape[-1], y.shape[-1]) or if y is None, max_shift is set to
X.shape[-1].
n_jobs : int, default=1
The number of jobs to run in parallel. If -1, then the number of jobs is set
to the number of CPU cores. If 1, then the function is executed in a single
thread. If greater than 1, then the function is executed in parallel.

Returns
-------
Expand Down Expand Up @@ -223,6 +230,8 @@ def shift_scale_invariant_pairwise_distance(
>>> y_univariate = np.array([11., 12., 13.])
>>> single_pw =shift_scale_invariant_pairwise_distance(X, y_univariate)
"""
n_jobs = check_n_jobs(n_jobs)
set_num_threads(n_jobs)
if max_shift is None:
if y is None:
max_shift = X.shape[-1]
Expand Down Expand Up @@ -308,15 +317,15 @@ def shift_scale_invariant_best_shift(
raise ValueError("x and y must be 1D or 2D")


@njit(cache=True, fastmath=True)
@njit(cache=True, fastmath=True, parallel=True)
def _shift_invariant_pairwise_distance(
x: NumbaList[np.ndarray], y: NumbaList[np.ndarray], max_shift: int
) -> np.ndarray:
n_cases = len(x)
m_cases = len(y)
distances = np.zeros((n_cases, m_cases))

for i in range(n_cases):
for i in prange(n_cases):
for j in range(m_cases):
distances[i, j] = shift_scale_invariant_distance(x[i], y[j], max_shift)
return distances
19 changes: 14 additions & 5 deletions aeon/distances/elastic/_adtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from typing import Optional, Union

import numpy as np
from numba import njit
from numba import njit, prange, set_num_threads
from numba.typed import List as NumbaList

from aeon.distances.elastic._alignment_paths import compute_min_return_path
from aeon.distances.elastic._bounding_matrix import create_bounding_matrix
from aeon.distances.pointwise._squared import _univariate_squared_distance
from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list
from aeon.utils.validation import check_n_jobs
from aeon.utils.validation.collection import _is_numpy_list_multivariate


Expand Down Expand Up @@ -203,6 +204,8 @@ def adtw_pairwise_distance(
window: Optional[float] = None,
itakura_max_slope: Optional[float] = None,
warp_penalty: float = 1.0,
n_jobs: int = 1,
**kwargs,
) -> np.ndarray:
r"""Compute the ADTW pairwise distance between a set of time series.

Expand All @@ -226,6 +229,10 @@ def adtw_pairwise_distance(
Penalty for warping. A high value will mean less warping.
warp less and if value is low then will encourage algorithm to warp
more.
n_jobs : int, default=1
The number of jobs to run in parallel. If -1, then the number of jobs is set
to the number of CPU cores. If 1, then the function is executed in a single
thread. If greater than 1, then the function is executed in parallel.

Returns
-------
Expand Down Expand Up @@ -272,6 +279,8 @@ def adtw_pairwise_distance(
[ 44., 0., 87.],
[294., 87., 0.]])
"""
n_jobs = check_n_jobs(n_jobs)
set_num_threads(n_jobs)
multivariate_conversion = _is_numpy_list_multivariate(X, y)
_X, unequal_length = _convert_collection_to_numba_list(
X, "X", multivariate_conversion
Expand All @@ -290,7 +299,7 @@ def adtw_pairwise_distance(
)


@njit(cache=True, fastmath=True)
@njit(cache=True, fastmath=True, parallel=True)
def _adtw_pairwise_distance(
X: NumbaList[np.ndarray],
window: Optional[float],
Expand All @@ -306,7 +315,7 @@ def _adtw_pairwise_distance(
bounding_matrix = create_bounding_matrix(
n_timepoints, n_timepoints, window, itakura_max_slope
)
for i in range(n_cases):
for i in prange(n_cases):
for j in range(i + 1, n_cases):
x1, x2 = X[i], X[j]
if unequal_length:
Expand All @@ -319,7 +328,7 @@ def _adtw_pairwise_distance(
return distances


@njit(cache=True, fastmath=True)
@njit(cache=True, fastmath=True, parallel=True)
def _adtw_from_multiple_to_multiple_distance(
x: NumbaList[np.ndarray],
y: NumbaList[np.ndarray],
Expand All @@ -336,7 +345,7 @@ def _adtw_from_multiple_to_multiple_distance(
bounding_matrix = create_bounding_matrix(
x[0].shape[1], y[0].shape[1], window, itakura_max_slope
)
for i in range(n_cases):
for i in prange(n_cases):
for j in range(m_cases):
x1, y1 = x[i], y[j]
if unequal_length:
Expand Down
Loading