Skip to content

Commit a601cf2

Browse files
authored
[ENH] Split validation and conversion in utils (#1134)
* switch test example for pipeline * switch test example for pipeline * revise tests and standardise names * revise tests and standardise names * revise tests and standardise names * make_example_unequal_length * nest univ * move one off functions * cluster tests * docstring * move one off function to test location * move make_forecasting to _series * min length segments * remove test in FeatureUnion_pipeline * remove test in FeatureUnion_pipeline * conversion module * conversion module * move data generators * fix imports * base classifier * convert merge * import * import * notebook * Merge branch 'main' into ajb/data_gen # Conflicts: # aeon/datasets/_data_generators.py # aeon/forecasting/base/tests/test_base.py # aeon/testing/utils/data_gen/__init__.py # aeon/testing/utils/data_gen/_data_generators.py * Merge branch 'main' into ajb/data_gen # Conflicts: # aeon/datasets/_data_generators.py # aeon/forecasting/base/tests/test_base.py # aeon/testing/utils/data_gen/__init__.py # aeon/testing/utils/data_gen/_data_generators.py * Merge branch 'main' into ajb/data_gen # Conflicts: # aeon/datasets/_data_generators.py # aeon/forecasting/base/tests/test_base.py # aeon/testing/utils/data_gen/__init__.py # aeon/testing/utils/data_gen/_data_generators.py * Merge branch 'main' into ajb/data_gen # Conflicts: # aeon/datasets/_data_generators.py # aeon/forecasting/base/tests/test_base.py # aeon/testing/utils/data_gen/__init__.py # aeon/testing/utils/data_gen/_data_generators.py * get_examples * fix imports * fix imports * fix docstrings * fix docstrings * all
1 parent 09285e1 commit a601cf2

25 files changed

+655
-573
lines changed

aeon/base/_base_collection.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
"""Base class for estimators that fit collections of time series."""
22

33
from aeon.base._base import BaseEstimator
4-
from aeon.utils.validation import check_n_jobs
5-
from aeon.utils.validation._dependencies import _check_estimator_deps
6-
from aeon.utils.validation.collection import (
4+
from aeon.utils.conversion import (
75
convert_collection,
6+
resolve_equal_length_inner_type,
7+
resolve_unequal_length_inner_type,
8+
)
9+
from aeon.utils.validation import check_n_jobs
10+
from aeon.utils.validation._check_collection import (
811
get_n_cases,
912
has_missing,
1013
is_equal_length,
1114
is_univariate,
12-
resolve_equal_length_inner_type,
13-
resolve_unequal_length_inner_type,
1415
)
16+
from aeon.utils.validation._dependencies import _check_estimator_deps
1517

1618

1719
class BaseCollectionEstimator(BaseEstimator):
@@ -51,8 +53,8 @@ def _preprocess_collection(self, X):
5153
5254
Parameters
5355
----------
54-
X : data structure
55-
See aeon.utils.validation.collection.COLLECTIONS_DATA_TYPES for details
56+
X : collection
57+
See aeon.utils.conversion.COLLECTIONS_DATA_TYPES for details
5658
on aeon supported data structures.
5759
5860
Returns
@@ -102,7 +104,7 @@ def _check_X(self, X):
102104
Parameters
103105
----------
104106
X : data structure
105-
See aeon.utils.validation.collection.COLLECTIONS_DATA_TYPES for details
107+
See aeon.utils.conversion.COLLECTIONS_DATA_TYPES for details
106108
on aeon supported data structures.
107109
108110
Returns
@@ -164,7 +166,7 @@ def _convert_X(self, X):
164166
Parameters
165167
----------
166168
X : data structure
167-
must be of type aeon.utils.validation.collection.COLLECTIONS_DATA_TYPES.
169+
must be of type aeon.utils.conversion.COLLECTIONS_DATA_TYPES.
168170
169171
Returns
170172
-------
@@ -179,7 +181,7 @@ def _convert_X(self, X):
179181
--------
180182
>>> from aeon.classification.hybrid import HIVECOTEV2
181183
>>> import numpy as np
182-
>>> from aeon.utils.validation.collection import get_type
184+
>>> from aeon.utils.validation import get_type
183185
>>> X = [np.random.random(size=(5,10)), np.random.random(size=(5,10))]
184186
>>> get_type(X)
185187
'np-list'

aeon/base/_base_series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _check_X(self, X):
8080
----------
8181
X : data structure
8282
A valid aeon collection data structure. See
83-
aeon.utils.validation.collection.COLLECTIONS_DATA_TYPES for details
83+
aeon.utils.conversion.COLLECTIONS_DATA_TYPES for details
8484
on aeon supported data structures.
8585
8686
Returns

aeon/base/tests/test_base_collection.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import pytest
33

44
from aeon.base import BaseCollectionEstimator
5-
from aeon.utils.validation.collection import COLLECTIONS_DATA_TYPES
6-
from aeon.utils.validation.tests.test_collection import (
5+
from aeon.testing.utils.data_gen._collection import (
76
EQUAL_LENGTH_UNIVARIATE,
87
UNEQUAL_LENGTH_UNIVARIATE,
9-
get_type,
108
)
9+
from aeon.utils.conversion import COLLECTIONS_DATA_TYPES
10+
from aeon.utils.validation import get_type
1111

1212

1313
@pytest.mark.parametrize("data", COLLECTIONS_DATA_TYPES)

aeon/classification/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class name: BaseClassifier
3636
from aeon.base import BaseCollectionEstimator
3737
from aeon.base._base import _clone_estimator
3838
from aeon.utils.sklearn import is_sklearn_transformer
39+
from aeon.utils.validation._check_collection import get_n_cases
3940
from aeon.utils.validation._dependencies import _check_estimator_deps
40-
from aeon.utils.validation.collection import get_n_cases
4141

4242

4343
class BaseClassifier(BaseCollectionEstimator, ABC):

aeon/classification/compose/_pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from aeon.classification.base import BaseClassifier
1313
from aeon.transformations.base import BaseTransformer
1414
from aeon.transformations.compose import TransformerPipeline
15+
from aeon.utils.conversion import convert_collection
1516
from aeon.utils.sklearn import is_sklearn_classifier
16-
from aeon.utils.validation.collection import convert_collection
1717

1818

1919
class ClassifierPipeline(_HeterogenousMetaEstimator, BaseClassifier):

aeon/classification/tests/test_all_classifiers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from aeon.testing.test_all_estimators import BaseFixtureGenerator, QuickTester
1616
from aeon.testing.utils.estimator_checks import _assert_array_almost_equal
1717
from aeon.testing.utils.scenarios_classification import ClassifierFitPredict
18-
from aeon.utils.validation.collection import get_n_cases
18+
from aeon.utils.validation import get_n_cases
1919

2020

2121
class ClassifierFixtureGenerator(BaseFixtureGenerator):

aeon/classification/tests/test_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
MockClassifierFullTags,
1111
MockClassifierPredictProba,
1212
)
13-
from aeon.utils.validation.collection import COLLECTIONS_DATA_TYPES
14-
from aeon.utils.validation.tests.test_collection import (
13+
from aeon.testing.utils.data_gen._collection import (
1514
EQUAL_LENGTH_MULTIVARIATE,
1615
EQUAL_LENGTH_UNIVARIATE,
1716
UNEQUAL_LENGTH_UNIVARIATE,
1817
)
18+
from aeon.utils.conversion._convert_collection import COLLECTIONS_DATA_TYPES
1919

2020
__author__ = ["mloning", "fkiraly", "TonyBagnall", "MatthewMiddlehurst", "achieveordie"]
2121

aeon/datasets/_data_loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
get_downloaded_tsf_datasets,
1717
)
1818
from aeon.datasets.tser_data_lists import tser_monash, tser_soton
19-
from aeon.utils.validation.collection import convert_collection
19+
from aeon.utils.conversion import convert_collection
2020

2121
DIRNAME = "data"
2222
MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets")

aeon/datasets/_dataframe_loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import aeon
2020
from aeon.datasets._data_loaders import load_from_arff_file, load_from_tsfile
21-
from aeon.utils.validation.collection import convert_collection
21+
from aeon.utils.conversion import convert_collection
2222

2323
DIRNAME = "data"
2424
MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets")

aeon/forecasting/base/tests/test_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
make_example_3d_numpy,
1818
make_series,
1919
)
20+
from aeon.utils.conversion import convert_collection
2021
from aeon.utils.index_functions import get_cutoff, get_window
2122
from aeon.utils.validation._dependencies import _check_soft_dependencies
22-
from aeon.utils.validation.collection import convert_collection
2323

2424
COLLECTION_TYPES = ["pd-multiindex", "nested_univ", "numpy3D"]
2525
HIER_TYPES = ["pd_multiindex_hier"]

aeon/regression/compose/_pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from aeon.regression.base import BaseRegressor
77
from aeon.transformations.base import BaseTransformer
88
from aeon.transformations.compose import TransformerPipeline
9+
from aeon.utils.conversion import convert_collection
910
from aeon.utils.sklearn import is_sklearn_regressor
10-
from aeon.utils.validation.collection import convert_collection
1111

1212
__author__ = ["fkiraly"]
1313
__all__ = ["RegressorPipeline", "SklearnRegressorPipeline"]

aeon/regression/tests/test_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from aeon.datasets import load_covid_3month
88
from aeon.regression._dummy import DummyRegressor
99
from aeon.regression.base import BaseRegressor
10-
from aeon.utils.validation.collection import COLLECTIONS_DATA_TYPES
11-
from aeon.utils.validation.tests.test_collection import (
10+
from aeon.testing.utils.data_gen._collection import (
1211
EQUAL_LENGTH_UNIVARIATE,
1312
UNEQUAL_LENGTH_UNIVARIATE,
1413
)
14+
from aeon.utils.conversion import COLLECTIONS_DATA_TYPES
1515

1616

1717
class _TestRegressor(BaseRegressor):

aeon/testing/utils/data_gen/_collection.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
from sklearn.utils.validation import check_random_state
88

9-
from aeon.utils.validation.collection import convert_collection
9+
from aeon.utils.conversion import convert_collection
1010

1111

1212
def make_example_3d_numpy(
@@ -457,3 +457,71 @@ def _make_nested_from_array(array, n_instances=20, n_columns=1):
457457
[[pd.Series(array) for _ in range(n_columns)] for _ in range(n_instances)],
458458
columns=[f"col{c}" for c in range(n_columns)],
459459
)
460+
461+
462+
np_list = []
463+
for _ in range(10):
464+
np_list.append(np.random.random(size=(1, 20)))
465+
df_list = []
466+
for _ in range(10):
467+
df_list.append(pd.DataFrame(np.random.random(size=(20, 1))))
468+
nested, _ = make_example_nested_dataframe(n_cases=10)
469+
multiindex = make_example_multi_index_dataframe(
470+
n_instances=10, n_channels=1, n_timepoints=20
471+
)
472+
473+
EQUAL_LENGTH_UNIVARIATE = {
474+
"numpy3D": np.random.random(size=(10, 1, 20)),
475+
"np-list": np_list,
476+
"df-list": df_list,
477+
"numpy2D": np.zeros(shape=(10, 20)),
478+
"pd-wide": pd.DataFrame(np.zeros(shape=(10, 20))),
479+
"nested_univ": nested,
480+
"pd-multiindex": multiindex,
481+
}
482+
np_list_uneq = []
483+
for i in range(10):
484+
np_list_uneq.append(np.random.random(size=(1, 20 + i)))
485+
df_list_uneq = []
486+
for i in range(10):
487+
df_list_uneq.append(pd.DataFrame(np.random.random(size=(20 + i, 1))))
488+
489+
nested_univ_uneq = pd.DataFrame(dtype=float)
490+
instance_list = []
491+
for i in range(0, 10):
492+
instance_list.append(pd.Series(np.random.randn(20 + i)))
493+
nested_univ_uneq["channel0"] = instance_list
494+
495+
UNEQUAL_LENGTH_UNIVARIATE = {
496+
"np-list": np_list_uneq,
497+
"df-list": df_list_uneq,
498+
"nested_univ": nested_univ_uneq,
499+
}
500+
np_list_multi = []
501+
for _ in range(10):
502+
np_list_multi.append(np.random.random(size=(2, 20)))
503+
df_list_multi = []
504+
for _ in range(10):
505+
df_list_multi.append(pd.DataFrame(np.random.random(size=(20, 2))))
506+
multi = make_example_multi_index_dataframe(
507+
n_instances=10, n_channels=2, n_timepoints=20
508+
)
509+
510+
nested_univ_multi = pd.DataFrame(dtype=float)
511+
instance_list = []
512+
for _ in range(0, 10):
513+
instance_list.append(pd.Series(np.random.randn(20)))
514+
nested_univ_multi["channel0"] = instance_list
515+
instance_list = []
516+
for _ in range(0, 10):
517+
instance_list.append(pd.Series(np.random.randn(20)))
518+
nested_univ_multi["channel1"] = instance_list
519+
520+
521+
EQUAL_LENGTH_MULTIVARIATE = {
522+
"numpy3D": np.random.random(size=(10, 2, 20)),
523+
"np-list": np_list_multi,
524+
"df-list": df_list_multi,
525+
"nested_univ": nested_univ_multi,
526+
"pd-multiindex": multi,
527+
}

aeon/utils/conversion/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Conversion utilities."""
2+
3+
__all__ = [
4+
"equal_length",
5+
"resolve_equal_length_inner_type",
6+
"resolve_unequal_length_inner_type",
7+
"convert_collection",
8+
"COLLECTIONS_DATA_TYPES",
9+
]
10+
11+
from aeon.utils.conversion._convert_collection import (
12+
COLLECTIONS_DATA_TYPES,
13+
convert_collection,
14+
resolve_equal_length_inner_type,
15+
resolve_unequal_length_inner_type,
16+
)

0 commit comments

Comments
 (0)