-
Notifications
You must be signed in to change notification settings - Fork 208
[Doc] A fix to inconsistent ticks in Base module #2734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
a78b713
835ce46
56704c2
c2d39fb
d4d4d3e
a18406b
2ce1b09
54fda16
b18877c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Base class template for aeon estimators.""" | ||
"""Base class template for ``aeon`` estimators.""" | ||
|
||
__maintainer__ = ["MatthewMiddlehurst", "TonyBagnall"] | ||
__all__ = ["BaseAeonEstimator"] | ||
|
@@ -17,23 +17,24 @@ | |
|
||
class BaseAeonEstimator(BaseEstimator, ABC): | ||
""" | ||
Base class for defining estimators in aeon. | ||
Base class for defining estimators in ``aeon``. | ||
Contains the following methods: | ||
- reset estimator to post-init - reset(keep) | ||
- clone stimator (copy) - clone(random_state) | ||
- inspect tags (class method) - get_class_tags() | ||
- inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default, | ||
raise_error) | ||
- inspect tags (all) - get_tags() | ||
- inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error) | ||
- setting dynamic tags - set_tags(**tag_dict) | ||
- get fitted parameters - get_fitted_params(deep) | ||
- reset estimator to post-init - ``reset(keep)`` | ||
- clone stimator (copy) - ``clone(random_state)`` | ||
- inspect tags (class method) - ``get_class_tags()`` | ||
- inspect tags (one tag, class) - ``get_class_tag(tag_name, tag_value_default | ||
, raise_error)`` | ||
- inspect tags (all) - ``get_tags()`` | ||
- inspect tags (one tag) - ``get_tag(tag_name, tag_value_default | ||
, raise_error)`` | ||
- setting dynamic tags - ``set_tags(**tag_dict)`` | ||
- get fitted parameters - ``get_fitted_params(deep)`` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use the previous indentation please |
||
All estimators have the attribute: | ||
- fitted state flag - is_fitted | ||
- fitted state flag - `is_fitted` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not code style? |
||
""" | ||
|
||
_tags = { | ||
|
@@ -59,28 +60,28 @@ def reset(self, keep=None): | |
""" | ||
Reset the object to a clean post-init state. | ||
After a ``self.reset()`` call, self is equal or similar in value to | ||
After a ``self.reset()`` call, `self` is equal or similar in value to | ||
``type(self)(**self.get_params(deep=False))``, assuming no other attributes | ||
were kept using ``keep``. | ||
were kept using `keep`. | ||
Detailed behaviour: | ||
removes any object attributes, except: | ||
hyper-parameters (arguments of ``__init__``) | ||
object attributes containing double-underscores, i.e., the string "__" | ||
object attributes containing double-underscores, i.e., the string `"__"` | ||
runs ``__init__`` with current values of hyperparameters (result of | ||
``get_params``) | ||
Not affected by the reset are: | ||
object attributes containing double-underscores | ||
class and object methods, class attributes | ||
any attributes specified in the ``keep`` argument | ||
any attributes specified in the `keep` argument | ||
Parameters | ||
---------- | ||
keep : None, str, or list of str, default=None | ||
If None, all attributes are removed except hyperparameters. | ||
If str, only the attribute with this name is kept. | ||
If list of str, only the attributes with these names are kept. | ||
If ``None``, all attributes are removed except hyperparameters. | ||
If ``str``, only the attribute with this name is kept. | ||
If ``list`` of ``str``, only the attributes with these names are kept. | ||
Returns | ||
------- | ||
|
@@ -125,15 +126,16 @@ def clone(self, random_state=None): | |
Obtain a clone of the object with the same hyperparameters. | ||
A clone is a different object without shared references, in post-init state. | ||
This function is equivalent to returning ``sklearn.clone`` of self. | ||
This function is equivalent to returning ``sklearn.clone`` of `self`. | ||
Equal in value to ``type(self)(**self.get_params(deep=False))``. | ||
Parameters | ||
---------- | ||
random_state : int, RandomState instance, or None, default=None | ||
Sets the random state of the clone. If None, the random state is not set. | ||
If int, random_state is the seed used by the random number generator. | ||
If RandomState instance, random_state is the random number generator. | ||
Sets the random state of the clone. If ``None``, the random state is not | ||
set. | ||
If ``int``, `random_state` is the seed used by the random number generator. | ||
If ``RandomState`` instance, `random_state` is the random number generator. | ||
Returns | ||
------- | ||
|
@@ -187,21 +189,21 @@ def get_class_tag( | |
tag_name : str | ||
Name of tag value. | ||
raise_error : bool, default=True | ||
Whether a ValueError is raised when the tag is not found. | ||
Whether a ``ValueError`` is raised when the tag is not found. | ||
tag_value_default : any type, default=None | ||
Default/fallback value if tag is not found and error is not raised. | ||
Returns | ||
------- | ||
tag_value | ||
Value of the ``tag_name`` tag in cls. | ||
If not found, returns an error if ``raise_error`` is True, otherwise it | ||
If not found, returns an error if ``raise_error`` is ``True``, otherwise it | ||
returns ``tag_value_default``. | ||
Raises | ||
------ | ||
ValueError | ||
if ``raise_error`` is True and ``tag_name`` is not in | ||
if ``raise_error`` is ``True`` and ``tag_name`` is not in | ||
``self.get_tags().keys()`` | ||
Examples | ||
|
@@ -247,15 +249,15 @@ def get_tag(self, tag_name, raise_error=True, tag_value_default=None): | |
tag_name : str | ||
Name of tag to be retrieved. | ||
raise_error : bool, default=True | ||
Whether a ValueError is raised when the tag is not found. | ||
Whether a ``ValueError`` is raised when the tag is not found. | ||
tag_value_default : any type, default=None | ||
Default/fallback value if tag is not found and error is not raised. | ||
Returns | ||
------- | ||
tag_value | ||
Value of the ``tag_name`` tag in self. | ||
If not found, returns an error if ``raise_error`` is True, otherwise it | ||
If not found, returns an error if ``raise_error`` is ``True``, otherwise it | ||
returns ``tag_value_default``. | ||
Raises | ||
|
@@ -292,7 +294,7 @@ def set_tags(self, **tag_dict): | |
Returns | ||
------- | ||
self : object | ||
Reference to self. | ||
Reference to `self`. | ||
""" | ||
tag_update = deepcopy(tag_dict) | ||
self._tags_dynamic.update(tag_update) | ||
|
@@ -307,7 +309,7 @@ def get_fitted_params(self, deep=True): | |
Parameters | ||
---------- | ||
deep : bool, default=True | ||
If True, will return the fitted parameters for this estimator and | ||
If ``True``, will return the fitted parameters for this estimator and | ||
contained subobjects that are estimators. | ||
Returns | ||
|
@@ -371,9 +373,10 @@ def _get_test_params(cls, parameter_set="default"): | |
Returns | ||
------- | ||
params : dict or list of dict, default = {} | ||
Parameters to create testing instances of the class. Each dict are | ||
parameters to construct an "interesting" test instance, i.e., | ||
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. | ||
Parameters to create testing instances of the class. Each ``dict`` are | ||
parameters to construct an `"interesting"` test instance, i.e., | ||
``MyClass(**params)`` or ``MyClass(**params[i])`` creates a valid test | ||
instance. | ||
""" | ||
# default parameters = empty dict | ||
return {} | ||
|
@@ -383,23 +386,23 @@ def _create_test_instance(cls, parameter_set="default", return_first=True): | |
""" | ||
Construct Estimator instance if possible. | ||
Calls the `_get_test_params` method and returns an instance or list of instances | ||
using the returned dict or list of dict. | ||
Calls the ``_get_test_params`` method and returns an instance or ``list`` | ||
of instances using the returned ``dict`` or list of ``dict``. | ||
Parameters | ||
---------- | ||
parameter_set : str, default="default" | ||
Name of the set of test parameters to return, for use in tests. If no | ||
special parameters are defined for a value, will return `"default"` set. | ||
return_first : bool, default=True | ||
If True, return the first instance of the list of instances. | ||
If False, return the list of instances. | ||
If ``True``, return the first instance of the list of instances. | ||
If ``False``, return the list of instances. | ||
Returns | ||
------- | ||
instance : BaseAeonEstimator or list of BaseAeonEstimator | ||
Instance of the class with default parameters. If return_first | ||
is False, returns list of instances. | ||
Instance of the class with default parameters. If `return_first` | ||
is ``False``, returns list of instances. | ||
""" | ||
params = cls._get_test_params(parameter_set=parameter_set) | ||
|
||
|
@@ -421,7 +424,7 @@ def __sklearn_is_fitted__(self): | |
return self.is_fitted | ||
|
||
def __sklearn_tags__(self): | ||
"""Return sklearn style tags for the estimator.""" | ||
"""Return ``sklearn`` style tags for the estimator.""" | ||
aeon_tags = self.get_tags() | ||
sklearn_tags = super().__sklearn_tags__() | ||
sklearn_tags.non_deterministic = aeon_tags.get("non_deterministic", False) | ||
|
@@ -433,13 +436,13 @@ def __sklearn_tags__(self): | |
return sklearn_tags | ||
|
||
def _validate_data(self, **kwargs): | ||
"""Sklearn data validation.""" | ||
"""``Sklearn`` data validation.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
raise NotImplementedError( | ||
"aeon estimators do not have a _validate_data method." | ||
) | ||
|
||
def get_metadata_routing(self): | ||
"""Sklearn metadata routing. | ||
"""``Sklearn`` metadata routing. | ||
Not supported by ``aeon`` estimators. | ||
""" | ||
|
@@ -449,7 +452,7 @@ def get_metadata_routing(self): | |
|
||
@classmethod | ||
def _get_default_requests(cls): | ||
"""Sklearn metadata request defaults.""" | ||
"""``Sklearn`` metadata request defaults.""" | ||
from sklearn.utils._metadata_requests import MetadataRequest | ||
|
||
return MetadataRequest(None) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, do other packages i.e. scikit-learn do this for their package name?