Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
1646895
WIP remake module structure
baraline Dec 4, 2024
52b0692
Update _brute_force.py
baraline Dec 5, 2024
7973a30
Update test__commons.py
baraline Dec 5, 2024
7abe221
Merge remote-tracking branch 'origin/main' into 2341-enh-indexsearch-…
baraline Dec 14, 2024
ad02b84
WIP mock and test
baraline Dec 27, 2024
bb2aa33
Add test for base subsequence
baraline Jan 1, 2025
c5c9c28
Merge remote-tracking branch 'origin/main' into 2341-enh-indexsearch-…
baraline Jan 2, 2025
f23c720
Fix subsequence_search tests
baraline Jan 2, 2025
c372969
debug brute force mp
baraline Jan 2, 2025
d7da68b
more debug of subsequence tests
baraline Jan 2, 2025
da2758c
more debug of subsequence tests
baraline Jan 2, 2025
2191ac2
Add functional LSH neighbors
baraline Jan 10, 2025
cd33d0a
add notebook for sim search tasks
baraline Jan 13, 2025
b841b79
Updated series similarity search
baraline Jan 16, 2025
dbe9494
Merge remote-tracking branch 'origin/main' into 2341-enh-indexsearch-…
baraline Jan 16, 2025
57e5e7b
Fix mistake addition in transformers and fix base classes
baraline Jan 16, 2025
2078086
Fix registry and api reference
baraline Jan 16, 2025
9effbd9
Update documentation and fix some leftover bugs
baraline Jan 17, 2025
f51d66a
Update documentation and add default test params
baraline Jan 19, 2025
763bdcf
Fix identifiers and test data shape for all_estimators tests
baraline Jan 19, 2025
85c7174
Fix missing params
baraline Jan 20, 2025
038f844
Merge remote-tracking branch 'origin/main' into 2341-enh-indexsearch-…
baraline Feb 1, 2025
fd7caad
Fix n_jobs params and tags, add some docs
baraline Feb 1, 2025
6e3157b
Fix numba test bug and update testing data for sim search
baraline Feb 2, 2025
e3ccb3f
Fix imports, testing data tests, and impose predict/_predict interfac…
baraline Feb 2, 2025
ee7aa58
Fix args
baraline Feb 2, 2025
bf0c5e8
Fix extract test
baraline Feb 2, 2025
0c2d763
update docs api and notebooks
baraline Feb 2, 2025
db10499
remove notes
baraline Feb 2, 2025
3587de1
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
baraline Feb 2, 2025
9c671f6
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
baraline Feb 14, 2025
1e21767
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
patrickzib Mar 4, 2025
f328779
Patrick comments
baraline Mar 15, 2025
6d8541a
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
baraline Mar 15, 2025
617a927
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
baraline Mar 16, 2025
a34d58c
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
MatthewMiddlehurst Mar 23, 2025
680e8b0
Merge remote-tracking branch 'origin/main' into 2341-enh-indexsearch-…
baraline Mar 27, 2025
0a9434f
Adress comments and clean index code
baraline Mar 27, 2025
537caab
Adress comments and clean index code
baraline Mar 27, 2025
75ae978
Fix Patrick comments
baraline Mar 31, 2025
fa43a10
Fix variable suppression mistake
baraline Mar 31, 2025
68d817f
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
baraline Mar 31, 2025
69eaed2
Fix Base class comments
baraline Apr 16, 2025
f4a6414
Divide base class into task specific
baraline Apr 19, 2025
5c08b43
Fix typo in imports
baraline Apr 19, 2025
40c7e1a
Empty commit for CI
baraline Apr 19, 2025
58dd7b4
Fix typo again
baraline Apr 19, 2025
38e7b55
Merge branch '2341-enh-indexsearch-class-with-attimo-algorithm' of ht…
baraline Apr 19, 2025
6911bd9
Add check_inheritance exception for similarity search
baraline Apr 19, 2025
7fd9899
Revert back to non per type base classes
baraline Apr 19, 2025
470dcf9
Factor check index and typo in test
baraline Apr 19, 2025
d4fc858
Merge branch 'main' into 2341-enh-indexsearch-class-with-attimo-algor…
baraline May 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aeon/similarity_search/collection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Similarity search for time series collection."""

__all__ = ["BaseCollectionSimilaritySearch", "RandomProjectionIndexANN"]
__all__ = [
"BaseCollectionSimilaritySearch",
"RandomProjectionIndexANN",
]

from aeon.similarity_search.collection._base import BaseCollectionSimilaritySearch
from aeon.similarity_search.collection.neighbors._rp_cosine_lsh import (
Expand Down
19 changes: 3 additions & 16 deletions aeon/similarity_search/collection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def fit(
"""
self.reset()
X = self._preprocess_collection(X)
# Store minimum number of n_timepoints for unequal length collections
self.n_channels_ = X[0].shape[0]
self.n_cases_ = len(X)
self.n_channels_ = self.metadata_["n_channels"]
self.n_cases_ = self.metadata_["n_cases"]
self._fit(X, y=y)
self.is_fitted = True
return self
Expand Down Expand Up @@ -89,9 +88,7 @@ def predict(self, X, **kwargs):

"""
self._check_is_fitted()
if X[0].ndim == 1:
X = X[np.newaxis, :, :]
X = self._preprocess_collection(X)
X = self._preprocess_collection(X, store_metadata=False)
self._check_predict_series_format(X)
indexes, distances = self._predict(X, **kwargs)
return indexes, distances
Expand All @@ -105,16 +102,6 @@ def _check_predict_series_format(self, X):
X : np.ndarray, shape = (n_channels, n_timepoints)
A series to be used in predict.
"""
if isinstance(X, np.ndarray):
if X[0].ndim != 2:
raise TypeError(
"A np.ndarray given in predict must be 3D"
f"(n_channels, n_timepoints) but found {X.ndim}D."
)
else:
raise TypeError(
"Expected a 3D np.ndarray in predict but found" f" {type(X)}."
)
if self.n_channels_ != X[0].shape[0]:
raise ValueError(
f"Expected X to have {self.n_channels_} channels but"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class RandomProjectionIndexANN(BaseCollectionSimilaritySearch):

_tags = {
"capability:unequal_length": False,
"capability:multivariate": True,
"capability:multithreading": True,
}

Expand Down
49 changes: 6 additions & 43 deletions aeon/similarity_search/collection/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,18 @@

__maintainer__ = ["baraline"]

import pytest

from aeon.testing.mock_estimators._mock_similarity_searchers import (
MockCollectionSimilaritySearch,
)
from aeon.testing.testing_data import (
make_example_1d_numpy,
make_example_2d_numpy_series,
make_example_3d_numpy,
)
from aeon.testing.testing_data import FULL_TEST_DATA_DICT, _get_datatypes_for_estimator


def test_input_shape_fit_predict_collection():
"""Test input shapes."""
estimator = MockCollectionSimilaritySearch()
datatypes = _get_datatypes_for_estimator(estimator)
# dummy data to pass to fit when testing predict/predict_proba
X_3D_uni = make_example_3d_numpy(n_channels=1, return_y=False)
X_3D_multi = make_example_3d_numpy(n_channels=2, return_y=False)
X_2D_uni = make_example_2d_numpy_series(n_channels=1)
X_2D_multi = make_example_2d_numpy_series(n_channels=2)
X_1D = make_example_1d_numpy()

# 2D are converted to 3D
valid_inputs_fit = [
X_3D_uni,
X_3D_multi,
X_2D_uni,
X_2D_multi,
]
# Valid inputs
for _input in valid_inputs_fit:
estimator.fit(_input)

with pytest.raises(ValueError):
estimator.fit(X_1D)

estimator_multi = MockCollectionSimilaritySearch().fit(X_3D_multi)
estimator_uni = MockCollectionSimilaritySearch().fit(X_3D_uni)

estimator_uni.predict(X_2D_uni)
estimator_uni.predict(X_3D_uni)
estimator_multi.predict(X_2D_multi)
estimator_multi.predict(X_3D_multi)

with pytest.raises(ValueError):
estimator_uni.predict(X_2D_multi)
with pytest.raises(ValueError):
estimator_multi.predict(X_2D_uni)
with pytest.raises(ValueError):
estimator_multi.predict(X_3D_uni)
with pytest.raises(ValueError):
estimator_uni.predict(X_3D_multi)
for datatype in datatypes:
X_train, y_train = FULL_TEST_DATA_DICT[datatype]["train"]
X_test, y_test = FULL_TEST_DATA_DICT[datatype]["test"]
estimator.fit(X_train, y_train).predict(X_test)
11 changes: 9 additions & 2 deletions aeon/similarity_search/series/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Similarity search for series."""

__all__ = ["BaseSeriesSimilaritySearch", "MassSNN", "StompMotif", "DummySNN"]
__all__ = [
"BaseSeriesSimilaritySearch",
"MassSNN",
"StompMotif",
"DummySNN",
]

from aeon.similarity_search.series._base import BaseSeriesSimilaritySearch
from aeon.similarity_search.series._base import (
BaseSeriesSimilaritySearch,
)
from aeon.similarity_search.series.motifs._stomp import StompMotif
from aeon.similarity_search.series.neighbors._dummy import DummySNN
from aeon.similarity_search.series.neighbors._mass import MassSNN
49 changes: 13 additions & 36 deletions aeon/similarity_search/series/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Base similiarity search for series."""

__maintainer__ = ["baraline"]
__all__ = ["BaseSeriesSimilaritySearch"]

from abc import abstractmethod
from typing import final

Expand Down Expand Up @@ -53,9 +56,9 @@ def fit(
"""
self.reset()
X = self._preprocess_series(X, self.axis, True)
# Store minimum number of n_timepoints for unequal length collections
self.n_channels_ = X.shape[0]
self.n_timepoints_ = X.shape[1]
self.n_channels_ = self.metadata_["n_channels"]
timepoint_idx = 1 if self.axis == 1 else 0
self.n_timepoints_ = X.shape[timepoint_idx]
self.X_ = X
self._fit(X, y=y)
self.is_fitted = True
Expand All @@ -69,7 +72,7 @@ def _fit(
): ...

@final
def predict(self, X=None, **kwargs):
def predict(self, X, **kwargs):
"""
Predict function.

Expand All @@ -79,7 +82,7 @@ def predict(self, X=None, **kwargs):
Series to predict on.
kwargs : dict, optional
Additional keyword argument as dict or individual keywords args
to pass to use.
to pass to the estimator.

Returns
-------
Expand All @@ -90,41 +93,14 @@ def predict(self, X=None, **kwargs):

"""
self._check_is_fitted()
if X is not None:
X = self._preprocess_series(X, self.axis, False)
self._check_predict_series_format(X)
else:
X = self.X_
X = self._preprocess_series(X, self.axis, False)
self._check_predict_series_format(X)
indexes, distances = self._predict(X, **kwargs)
return indexes, distances

@abstractmethod
def _predict(self, X, **kwargs): ...

def _check_X_index(self, X_index: int):
"""
Check wheter a X_index parameter is correctly formated and is admissible.

Parameters
----------
X_index : int
Index of a timestamp in X_.

"""
if X_index is not None:
if not isinstance(X_index, int):
raise TypeError("Expected an integer for X_index but got {X_index}")

max_timepoints = self.n_timepoints_
if hasattr(self, "length"):
max_timepoints -= self.length
if X_index >= max_timepoints or X_index < 0:
raise ValueError(
"The value of X_index cannot exced the number "
"of timepoint in series given during fit. Expected a value "
f"between [0, {max_timepoints - 1}] but got {X_index}"
)

def _check_predict_series_format(self, X):
"""
Check wheter a series X is correctly formated regarding series given in fit.
Expand All @@ -135,8 +111,9 @@ def _check_predict_series_format(self, X):
A series to be used in predict.

"""
if self.n_channels_ != X.shape[0]:
channel_idx = 0 if self.axis == 1 else 1
if self.n_channels_ != X.shape[channel_idx]:
raise ValueError(
f"Expected X to have {self.n_channels_} channels but"
f" got {X.shape[0]} channels."
f" got {X.shape[channel_idx]} channels."
)
34 changes: 32 additions & 2 deletions aeon/similarity_search/series/_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,33 @@
from aeon.utils.numba.general import AEON_NUMBA_STD_THRESHOLD


def _check_X_index(X_index: int, n_timepoints: int, length: int):
"""
Check wheter a X_index parameter is correctly formated and is admissible.

Parameters
----------
X_index : int
Index of a timestamp in X_.
n_timepoints: int
Number of timepoints in the serie X_
length: int
Length parameter of the estimator

"""
if X_index is not None:
if not isinstance(X_index, int):
raise TypeError("Expected an integer for X_index but got {X_index}")

max_timepoints = n_timepoints - length
if X_index >= max_timepoints or X_index < 0:
raise ValueError(
"The value of X_index cannot exced the number "
"of timepoint in series given during fit. Expected a value "
f"between [0, {max_timepoints - 1}] but got {X_index}"
)


def fft_sliding_dot_product(X, q):
"""
Use FFT convolution to calculate the sliding window dot product.
Expand Down Expand Up @@ -161,7 +188,10 @@ def _extract_top_k_motifs(MP, IP, k, allow_trivial_matches, exclusion_size):
idx, _ = _extract_top_k_from_dist_profile(
criterion, k, np.inf, allow_trivial_matches, exclusion_size
)
return [MP[i] for i in idx], [IP[i] for i in idx]
return (
[IP[i] for i in idx],
[MP[i] for i in idx],
)


def _extract_top_r_motifs(MP, IP, k, allow_trivial_matches, exclusion_size):
Expand All @@ -175,7 +205,7 @@ def _extract_top_r_motifs(MP, IP, k, allow_trivial_matches, exclusion_size):
allow_trivial_matches,
exclusion_size,
)
return [MP[i] for i in idx], [IP[i] for i in idx]
return [IP[i] for i in idx], [MP[i] for i in idx]


@njit(cache=True, fastmath=True)
Expand Down
Loading