Skip to content

Commit 3ff96b1

Browse files
committed
custom pairwise threaded
1 parent d35954d commit 3ff96b1

File tree

2 files changed

+74
-18
lines changed

2 files changed

+74
-18
lines changed

aeon/distances/_distance.py

+48-18
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
__maintainer__ = []
22

3-
import warnings
43
from enum import Enum
54
from typing import Any, Callable, Optional, TypedDict, Union
65

76
import numpy as np
7+
from joblib import Parallel, delayed
8+
from numba import set_num_threads
89
from typing_extensions import Unpack
910

1011
from aeon.distances._mpdist import mp_distance, mp_pairwise_distance
@@ -84,6 +85,7 @@
8485
squared_pairwise_distance,
8586
)
8687
from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list
88+
from aeon.utils.validation import check_n_jobs
8789
from aeon.utils.validation.collection import _is_numpy_list_multivariate
8890

8991

@@ -264,40 +266,54 @@ def _custom_func_pairwise(
264266
n_jobs: int = 1,
265267
**kwargs: Unpack[DistanceKwargs],
266268
) -> np.ndarray:
269+
n_jobs = check_n_jobs(n_jobs)
270+
set_num_threads(n_jobs)
267271
if dist_func is None:
268272
raise ValueError("dist_func must be a callable")
269273

270-
if n_jobs != 1:
271-
warnings.warn(
272-
"You are using a custom distance function with n_jobs > 1. "
273-
"Aeon does not support parallelization for custom distance "
274-
"functions. If it is an existing aeon distance try using the "
275-
"string name instead.",
276-
UserWarning,
277-
stacklevel=2,
278-
)
279-
280274
multivariate_conversion = _is_numpy_list_multivariate(X, y)
281275
X, _ = _convert_collection_to_numba_list(X, "X", multivariate_conversion)
276+
277+
if n_jobs > 1:
278+
X = np.array(X)
279+
282280
if y is None:
283281
# To self
284-
return _custom_pairwise_distance(X, dist_func, **kwargs)
282+
return _custom_pairwise_distance(X, dist_func, n_jobs=n_jobs, **kwargs)
285283
y, _ = _convert_collection_to_numba_list(y, "y", multivariate_conversion)
286-
return _custom_from_multiple_to_multiple_distance(X, y, dist_func, **kwargs)
284+
if n_jobs > 1:
285+
y = np.array(y)
286+
return _custom_from_multiple_to_multiple_distance(
287+
X, y, dist_func, n_jobs=n_jobs, **kwargs
288+
)
287289

288290

289291
def _custom_pairwise_distance(
290292
X: Union[np.ndarray, list[np.ndarray]],
291293
dist_func: DistanceFunction,
294+
n_jobs: int = 1,
292295
**kwargs: Unpack[DistanceKwargs],
293296
) -> np.ndarray:
294297
n_cases = len(X)
295298
distances = np.zeros((n_cases, n_cases))
296299

297-
for i in range(n_cases):
298-
for j in range(i + 1, n_cases):
300+
def compute_single_distance(i, j):
301+
return i, j, dist_func(X[i], X[j], **kwargs)
302+
303+
indices = [(i, j) for i in range(n_cases) for j in range(i + 1, n_cases)]
304+
305+
if n_jobs == 1:
306+
for i, j in indices:
299307
distances[i, j] = dist_func(X[i], X[j], **kwargs)
300-
distances[j, i] = distances[i, j]
308+
distances[j, i] = distances[i, j] # Mirror for symmetry
309+
else:
310+
results = Parallel(n_jobs=n_jobs)(
311+
delayed(compute_single_distance)(i, j) for i, j in indices
312+
)
313+
314+
for i, j, dist in results:
315+
distances[i, j] = dist
316+
distances[j, i] = dist # Mirror for symmetry
301317

302318
return distances
303319

@@ -306,15 +322,29 @@ def _custom_from_multiple_to_multiple_distance(
306322
x: Union[np.ndarray, list[np.ndarray]],
307323
y: Union[np.ndarray, list[np.ndarray]],
308324
dist_func: DistanceFunction,
325+
n_jobs: int = 1,
309326
**kwargs: Unpack[DistanceKwargs],
310327
) -> np.ndarray:
311328
n_cases = len(x)
312329
m_cases = len(y)
313330
distances = np.zeros((n_cases, m_cases))
314331

315-
for i in range(n_cases):
316-
for j in range(m_cases):
332+
def compute_single_distance(i, j):
333+
return i, j, dist_func(x[i], y[j], **kwargs)
334+
335+
indices = [(i, j) for i in range(n_cases) for j in range(m_cases)]
336+
337+
if n_jobs == 1:
338+
for i, j in indices:
317339
distances[i, j] = dist_func(x[i], y[j], **kwargs)
340+
else:
341+
results = Parallel(n_jobs=n_jobs)(
342+
delayed(compute_single_distance)(i, j) for i, j in indices
343+
)
344+
345+
for i, j, dist in results:
346+
distances[i, j] = dist
347+
318348
return distances
319349

320350

aeon/utils/numba/general.py

+26
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,29 @@ def get_all_subsequences(X: np.ndarray, length: int, dilation: int) -> np.ndarra
772772
out_shape = (n_timestamps - (length - 1) * dilation, n_features, np.int64(length))
773773
strides = (s1, s0, s1 * dilation)
774774
return np.lib.stride_tricks.as_strided(X, shape=out_shape, strides=strides)
775+
776+
777+
def is_numba_function(func) -> bool:
778+
"""Determine if a function is compiled with Numba.
779+
780+
Parameters
781+
----------
782+
func : callable
783+
The function to check.
784+
785+
Returns
786+
-------
787+
bool
788+
True if the function is compiled with Numba.
789+
"""
790+
if hasattr(func, "nopython_signatures") or hasattr(func, "__numba__"):
791+
return True
792+
793+
if hasattr(func, "_numba_type_"):
794+
return True
795+
796+
module_name = getattr(func, "__module__", "")
797+
if module_name and module_name.startswith("numba."):
798+
return True
799+
800+
return False

0 commit comments

Comments
 (0)