1
1
__maintainer__ = []
2
2
3
- import warnings
4
3
from enum import Enum
5
4
from typing import Any , Callable , Optional , TypedDict , Union
6
5
7
6
import numpy as np
7
+ from joblib import Parallel , delayed
8
+ from numba import set_num_threads
8
9
from typing_extensions import Unpack
9
10
10
11
from aeon .distances ._mpdist import mp_distance , mp_pairwise_distance
84
85
squared_pairwise_distance ,
85
86
)
86
87
from aeon .utils .conversion ._convert_collection import _convert_collection_to_numba_list
88
+ from aeon .utils .validation import check_n_jobs
87
89
from aeon .utils .validation .collection import _is_numpy_list_multivariate
88
90
89
91
@@ -264,40 +266,54 @@ def _custom_func_pairwise(
264
266
n_jobs : int = 1 ,
265
267
** kwargs : Unpack [DistanceKwargs ],
266
268
) -> np .ndarray :
269
+ n_jobs = check_n_jobs (n_jobs )
270
+ set_num_threads (n_jobs )
267
271
if dist_func is None :
268
272
raise ValueError ("dist_func must be a callable" )
269
273
270
- if n_jobs != 1 :
271
- warnings .warn (
272
- "You are using a custom distance function with n_jobs > 1. "
273
- "Aeon does not support parallelization for custom distance "
274
- "functions. If it is an existing aeon distance try using the "
275
- "string name instead." ,
276
- UserWarning ,
277
- stacklevel = 2 ,
278
- )
279
-
280
274
multivariate_conversion = _is_numpy_list_multivariate (X , y )
281
275
X , _ = _convert_collection_to_numba_list (X , "X" , multivariate_conversion )
276
+
277
+ if n_jobs > 1 :
278
+ X = np .array (X )
279
+
282
280
if y is None :
283
281
# To self
284
- return _custom_pairwise_distance (X , dist_func , ** kwargs )
282
+ return _custom_pairwise_distance (X , dist_func , n_jobs = n_jobs , ** kwargs )
285
283
y , _ = _convert_collection_to_numba_list (y , "y" , multivariate_conversion )
286
- return _custom_from_multiple_to_multiple_distance (X , y , dist_func , ** kwargs )
284
+ if n_jobs > 1 :
285
+ y = np .array (y )
286
+ return _custom_from_multiple_to_multiple_distance (
287
+ X , y , dist_func , n_jobs = n_jobs , ** kwargs
288
+ )
287
289
288
290
289
291
def _custom_pairwise_distance (
290
292
X : Union [np .ndarray , list [np .ndarray ]],
291
293
dist_func : DistanceFunction ,
294
+ n_jobs : int = 1 ,
292
295
** kwargs : Unpack [DistanceKwargs ],
293
296
) -> np .ndarray :
294
297
n_cases = len (X )
295
298
distances = np .zeros ((n_cases , n_cases ))
296
299
297
- for i in range (n_cases ):
298
- for j in range (i + 1 , n_cases ):
300
+ def compute_single_distance (i , j ):
301
+ return i , j , dist_func (X [i ], X [j ], ** kwargs )
302
+
303
+ indices = [(i , j ) for i in range (n_cases ) for j in range (i + 1 , n_cases )]
304
+
305
+ if n_jobs == 1 :
306
+ for i , j in indices :
299
307
distances [i , j ] = dist_func (X [i ], X [j ], ** kwargs )
300
- distances [j , i ] = distances [i , j ]
308
+ distances [j , i ] = distances [i , j ] # Mirror for symmetry
309
+ else :
310
+ results = Parallel (n_jobs = n_jobs )(
311
+ delayed (compute_single_distance )(i , j ) for i , j in indices
312
+ )
313
+
314
+ for i , j , dist in results :
315
+ distances [i , j ] = dist
316
+ distances [j , i ] = dist # Mirror for symmetry
301
317
302
318
return distances
303
319
@@ -306,15 +322,29 @@ def _custom_from_multiple_to_multiple_distance(
306
322
x : Union [np .ndarray , list [np .ndarray ]],
307
323
y : Union [np .ndarray , list [np .ndarray ]],
308
324
dist_func : DistanceFunction ,
325
+ n_jobs : int = 1 ,
309
326
** kwargs : Unpack [DistanceKwargs ],
310
327
) -> np .ndarray :
311
328
n_cases = len (x )
312
329
m_cases = len (y )
313
330
distances = np .zeros ((n_cases , m_cases ))
314
331
315
- for i in range (n_cases ):
316
- for j in range (m_cases ):
332
+ def compute_single_distance (i , j ):
333
+ return i , j , dist_func (x [i ], y [j ], ** kwargs )
334
+
335
+ indices = [(i , j ) for i in range (n_cases ) for j in range (m_cases )]
336
+
337
+ if n_jobs == 1 :
338
+ for i , j in indices :
317
339
distances [i , j ] = dist_func (x [i ], y [j ], ** kwargs )
340
+ else :
341
+ results = Parallel (n_jobs = n_jobs )(
342
+ delayed (compute_single_distance )(i , j ) for i , j in indices
343
+ )
344
+
345
+ for i , j , dist in results :
346
+ distances [i , j ] = dist
347
+
318
348
return distances
319
349
320
350
0 commit comments