-
Notifications
You must be signed in to change notification settings - Fork 208
[ENH] Distance module n_jobs support #2545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
chrisholder
wants to merge
14
commits into
main
Choose a base branch
from
distance-module-n-jobs
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
ac185e2
added numba prange to all pairwise distances
chrisholder c406623
sfa update
chrisholder cbf5e19
changed warning
chrisholder 2b7acac
Merge branch 'main' into distance-module-n-jobs
baraline d35954d
Merge branch 'main' into distance-module-n-jobs
chrisholder 3ff96b1
custom pairwise threaded
chrisholder 2a74496
fixed
chrisholder 4afc8df
added threaded decorator
chrisholder b603ddb
fix
chrisholder 24b46b0
fix
chrisholder 85dfdab
Merge branch 'main' into distance-module-n-jobs
MatthewMiddlehurst 541e7fd
expanded threaded decorator to work with classes
chrisholder add714e
fix
chrisholder 7ee5846
Merge branch 'main' into distance-module-n-jobs
MatthewMiddlehurst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,16 @@ | |
|
||
__maintainer__ = ["SebastianSchmidl"] | ||
|
||
import warnings | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
from numba import njit, objmode | ||
from numba import njit, objmode, prange, set_num_threads | ||
from numba.typed import List as NumbaList | ||
from scipy.signal import correlate | ||
|
||
from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list | ||
from aeon.utils.validation import check_n_jobs | ||
from aeon.utils.validation.collection import _is_numpy_list_multivariate | ||
|
||
|
||
|
@@ -117,6 +119,8 @@ def sbd_pairwise_distance( | |
X: Union[np.ndarray, list[np.ndarray]], | ||
y: Optional[Union[np.ndarray, list[np.ndarray]]] = None, | ||
standardize: bool = True, | ||
n_jobs: int = 1, | ||
**kwargs, | ||
) -> np.ndarray: | ||
""" | ||
Compute the shape-based distance (SBD) between all pairs of time series. | ||
|
@@ -138,6 +142,10 @@ def sbd_pairwise_distance( | |
standardize : bool, default=True | ||
Apply z-score to both input time series for standardization before | ||
computing the distance. This makes SBD scaling invariant. Default is True. | ||
n_jobs : int, default=1 | ||
The number of jobs to run in parallel. If -1, then the number of jobs is set | ||
to the number of CPU cores. If 1, then the function is executed in a single | ||
thread. If greater than 1, then the function is executed in parallel. | ||
|
||
Returns | ||
------- | ||
|
@@ -188,6 +196,17 @@ def sbd_pairwise_distance( | |
[0.36754447, 0. , 0.29289322], | ||
[0.5527864 , 0.29289322, 0. ]]) | ||
""" | ||
n_jobs = check_n_jobs(n_jobs) | ||
set_num_threads(n_jobs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like this should set it back to the original value after? Maybe its possible to use a decorator in these cases similar to https://github.com/aeon-toolkit/aeon/blob/v1.0.0/aeon/testing/utils/output_suppression.py |
||
if n_jobs > 1: | ||
warnings.warn( | ||
"You have set n_jobs > 1. For this distance function " | ||
"unless your data has a large number of time points, it is " | ||
"recommended to use n_jobs=1. If this function is slower than " | ||
"expected try setting n_jobs=1.", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
multivariate_conversion = _is_numpy_list_multivariate(X, y) | ||
_X, _ = _convert_collection_to_numba_list(X, "", multivariate_conversion) | ||
|
||
|
@@ -199,30 +218,30 @@ def sbd_pairwise_distance( | |
return _sbd_pairwise_distance(_X, _y, standardize) | ||
|
||
|
||
@njit(cache=True, fastmath=True) | ||
@njit(cache=True, fastmath=True, parallel=True) | ||
def _sbd_pairwise_distance_single( | ||
x: NumbaList[np.ndarray], standardize: bool | ||
) -> np.ndarray: | ||
n_cases = len(x) | ||
distances = np.zeros((n_cases, n_cases)) | ||
|
||
for i in range(n_cases): | ||
for i in prange(n_cases): | ||
for j in range(i + 1, n_cases): | ||
distances[i, j] = sbd_distance(x[i], x[j], standardize) | ||
distances[j, i] = distances[i, j] | ||
|
||
return distances | ||
|
||
|
||
@njit(cache=True, fastmath=True) | ||
@njit(cache=True, fastmath=True, parallel=True) | ||
def _sbd_pairwise_distance( | ||
x: NumbaList[np.ndarray], y: NumbaList[np.ndarray], standardize: bool | ||
) -> np.ndarray: | ||
n_cases = len(x) | ||
m_cases = len(y) | ||
distances = np.zeros((n_cases, m_cases)) | ||
|
||
for i in range(n_cases): | ||
for i in prange(n_cases): | ||
for j in range(m_cases): | ||
distances[i, j] = sbd_distance(x[i], y[j], standardize) | ||
return distances | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.