Skip to content

Commit 0ede6e3

Browse files
baralinepatrickzibMatthewMiddlehurst
authored andcommitted
[MNT, ENH, DOC] Rework similarity search (#2473)
* WIP remake module structure * Update _brute_force.py * Update test__commons.py * WIP mock and test * Add test for base subsequence * Fix subsequence_search tests * debug brute force mp * more debug of subsequence tests * more debug of subsequence tests * Add functional LSH neighbors * add notebook for sim search tasks * Updated series similarity search * Fix mistake addition in transformers and fix base classes * Fix registry and api reference * Update documentation and fix some leftover bugs * Update documentation and add default test params * Fix identifiers and test data shape for all_estimators tests * Fix missing params * Fix n_jobs params and tags, add some docs * Fix numba test bug and update testing data for sim search * Fix imports, testing data tests, and impose predict/_predict interface to all sim search estimators * Fix args * Fix extract test * update docs api and notebooks * remove notes * Patrick comments * Adress comments and clean index code * Fix Patrick comments * Fix variable suppression mistake * Divide base class into task specific * Fix typo in imports * Empty commit for CI * Fix typo again * Add check_inheritance exception for similarity search * Revert back to non per type base classes * Factor check index and typo in test --------- Co-authored-by: Patrick Schäfer <[email protected]> Co-authored-by: Matthew Middlehurst <[email protected]> Co-authored-by: baraline <[email protected]>
1 parent 21c2280 commit 0ede6e3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+3046
-4349
lines changed

aeon/similarity_search/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Similarity search module."""
22

3-
__all__ = ["BaseSimilaritySearch", "QuerySearch", "SeriesSearch"]
3+
__all__ = ["BaseSimilaritySearch"]
44

5-
from aeon.similarity_search.base import BaseSimilaritySearch
6-
from aeon.similarity_search.query_search import QuerySearch
7-
from aeon.similarity_search.series_search import SeriesSearch
5+
from aeon.similarity_search._base import BaseSimilaritySearch

aeon/similarity_search/_base.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Base class for similarity search."""
2+
3+
__maintainer__ = ["baraline"]
4+
__all__ = [
5+
"BaseSimilaritySearch",
6+
]
7+
8+
9+
from abc import abstractmethod
10+
from typing import Union
11+
12+
import numpy as np
13+
from numba.typed import List
14+
15+
from aeon.base import BaseAeonEstimator
16+
17+
18+
class BaseSimilaritySearch(BaseAeonEstimator):
19+
"""Base class for similarity search applications."""
20+
21+
_tags = {
22+
"requires_y": False,
23+
"fit_is_empty": False,
24+
}
25+
26+
@abstractmethod
27+
def __init__(self):
28+
super().__init__()
29+
30+
@abstractmethod
31+
def fit(
32+
self,
33+
X: Union[np.ndarray, List],
34+
y=None,
35+
):
36+
"""
37+
Fit estimator to X.
38+
39+
State change:
40+
Changes state to "fitted".
41+
42+
Writes to self:
43+
_is_fitted : flag is set to True.
44+
45+
Parameters
46+
----------
47+
X : Series or Collection, any supported type
48+
Data to fit transform to, of python type as follows:
49+
Series: 2D np.ndarray shape (n_channels, n_timepoints)
50+
Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints)
51+
or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i)
52+
y: ignored, exists for API consistency reasons.
53+
54+
Returns
55+
-------
56+
self : a fitted instance of the estimator
57+
"""
58+
...
59+
60+
@abstractmethod
61+
def predict(
62+
self,
63+
X: Union[np.ndarray, None] = None,
64+
):
65+
"""
66+
Predict method.
67+
68+
Can either work with new series or with None (for case when predict can be made
69+
using the data given in fit against itself) depending on the estimator.
70+
71+
Parameters
72+
----------
73+
X : Series or Collection, any supported type
74+
Data to fit transform to, of python type as follows:
75+
Series: 2D np.ndarray shape (n_channels, n_timepoints)
76+
Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints)
77+
or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i
78+
None : If None type is accepted, it means that the predict function will
79+
work only with the data given in fit. (e.g. self matrix profile instead)
80+
"""
81+
...

0 commit comments

Comments
 (0)