Skip to content

Commit 5f8974f

Browse files
[BUG] LearningShapeletClassifier fixes (#1370)
* ls fixes * ls fixes * comment extra functions * tensorflow dep * skip test * changelog
1 parent dea395a commit 5f8974f

File tree

4 files changed

+70
-68
lines changed

4 files changed

+70
-68
lines changed

aeon/classification/shapelet_based/_ls.py

+63-61
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
Learning shapelet classifier that simply wraps the LearningShapelet class from tslearn.
44
"""
55

6+
__maintainer__ = ["MatthewMiddlehurst"]
67
__all__ = ["LearningShapeletClassifier"]
78

9+
810
import numpy as np
911

1012
from aeon.classification.base import BaseClassifier
@@ -31,67 +33,66 @@ class LearningShapeletClassifier(BaseClassifier):
3133
3234
Parameters
3335
----------
34-
n_shapelets_per_size: dict (default: None)
36+
n_shapelets_per_size: dict, default=None
3537
Dictionary giving, for each shapelet size (key),
3638
the number of such shapelets to be trained (value).
3739
If None, `grabocka_params_to_shapelet_size_dict` is used and the
3840
size used to compute is that of the shortest time series passed at fit
3941
time.
40-
41-
max_iter: int (default: 10,000)
42+
max_iter: int, default=10000
4243
Number of training epochs.
43-
44-
batch_size: int (default: 256)
44+
batch_size: int, default=256
4545
Batch size to be used.
46-
47-
verbose: {0, 1, 2} (default: 0)
46+
verbose: {0, 1, 2}, default=0
4847
`keras` verbose level.
49-
50-
optimizer: str or keras.optimizers.Optimizer (default: "sgd")
48+
optimizer: str or keras.optimizers.Optimizer, default="sgd"
5149
`keras` optimizer to use for training.
52-
53-
weight_regularizer: float or None (default: 0.)
50+
weight_regularizer: float or None, default=0.0
5451
Strength of the L2 regularizer to use for training the classification
5552
(softmax) layer. If 0, no regularization is performed.
56-
57-
shapelet_length: float (default: 0.15)
53+
shapelet_length: float, default=0.15
5854
The length of the shapelets, expressed as a fraction of the time
5955
series length.
6056
Used only if `n_shapelets_per_size` is None.
61-
62-
total_lengths: int (default: 3)
57+
total_lengths: int, default=3
6358
The number of different shapelet lengths. Will extract shapelets of
6459
length i * shapelet_length for i in [1, total_lengths]
6560
Used only if `n_shapelets_per_size` is None.
66-
67-
max_size: int or None (default: None)
61+
max_size: int or None, default=None
6862
Maximum size for time series to be fed to the model. If None, it is
6963
set to the size (number of timestamps) of the training time series.
70-
71-
scale: bool (default: False)
64+
scale: bool, default=False
7265
Whether input data should be scaled for each feature of each time
73-
series to lie in the [0-1] interval.
74-
Default for this parameter is set to `False` in version 0.4 to ensure
75-
backward compatibility, but is likely to change in a future version.
76-
77-
random_state : int or None, optional (default: None)
66+
series to lie in the [0-1] interval. Default for this parameter is set to
67+
`False`.
68+
random_state : int or None, default=None
7869
The seed of the pseudo random number generator to use when shuffling
7970
the data. If int, random_state is the seed used by the random number
8071
generator; If None, the random number generator is the RandomState
8172
instance used by `np.random`.
8273
8374
References
8475
----------
85-
.. [1] J. Grabocka et al. Learning Time-Series Shapelets. SIGKDD 2014.
86-
76+
.. Grabocka, J., Schilling, N., Wistuba, M. and Schmidt-Thieme, L., 2014, August.
77+
Learning time-series shapelets. In Proceedings of the 20th ACM SIGKDD
78+
international conference on Knowledge discovery and data mining (pp. 392-401).
79+
80+
Examples
81+
--------
82+
>>> from aeon.classification.shapelet_based import LearningShapeletClassifier
83+
>>> from aeon.testing.utils.data_gen import make_example_3d_numpy
84+
>>> X, y = make_example_3d_numpy(random_state=0)
85+
>>> clf = LearningShapeletClassifier(max_iter=50, random_state=0) # doctest: +SKIP
86+
>>> clf.fit(X, y) # doctest: +SKIP
87+
MrSQMClassifier(...)
88+
>>> clf.predict(X) # doctest: +SKIP
8789
"""
8890

8991
_tags = {
9092
"capability:multivariate": True,
9193
"algorithm_type": "shapelet",
9294
"cant-pickle": True,
93-
"python_dependencies": "tensorflow",
94-
"python_version": "<3.10",
95+
"python_dependencies": ["tslearn", "tensorflow"],
9596
}
9697

9798
def __init__(
@@ -108,7 +109,6 @@ def __init__(
108109
scale=False,
109110
random_state=None,
110111
):
111-
super().__init__()
112112
self.n_shapelets_per_size = n_shapelets_per_size
113113
self.max_iter = max_iter
114114
self.batch_size = batch_size
@@ -121,6 +121,8 @@ def __init__(
121121
self.scale = scale
122122
self.random_state = random_state
123123

124+
super().__init__()
125+
124126
def _fit(self, X, y):
125127
from tslearn.shapelets import LearningShapelets
126128

@@ -149,37 +151,37 @@ def _predict_proba(self, X) -> np.ndarray:
149151
_X_transformed = _X_transformed_tslearn(X)
150152
return self.clf_.predict_proba(_X_transformed)
151153

152-
def transform(self, X):
153-
"""Generate shapelet transform for a set of time series.
154-
155-
Parameters
156-
----------
157-
X : array-like of shape=(n_ts, sz, d)
158-
Time series dataset.
159-
160-
Returns
161-
-------
162-
array of shape=(n_ts, n_shapelets)
163-
Shapelet-Transform of the provided time series.
164-
"""
165-
_X_transformed = _X_transformed_tslearn(X)
166-
return self.clf_.transform(_X_transformed)
167-
168-
def locate(self, X):
169-
"""Compute shapelet match location for a set of time series.
170-
171-
Parameters
172-
----------
173-
X : array-like of shape=(n_ts, sz, d)
174-
Time series dataset.
175-
176-
Returns
177-
-------
178-
array of shape=(n_ts, n_shapelets)
179-
Location of the shapelet matches for the provided time series.
180-
"""
181-
_X_transformed = _X_transformed_tslearn(X)
182-
return self.clf_.locate(_X_transformed)
154+
# def transform(self, X):
155+
# """Generate shapelet transform for a set of time series.
156+
#
157+
# Parameters
158+
# ----------
159+
# X : array-like of shape=(n_ts, sz, d)
160+
# Time series dataset.
161+
#
162+
# Returns
163+
# -------
164+
# array of shape=(n_ts, n_shapelets)
165+
# Shapelet-Transform of the provided time series.
166+
# """
167+
# _X_transformed = _X_transformed_tslearn(X)
168+
# return self.clf_.transform(_X_transformed)
169+
#
170+
# def locate(self, X):
171+
# """Compute shapelet match location for a set of time series.
172+
#
173+
# Parameters
174+
# ----------
175+
# X : array-like of shape=(n_ts, sz, d)
176+
# Time series dataset.
177+
#
178+
# Returns
179+
# -------
180+
# array of shape=(n_ts, n_shapelets)
181+
# Location of the shapelet matches for the provided time series.
182+
# """
183+
# _X_transformed = _X_transformed_tslearn(X)
184+
# return self.clf_.locate(_X_transformed)
183185

184186
@classmethod
185187
def get_test_params(cls, parameter_set="default"):
@@ -203,4 +205,4 @@ def get_test_params(cls, parameter_set="default"):
203205
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
204206
`create_test_instance` uses the first (or only) dictionary in `params`.
205207
"""
206-
return {"n_shapelets_per_size": None, "shapelet_length": 0.15, "max_iter": 50}
208+
return {"max_iter": 50, "batch_size": 10}

aeon/testing/test_config.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
"Differencer": ["test_transform_inverse_transform_equivalent"],
5050
# Test fails, see https://github.com/aeon-toolkit/aeon/issues/1067
5151
"MockUnivariateForecasterLogger": ["test_non_state_changing_method_contract"],
52+
# has a keras fail, unknown reason
53+
"LearningShapeletClassifier": ["test_fit_deterministic"],
5254
}
5355

5456
# We use estimator tags in addition to class hierarchies to further distinguish

docs/changelogs/v0.8.md

+4-6
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,8 @@ April 2024
2727

2828
- [BUG] Fix random state for deep learning models in classification/regression and clustering ({pr}`1271`) {user}`hadifawaz1999`
2929
- [BUG] ElasticEnsemble with `euclidean` and `twe` distance measures ({pr}`1288`) {user}`itsdivya1309`
30-
- [ENH] fixed sqr error ({pr}`1240`) {user}`AnonymousCodes911`
31-
32-
### Enhancements
33-
34-
- [ENH] Learning Shapelet Classifier ({pr}`1247`) {user}`itsdivya1309`
30+
- [BUG] fixed sqr error ({pr}`1240`) {user}`AnonymousCodes911`
31+
- [BUG] LearningShapeletClassifier fixes ({pr}`1370`) {user}`MatthewMiddlehurst`
3532

3633
### Deprecation
3734

@@ -49,6 +46,7 @@ April 2024
4946
- [ENH] Update regression pipeline ({pr}`1279`) {user}`MatthewMiddlehurst`
5047
- [ENH] Loading unequal length, no missing values classification problems ({pr}`1157`) {user}`TonyBagnall`
5148
- [ENH] Tidy dummy estimators for classification and regression ({pr}`1281`) {user}`MatthewMiddlehurst`
49+
- [ENH] Learning Shapelet Classifier ({pr}`1247`) {user}`itsdivya1309`
5250

5351
### Maintenance
5452

@@ -202,7 +200,7 @@ April 2024
202200

203201
## Contributors
204202

205-
The following have contributed to this release through a collective 57 GitHub Pull Requests:
203+
The following have contributed to this release through a collective 58 GitHub Pull Requests:
206204

207205
{user}`aadya940`,
208206
{user}`AnonymousCodes911`,

docs/installation.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Installation
22

3-
`aeon` currently supports Python versions 3.8, 3.9, 3.10 and 3.11. Prior to these
3+
`aeon` currently supports Python versions 3.8, 3.9, 3.10, 3.11 and 3.12. Prior to these
44
instructions, please ensure you have a compatible version of Python installed
55
(i.e. from https://www.python.org).
66

0 commit comments

Comments
 (0)