diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md
index 79e701b844..582cabbcb9 100644
--- a/docs/releases/unreleased.md
+++ b/docs/releases/unreleased.md
@@ -1 +1,6 @@
# Unreleased
+
+## base
+
+- The `base` module is now fully type-annotated. Some type hints have changed, but this does not impact the behaviour of the code. For instance, the regression target is now indicated as a float instead of a Number.
+- The `tags` and `more_tags` properties of `base.Estimator` are now both a set of strings.
diff --git a/pyproject.toml b/pyproject.toml
index 346c94489a..88e1056af9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -174,7 +174,6 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
# Disable strict mode for all non fully-typed modules
module = [
- "river.base.*",
"river.metrics.*",
"river.utils.*",
"river.stats.*",
diff --git a/river/anomaly/pad.py b/river/anomaly/pad.py
index 0ddd3a403f..d5126c5913 100644
--- a/river/anomaly/pad.py
+++ b/river/anomaly/pad.py
@@ -130,7 +130,7 @@ def learn_one(self, x: dict | None, y: base.typing.Target | float):
else:
self.predictive_model.learn_one(y=y, x=x)
else:
- self.predictive_model.learn_one(x=x, y=y)
+ self.predictive_model.learn_one(x=x, y=y) # type:ignore[attr-defined]
def score_one(self, x: dict, y: base.typing.Target):
# Return the predicted value of x from the predictive model, first by checking whether
@@ -138,7 +138,7 @@ def score_one(self, x: dict, y: base.typing.Target):
if isinstance(self.predictive_model, time_series.base.Forecaster):
y_pred = self.predictive_model.forecast(self.horizon)[0]
else:
- y_pred = self.predictive_model.predict_one(x)
+ y_pred = self.predictive_model.predict_one(x) # type:ignore[attr-defined]
# Calculate the squared error
squared_error = (y_pred - y) ** 2
diff --git a/river/bandit/evaluate.py b/river/bandit/evaluate.py
index 0079079c8e..caf4ea46dd 100644
--- a/river/bandit/evaluate.py
+++ b/river/bandit/evaluate.py
@@ -136,10 +136,10 @@ def evaluate(
if done[policy_idx]:
continue
- arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined]
+ arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined, arg-type]
observation, reward, terminated, truncated, info = env_.step(arm)
policy_.update(arm, reward)
- reward_stat_.update(reward)
+ reward_stat_.update(reward) # type: ignore[arg-type]
yield {
"episode": episode,
diff --git a/river/base/__init__.py b/river/base/__init__.py
index 0aaa521934..f54493087f 100644
--- a/river/base/__init__.py
+++ b/river/base/__init__.py
@@ -29,6 +29,7 @@
from .multi_output import MultiLabelClassifier, MultiTargetRegressor
from .regressor import MiniBatchRegressor, Regressor
from .transformer import (
+ BaseTransformer,
MiniBatchSupervisedTransformer,
MiniBatchTransformer,
SupervisedTransformer,
@@ -38,6 +39,7 @@
__all__ = [
"Base",
+ "BaseTransformer",
"BinaryDriftDetector",
"BinaryDriftAndWarningDetector",
"Classifier",
diff --git a/river/base/base.py b/river/base/base.py
index 16e4f829b3..0db98c8759 100644
--- a/river/base/base.py
+++ b/river/base/base.py
@@ -10,6 +10,8 @@
import types
import typing
+import typing_extensions
+
class Base:
"""Base class that is inherited by the majority of classes in River.
@@ -22,14 +24,14 @@ class Base:
"""
- def __str__(self):
+ def __str__(self) -> str:
return self.__class__.__name__
- def __repr__(self):
+ def __repr__(self) -> str:
return _repr_obj(obj=self)
@classmethod
- def _unit_test_params(cls):
+ def _unit_test_params(cls) -> collections.abc.Iterator[dict[str, typing.Any]]:
"""Instantiates an object with default arguments.
Most parameters of each object have a default value. However, this isn't always the case,
@@ -71,7 +73,9 @@ def _get_params(self) -> dict[str, typing.Any]:
return params
- def clone(self, new_params: dict | None = None, include_attributes=False):
+ def clone(
+ self, new_params: dict[str, typing.Any] | None = None, include_attributes: bool = False
+ ) -> typing_extensions.Self:
"""Return a fresh estimator with the same parameters.
The clone has the same parameters but has not been updated with any data.
@@ -167,7 +171,7 @@ def clone(self, new_params: dict | None = None, include_attributes=False):
"""
- def is_class_param(param):
+ def is_class_param(param: typing.Any) -> bool:
# See expand_param_grid to understand why this is necessary
return (
isinstance(param, tuple)
@@ -202,10 +206,10 @@ def is_class_param(param):
return clone
@property
- def _mutable_attributes(self) -> set:
+ def _mutable_attributes(self) -> set[str]:
return set()
- def mutate(self, new_attrs: dict):
+ def mutate(self, new_attrs: dict[str, typing.Any]) -> None:
"""Modify attributes.
This changes parameters inplace. Although you can change attributes yourself, this is the
@@ -296,8 +300,8 @@ def mutate(self, new_attrs: dict):
"""
- def _mutate(obj, new_attrs):
- def is_class_attr(name, attr):
+ def _mutate(obj: typing.Any, new_attrs: dict[str, typing.Any]) -> None:
+ def is_class_attr(name: str, attr: typing.Any) -> bool:
return hasattr(getattr(obj, name), "mutate") and isinstance(attr, dict)
for name, attr in new_attrs.items():
@@ -318,7 +322,7 @@ def is_class_attr(name, attr):
_mutate(obj=self, new_attrs=new_attrs)
@property
- def _is_stochastic(self):
+ def _is_stochastic(self) -> bool:
"""Indicates if the model contains an unset seed parameter.
The convention in River is to control randomness by exposing a seed parameter. This seed
@@ -329,14 +333,14 @@ def _is_stochastic(self):
"""
- def is_class_param(param):
+ def is_class_param(param: typing.Any) -> bool:
return (
isinstance(param, tuple)
and inspect.isclass(param[0])
and isinstance(param[1], dict)
)
- def find(params):
+ def find(params: dict[str, typing.Any]) -> bool:
if not isinstance(params, dict):
return False
for name, param in params.items():
@@ -354,7 +358,7 @@ def _raw_memory_usage(self) -> int:
import numpy as np
- buffer = collections.deque([self])
+ buffer: collections.deque[typing.Any] = collections.deque([self])
seen = set()
size = 0
while len(buffer) > 0:
@@ -369,7 +373,7 @@ def _raw_memory_usage(self) -> int:
buffer.extend([k for k in obj.keys()])
buffer.extend([v for v in obj.values()])
elif hasattr(obj, "__dict__"): # Save object contents
- contents: dict = vars(obj)
+ contents = vars(obj)
size += sys.getsizeof(contents)
buffer.extend([k for k in contents.keys()])
buffer.extend([v for v in contents.values()])
@@ -384,7 +388,7 @@ def _raw_memory_usage(self) -> int:
elif hasattr(obj, "__iter__") and not (
isinstance(obj, str) or isinstance(obj, bytes) or isinstance(obj, bytearray)
):
- buffer.extend([i for i in obj]) # type: ignore
+ buffer.extend([i for i in obj])
return size
@@ -396,7 +400,12 @@ def _memory_usage(self) -> str:
return utils.pretty.humanize_bytes(self._raw_memory_usage)
-def _log_method_calls(self, name, class_condition, method_condition):
+def _log_method_calls(
+ self: typing.Any,
+ name: str,
+ class_condition: typing.Callable[[typing.Any], bool],
+ method_condition: typing.Callable[[typing.Any], bool],
+) -> typing.Any:
method = object.__getattribute__(self, name)
if (
not name.startswith("_")
@@ -412,7 +421,7 @@ def _log_method_calls(self, name, class_condition, method_condition):
def log_method_calls(
class_condition: typing.Callable[[typing.Any], bool] | None = None,
method_condition: typing.Callable[[typing.Any], bool] | None = None,
-):
+) -> collections.abc.Iterator[None]:
"""A context manager to log method calls.
All method calls will be logged by default. This behavior can be overriden by passing filtering
@@ -477,7 +486,7 @@ def log_method_calls(
Base.__getattribute__ = old # type: ignore
-def _repr_obj(obj, show_modules: bool = False, depth: int = 0) -> str:
+def _repr_obj(obj: typing.Any, show_modules: bool = False, depth: int = 0) -> str:
"""Return a pretty representation of an object."""
rep = f"{obj.__class__.__name__} ("
@@ -487,7 +496,7 @@ def _repr_obj(obj, show_modules: bool = False, depth: int = 0) -> str:
params = {
name: getattr(obj, name)
- for name, param in inspect.signature(obj.__init__).parameters.items() # type: ignore
+ for name, param in inspect.signature(obj.__init__).parameters.items()
if not (
param.name == "args"
and param.kind == param.VAR_POSITIONAL
diff --git a/river/base/classifier.py b/river/base/classifier.py
index 876bef4e13..01f6b8b9d4 100644
--- a/river/base/classifier.py
+++ b/river/base/classifier.py
@@ -2,6 +2,7 @@
import abc
import typing
+from typing import Any
from river import base
@@ -15,7 +16,7 @@ class Classifier(estimator.Estimator):
"""A classifier."""
@abc.abstractmethod
- def learn_one(self, x: dict, y: base.typing.ClfTarget) -> None:
+ def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.ClfTarget) -> None:
"""Update the model with a set of features `x` and a label `y`.
Parameters
@@ -27,7 +28,9 @@ def learn_one(self, x: dict, y: base.typing.ClfTarget) -> None:
"""
- def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]:
+ def predict_proba_one(
+ self, x: dict[base.typing.FeatureName, Any], **kwargs: Any
+ ) -> dict[base.typing.ClfTarget, float]:
"""Predict the probability of each label for a dictionary of features `x`.
Parameters
@@ -47,7 +50,9 @@ def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]:
# that a classifier does not support predict_proba_one.
raise NotImplementedError
- def predict_one(self, x: dict, **kwargs) -> base.typing.ClfTarget | None:
+ def predict_one(
+ self, x: dict[base.typing.FeatureName, Any], **kwargs: Any
+ ) -> base.typing.ClfTarget | None:
"""Predict the label of a set of features `x`.
Parameters
@@ -69,11 +74,11 @@ def predict_one(self, x: dict, **kwargs) -> base.typing.ClfTarget | None:
return None
@property
- def _multiclass(self):
+ def _multiclass(self) -> bool:
return False
@property
- def _supervised(self):
+ def _supervised(self) -> bool:
return True
diff --git a/river/base/clusterer.py b/river/base/clusterer.py
index 35a73a71ad..ea314a65d4 100644
--- a/river/base/clusterer.py
+++ b/river/base/clusterer.py
@@ -1,19 +1,20 @@
from __future__ import annotations
import abc
+from typing import Any
-from . import estimator
+from . import estimator, typing
class Clusterer(estimator.Estimator):
"""A clustering model."""
@property
- def _supervised(self):
+ def _supervised(self) -> bool:
return False
@abc.abstractmethod
- def learn_one(self, x: dict) -> None:
+ def learn_one(self, x: dict[typing.FeatureName, Any]) -> None:
"""Update the model with a set of features `x`.
Parameters
@@ -24,7 +25,7 @@ def learn_one(self, x: dict) -> None:
"""
@abc.abstractmethod
- def predict_one(self, x: dict) -> int:
+ def predict_one(self, x: dict[typing.FeatureName, Any]) -> int:
"""Predicts the cluster number for a set of features `x`.
Parameters
diff --git a/river/base/drift_detector.py b/river/base/drift_detector.py
index a350c40d4b..ab76036e64 100644
--- a/river/base/drift_detector.py
+++ b/river/base/drift_detector.py
@@ -20,7 +20,7 @@ class _BaseDriftDetector(base.Base):
"""
- def __init__(self):
+ def __init__(self) -> None:
self._drift_detected = False
def _reset(self) -> None:
@@ -40,16 +40,16 @@ class _BaseDriftAndWarningDetector(_BaseDriftDetector):
"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
self._warning_detected = False
- def _reset(self):
+ def _reset(self) -> None:
super()._reset()
self._warning_detected = False
@property
- def warning_detected(self):
+ def warning_detected(self) -> bool:
"""Whether or not a drift is detected following the last update."""
return self._warning_detected
diff --git a/river/base/ensemble.py b/river/base/ensemble.py
index c88f5181ef..f0a4e4fdf0 100644
--- a/river/base/ensemble.py
+++ b/river/base/ensemble.py
@@ -8,7 +8,7 @@
from .wrapper import Wrapper
-class Ensemble(UserList):
+class Ensemble(UserList[Estimator]):
"""An ensemble is a model which is composed of a list of models.
Parameters
@@ -17,7 +17,7 @@ class Ensemble(UserList):
"""
- def __init__(self, models: Iterator[Estimator]):
+ def __init__(self, models: Iterator[Estimator]) -> None:
super().__init__(models)
if len(self) < self._min_number_of_models:
@@ -27,11 +27,11 @@ def __init__(self, models: Iterator[Estimator]):
)
@property
- def _min_number_of_models(self):
+ def _min_number_of_models(self) -> int:
return 2
@property
- def models(self):
+ def models(self) -> list[Estimator]:
return self.data
@@ -49,7 +49,7 @@ class WrapperEnsemble(Ensemble, Wrapper):
"""
- def __init__(self, model, n_models, seed):
+ def __init__(self, model: Estimator, n_models: int, seed: int | None) -> None:
super().__init__(model.clone() for _ in range(n_models))
self.model = model
self.n_models = n_models
@@ -57,5 +57,5 @@ def __init__(self, model, n_models, seed):
self._rng = Random(seed)
@property
- def _wrapped_model(self):
+ def _wrapped_model(self) -> Estimator:
return self.model
diff --git a/river/base/estimator.py b/river/base/estimator.py
index 630b76d640..77504c69f2 100644
--- a/river/base/estimator.py
+++ b/river/base/estimator.py
@@ -1,6 +1,11 @@
from __future__ import annotations
import abc
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from river import compose
from . import base
@@ -9,7 +14,7 @@ class Estimator(base.Base, abc.ABC):
"""An estimator."""
@property
- def _supervised(self):
+ def _supervised(self) -> bool:
"""Indicates whether or not the estimator is supervised or not.
This is useful internally for determining if an estimator expects to be provided with a `y`
@@ -19,7 +24,7 @@ def _supervised(self):
"""
return True
- def __or__(self, other):
+ def __or__(self, other: Estimator | compose.Pipeline) -> compose.Pipeline:
"""Merge with another Transformer into a Pipeline."""
from river import compose
@@ -27,7 +32,7 @@ def __or__(self, other):
return other.__ror__(self)
return compose.Pipeline(self, other)
- def __ror__(self, other):
+ def __ror__(self, other: Estimator | compose.Pipeline) -> compose.Pipeline:
"""Merge with another Transformer into a Pipeline."""
from river import compose
@@ -35,7 +40,7 @@ def __ror__(self, other):
return other.__or__(self)
return compose.Pipeline(other, self)
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
from xml.etree import ElementTree as ET
from river.base import viz
@@ -44,11 +49,11 @@ def _repr_html_(self):
div_str = ET.tostring(div, encoding="unicode")
return f"
{div_str}
"
- def _more_tags(self):
+ def _more_tags(self) -> set[str]:
return set()
@property
- def _tags(self) -> dict[str, bool]:
+ def _tags(self) -> set[str]:
"""Return the estimator's tags.
Tags can be used to specify what kind of inputs an estimator is able to process. For
@@ -64,14 +69,14 @@ def _tags(self) -> dict[str, bool]:
for parent in self.__class__.__mro__:
try:
- tags |= parent._more_tags(self) # type: ignore
+ tags |= parent._more_tags(self) # type: ignore[attr-defined]
except AttributeError:
pass
return tags
@classmethod
- def _unit_test_params(self):
+ def _unit_test_params(self) -> Iterator[dict[str, Any]]:
"""Indicates which parameters to use during unit testing.
Most estimators have a default value for each of their parameters. However, in some cases,
@@ -84,7 +89,7 @@ def _unit_test_params(self):
"""
yield {}
- def _unit_test_skips(self):
+ def _unit_test_skips(self) -> set[str]:
"""Indicates which checks to skip during unit testing.
Most estimators pass the full test suite. However, in some cases, some estimators might not
diff --git a/river/base/multi_output.py b/river/base/multi_output.py
index 078ed1a362..68cf013afc 100644
--- a/river/base/multi_output.py
+++ b/river/base/multi_output.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import abc
+import typing
from .estimator import Estimator
from .typing import FeatureName, RegTarget
@@ -10,7 +11,7 @@ class MultiLabelClassifier(Estimator, abc.ABC):
"""Multi-label classifier."""
@abc.abstractmethod
- def learn_one(self, x: dict, y: dict[FeatureName, bool]) -> None:
+ def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, bool]) -> None:
"""Update the model with a set of features `x` and the labels `y`.
Parameters
@@ -22,7 +23,9 @@ def learn_one(self, x: dict, y: dict[FeatureName, bool]) -> None:
"""
- def predict_proba_one(self, x: dict, **kwargs) -> dict[FeatureName, dict[bool, float]]:
+ def predict_proba_one(
+ self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any
+ ) -> dict[FeatureName, dict[bool, float]]:
"""Predict the probability of each label appearing given dictionary of features `x`.
Parameters
@@ -39,7 +42,9 @@ def predict_proba_one(self, x: dict, **kwargs) -> dict[FeatureName, dict[bool, f
# In case the multi-label classifier does not support probabilities
raise NotImplementedError
- def predict_one(self, x: dict, **kwargs) -> dict[FeatureName, bool]:
+ def predict_one(
+ self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any
+ ) -> dict[FeatureName, bool]:
"""Predict the labels of a set of features `x`.
Parameters
@@ -68,7 +73,12 @@ class MultiTargetRegressor(Estimator, abc.ABC):
"""Multi-target regressor."""
@abc.abstractmethod
- def learn_one(self, x: dict, y: dict[FeatureName, RegTarget], **kwargs) -> None:
+ def learn_one(
+ self,
+ x: dict[FeatureName, typing.Any],
+ y: dict[FeatureName, RegTarget],
+ **kwargs: typing.Any,
+ ) -> None:
"""Fits to a set of features `x` and a real-valued target `y`.
Parameters
@@ -81,7 +91,7 @@ def learn_one(self, x: dict, y: dict[FeatureName, RegTarget], **kwargs) -> None:
"""
@abc.abstractmethod
- def predict_one(self, x: dict) -> dict[FeatureName, RegTarget]:
+ def predict_one(self, x: dict[FeatureName, typing.Any]) -> dict[FeatureName, RegTarget]:
"""Predict the outputs of features `x`.
Parameters
diff --git a/river/base/regressor.py b/river/base/regressor.py
index 09abacb2b6..88c8a351a2 100644
--- a/river/base/regressor.py
+++ b/river/base/regressor.py
@@ -2,6 +2,7 @@
import abc
import typing
+from typing import Any
from river import base
@@ -15,7 +16,7 @@ class Regressor(estimator.Estimator):
"""A regressor."""
@abc.abstractmethod
- def learn_one(self, x: dict, y: base.typing.RegTarget) -> None:
+ def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.RegTarget) -> None:
"""Fits to a set of features `x` and a real-valued target `y`.
Parameters
@@ -28,7 +29,7 @@ def learn_one(self, x: dict, y: base.typing.RegTarget) -> None:
"""
@abc.abstractmethod
- def predict_one(self, x: dict) -> base.typing.RegTarget:
+ def predict_one(self, x: dict[base.typing.FeatureName, Any]) -> base.typing.RegTarget:
"""Predict the output of features `x`.
Parameters
diff --git a/river/base/test_base.py b/river/base/test_base.py
index 412303b93b..82d818d4af 100644
--- a/river/base/test_base.py
+++ b/river/base/test_base.py
@@ -3,18 +3,19 @@
from river import compose, datasets, linear_model, optim, preprocessing, stats, time_series
-def test_clone_estimator():
+def test_clone_estimator() -> None:
obj = linear_model.LinearRegression(l2=42)
obj.learn_one({"x": 3}, 6)
new = obj.clone({"l2": 21})
+ assert type(new) is type(obj)
assert new.l2 == 21
assert obj.l2 == 42
assert new.weights == {}
assert new.weights != obj.weights
-def test_clone_include_attributes():
+def test_clone_include_attributes() -> None:
var = stats.Var()
var.update(1)
var.update(2)
@@ -25,7 +26,7 @@ def test_clone_include_attributes():
assert var.clone(include_attributes=True)._S == 2
-def test_clone_pipeline():
+def test_clone_pipeline() -> None:
obj = preprocessing.StandardScaler() | linear_model.LinearRegression(l2=42)
obj.learn_one({"x": 3}, 6)
@@ -37,7 +38,7 @@ def test_clone_pipeline():
assert new["LinearRegression"].weights != obj["LinearRegression"].weights
-def test_clone_idempotent():
+def test_clone_idempotent() -> None:
model = preprocessing.StandardScaler() | linear_model.LogisticRegression(
optimizer=optim.Adam(), l2=0.1
)
@@ -53,7 +54,7 @@ def test_clone_idempotent():
clone.learn_one(x, y)
-def test_memory_usage():
+def test_memory_usage() -> None:
model = preprocessing.StandardScaler() | linear_model.LogisticRegression()
# We can't test the exact value because it depends on the platform and the Python version
@@ -61,7 +62,7 @@ def test_memory_usage():
assert isinstance(model._memory_usage, str)
-def test_mutate():
+def test_mutate() -> None:
"""
>>> from river import datasets, linear_model, optim, preprocessing
@@ -114,13 +115,13 @@ def test_mutate():
"""
-def test_clone_positional_args():
+def test_clone_positional_args() -> None:
assert compose.Select(1, 2, 3).clone().keys == {1, 2, 3}
assert compose.Discard("a", "b", "c").clone().keys == {"a", "b", "c"}
assert compose.SelectType(float, int).clone().types == (float, int)
-def test_clone_nested_pipeline():
+def test_clone_nested_pipeline() -> None:
model = time_series.SNARIMAX(
p=2,
d=1,
diff --git a/river/base/transformer.py b/river/base/transformer.py
index 16f9aba276..b41aa78f32 100644
--- a/river/base/transformer.py
+++ b/river/base/transformer.py
@@ -2,40 +2,57 @@
import abc
import typing
+from typing import Any
from river import base
if typing.TYPE_CHECKING:
import pandas as pd
+ from river import compose
+
class BaseTransformer:
- def __add__(self, other):
+ def __add__(self, other: BaseTransformer) -> compose.TransformerUnion:
"""Fuses with another Transformer into a TransformerUnion."""
from river import compose
return compose.TransformerUnion(self, other)
- def __radd__(self, other):
+ def __radd__(self, other: BaseTransformer) -> compose.TransformerUnion:
"""Fuses with another Transformer into a TransformerUnion."""
from river import compose
return compose.TransformerUnion(other, self)
- def __mul__(self, other):
+ def __mul__(
+ self,
+ other: BaseTransformer
+ | compose.Pipeline
+ | base.typing.FeatureName
+ | list[base.typing.FeatureName],
+ ) -> compose.Grouper | compose.TransformerProduct:
from river import compose
- if isinstance(other, Transformer) or isinstance(other, compose.Pipeline):
+ if isinstance(other, BaseTransformer) or isinstance(other, compose.Pipeline):
return compose.TransformerProduct(self, other)
return compose.Grouper(transformer=self, by=other)
- def __rmul__(self, other):
+ def __rmul__(
+ self,
+ other: BaseTransformer
+ | compose.Pipeline
+ | base.typing.FeatureName
+ | list[base.typing.FeatureName],
+ ) -> compose.Grouper | compose.TransformerProduct:
"""Creates a Grouper."""
return self * other
@abc.abstractmethod
- def transform_one(self, x: dict) -> dict:
+ def transform_one(
+ self, x: dict[base.typing.FeatureName, Any]
+ ) -> dict[base.typing.FeatureName, Any]:
"""Transform a set of features `x`.
Parameters
@@ -54,10 +71,10 @@ class Transformer(base.Estimator, BaseTransformer):
"""A transformer."""
@property
- def _supervised(self):
+ def _supervised(self) -> bool:
return False
- def learn_one(self, x: dict) -> None:
+ def learn_one(self, x: dict[base.typing.FeatureName, Any]) -> None:
"""Update with a set of features `x`.
A lot of transformers don't actually have to do anything during the `learn_one` step
@@ -78,10 +95,10 @@ class SupervisedTransformer(base.Estimator, BaseTransformer):
"""A supervised transformer."""
@property
- def _supervised(self):
+ def _supervised(self) -> bool:
return True
- def learn_one(self, x: dict, y: base.typing.Target) -> None:
+ def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.Target) -> None:
"""Update with a set of features `x` and a target `y`.
Parameters
@@ -134,7 +151,7 @@ class MiniBatchSupervisedTransformer(Transformer):
"""A supervised transformer that can operate on mini-batches."""
@property
- def _supervised(self):
+ def _supervised(self) -> bool:
return True
@abc.abstractmethod
diff --git a/river/base/typing.py b/river/base/typing.py
index 762526b5ff..05d3ec9f84 100644
--- a/river/base/typing.py
+++ b/river/base/typing.py
@@ -1,11 +1,10 @@
from __future__ import annotations
-import numbers
import typing
FeatureName = typing.Hashable
-RegTarget = numbers.Number
+RegTarget = float
ClfTarget = typing.Union[bool, str, int] # noqa: UP007
Target = typing.Union[ClfTarget, RegTarget] # noqa: UP007
-Dataset = typing.Iterable[typing.Tuple[dict, typing.Any]] # noqa: UP006
-Stream = typing.Iterator[typing.Tuple[dict, typing.Any]] # noqa: UP006
+Dataset = typing.Iterable[typing.Tuple[dict[FeatureName, typing.Any], typing.Any]] # noqa: UP006
+Stream = typing.Iterator[typing.Tuple[dict[FeatureName, typing.Any], typing.Any]] # noqa: UP006
diff --git a/river/base/viz.py b/river/base/viz.py
index b99c272ff2..8c7201cca3 100644
--- a/river/base/viz.py
+++ b/river/base/viz.py
@@ -4,10 +4,11 @@
import textwrap
from xml.etree import ElementTree as ET
+# This import is not cyclic because 'viz' is not exported by 'base'
+from river import base, compose
-def to_html(obj) -> ET.Element:
- from river import base, compose
+def to_html(obj: base.Estimator) -> ET.Element:
if isinstance(obj, compose.Pipeline):
return pipeline_to_html(obj)
if isinstance(obj, compose.TransformerUnion):
@@ -17,9 +18,7 @@ def to_html(obj) -> ET.Element:
return estimator_to_html(obj)
-def estimator_to_html(estimator) -> ET.Element:
- from river import compose
-
+def estimator_to_html(estimator: base.Estimator) -> ET.Element:
details = ET.Element("details", attrib={"class": "river-component river-estimator"})
summary = ET.Element("summary", attrib={"class": "river-summary"})
@@ -45,7 +44,7 @@ def estimator_to_html(estimator) -> ET.Element:
return details
-def pipeline_to_html(pipeline) -> ET.Element:
+def pipeline_to_html(pipeline: compose.Pipeline) -> ET.Element:
div = ET.Element("div", attrib={"class": "river-component river-pipeline"})
for step in pipeline.steps.values():
@@ -54,7 +53,7 @@ def pipeline_to_html(pipeline) -> ET.Element:
return div
-def union_to_html(union) -> ET.Element:
+def union_to_html(union: compose.TransformerUnion) -> ET.Element:
div = ET.Element("div", attrib={"class": "river-component river-union"})
for transformer in union.transformers.values():
@@ -63,7 +62,7 @@ def union_to_html(union) -> ET.Element:
return div
-def wrapper_to_html(wrapper) -> ET.Element:
+def wrapper_to_html(wrapper: base.Wrapper) -> ET.Element:
div = ET.Element("div", attrib={"class": "river-component river-wrapper"})
details = ET.Element("details", attrib={"class": "river-details"})
diff --git a/river/base/wrapper.py b/river/base/wrapper.py
index d59bf1e232..b1f484d710 100644
--- a/river/base/wrapper.py
+++ b/river/base/wrapper.py
@@ -2,30 +2,32 @@
from abc import ABC, abstractmethod
+from river import base
+
class Wrapper(ABC):
"""A wrapper model."""
@property
@abstractmethod
- def _wrapped_model(self):
+ def _wrapped_model(self) -> base.Estimator:
"""Provides access to the wrapped model."""
@property
- def _labelloc(self):
+ def _labelloc(self) -> str:
"""Indicates location of the wrapper name when drawing pipelines."""
return "t" # for top
- def __str__(self):
+ def __str__(self) -> str:
return f"{type(self).__name__}({self._wrapped_model})"
- def _more_tags(self):
+ def _more_tags(self) -> set[str]:
return self._wrapped_model._tags
@property
- def _supervised(self):
+ def _supervised(self) -> bool:
return self._wrapped_model._supervised
@property
- def _multiclass(self):
- return self._wrapped_model._multiclass
+ def _multiclass(self) -> bool:
+ return isinstance(self._wrapped_model, base.Classifier) and self._wrapped_model._multiclass
diff --git a/river/compose/grouper.py b/river/compose/grouper.py
index 7d19f50ebf..b2bb853f94 100644
--- a/river/compose/grouper.py
+++ b/river/compose/grouper.py
@@ -28,7 +28,7 @@ class Grouper(base.Transformer):
def __init__(
self,
- transformer: base.Transformer,
+ transformer: base.BaseTransformer,
by: base.typing.FeatureName | list[base.typing.FeatureName],
):
self.transformer = transformer
diff --git a/river/compose/pipeline.py b/river/compose/pipeline.py
index 2c894ede04..be80adfe5a 100644
--- a/river/compose/pipeline.py
+++ b/river/compose/pipeline.py
@@ -274,8 +274,8 @@ class Pipeline(base.Estimator):
_LEARN_UNSUPERVISED_DURING_PREDICT = False
- def __init__(self, *steps):
- self.steps = collections.OrderedDict()
+ def __init__(self, *steps) -> None:
+ self.steps: collections.OrderedDict = collections.OrderedDict()
for step in steps:
self |= step
@@ -289,12 +289,12 @@ def __len__(self):
"""Just for convenience."""
return len(self.steps)
- def __or__(self, other):
+ def __or__(self, other) -> Pipeline:
"""Insert a step at the end of the pipeline."""
self._add_step(other, at_start=False)
return self
- def __ror__(self, other):
+ def __ror__(self, other) -> Pipeline:
"""Insert a step at the start of the pipeline."""
self._add_step(other, at_start=True)
return self
diff --git a/river/compose/select.py b/river/compose/select.py
index 088bfe714e..cbfa28f04c 100644
--- a/river/compose/select.py
+++ b/river/compose/select.py
@@ -42,7 +42,7 @@ class Discard(base.Transformer):
"""
- def __init__(self, *keys: tuple[base.typing.FeatureName]):
+ def __init__(self, *keys: base.typing.FeatureName):
self.keys = set(keys)
def transform_one(self, x):
@@ -124,7 +124,7 @@ class Select(base.MiniBatchTransformer):
"""
- def __init__(self, *keys: tuple[base.typing.FeatureName]):
+ def __init__(self, *keys: base.typing.FeatureName):
self.keys = set(keys)
def transform_one(self, x):
@@ -173,7 +173,7 @@ class SelectType(base.Transformer):
"""
- def __init__(self, *types: tuple[type]):
+ def __init__(self, *types: type):
self.types = types
def transform_one(self, x):
diff --git a/river/compose/union.py b/river/compose/union.py
index 0210745a1e..b7b600125b 100644
--- a/river/compose/union.py
+++ b/river/compose/union.py
@@ -156,8 +156,8 @@ class TransformerUnion(base.MiniBatchTransformer):
"""
- def __init__(self, *transformers):
- self.transformers = {}
+ def __init__(self, *transformers) -> None:
+ self.transformers: dict = {}
for transformer in transformers:
if transformer.__class__ == self.__class__:
for t in transformer:
diff --git a/river/datasets/phishing.py b/river/datasets/phishing.py
index d76131edcc..ef84d4e486 100644
--- a/river/datasets/phishing.py
+++ b/river/datasets/phishing.py
@@ -16,7 +16,7 @@ class Phishing(base.FileDataset):
"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__(
n_samples=1_250,
n_features=9,
diff --git a/river/ensemble/streaming_random_patches.py b/river/ensemble/streaming_random_patches.py
index 415b30be10..21d94e77e3 100644
--- a/river/ensemble/streaming_random_patches.py
+++ b/river/ensemble/streaming_random_patches.py
@@ -93,11 +93,11 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
for model in self:
# Get prediction for instance
- y_pred = model.predict_one(x)
+ y_pred = model.predict_one(x) # type:ignore[attr-defined]
# Update performance evaluator
if y_pred is not None:
- model.metric.update(y_true=y, y_pred=y_pred)
+ model.metric.update(y_true=y, y_pred=y_pred) # type: ignore[attr-defined] # BaseSRPEstimator has a metric field
# Train using random subspaces without resampling,
# i.e. all instances are used for training.
@@ -109,7 +109,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
k = poisson(rate=self.lam, rng=self._rng)
if k == 0:
continue
- model.learn_one(x=x, y=y, w=k, n_samples_seen=self._n_samples_seen)
+ model.learn_one(x=x, y=y, w=k, n_samples_seen=self._n_samples_seen) # type:ignore[attr-defined]
def _generate_subspaces(self, features: list):
n_features = len(features)
@@ -543,7 +543,7 @@ def learn_one(
# TODO Find a way to verify if the model natively supports sample_weight (w)
for _ in range(int(w)):
- self.model.learn_one(x=x_subset, y=y, **kwargs)
+ self.model.learn_one(x=x_subset, y=y, **kwargs) # type:ignore[attr-defined]
if self._background_learner:
# Train the background learner
@@ -557,7 +557,7 @@ def learn_one(
)
if not self.disable_drift_detector and not self.is_background_learner:
- correctly_classifies = self.model.predict_one(x_subset) == y
+ correctly_classifies = self.model.predict_one(x_subset) == y # type:ignore[attr-defined]
# Check for warnings only if the background learner is active
if not self.disable_background_learner:
# Update the warning detection method
@@ -845,10 +845,10 @@ def learn_one(
# TODO Find a way to verify if the model natively supports sample_weight (w)
for _ in range(int(w)):
- self.model.learn_one(x=x_subset, y=y, **kwargs)
+ self.model.learn_one(x=x_subset, y=y, **kwargs) # type:ignore[attr-defined]
# Drift detection input
- y_pred = self.model.predict_one(x_subset)
+ y_pred = self.model.predict_one(x_subset) # type:ignore[attr-defined]
if self.drift_detection_criteria == "error":
# Track absolute error
drift_detector_input = abs(y_pred - y)
diff --git a/river/forest/adaptive_random_forest.py b/river/forest/adaptive_random_forest.py
index 791e614d88..9ed1df53dd 100644
--- a/river/forest/adaptive_random_forest.py
+++ b/river/forest/adaptive_random_forest.py
@@ -155,13 +155,13 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
self._init_ensemble(sorted(x.keys()))
for i, model in enumerate(self):
- y_pred = model.predict_one(x)
+ y_pred = model.predict_one(x) # type:ignore[attr-defined]
# Update performance evaluator
self._metrics[i].update(
- y_true=y,
+ y_true=y, # type:ignore[arg-type]
y_pred=(
- model.predict_proba_one(x)
+ model.predict_proba_one(x) # type:ignore[attr-defined]
if isinstance(self.metric, metrics.base.ClassificationMetric)
and not self.metric.requires_labels
else y_pred
@@ -173,7 +173,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
if not self._warning_detection_disabled and self._background[i] is not None:
self._background[i].learn_one(x=x, y=y, w=k) # type: ignore
- model.learn_one(x=x, y=y, w=k)
+ model.learn_one(x=x, y=y, w=k) # type:ignore[attr-defined]
drift_input = None
if not self._warning_detection_disabled:
@@ -198,7 +198,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
if self._drift_detectors[i].drift_detected:
if not self._warning_detection_disabled and self._background[i] is not None:
- self.data[i] = self._background[i]
+ self.data[i] = self._background[i] # type:ignore[assignment]
self._background[i] = None
self._warning_detectors[i] = self.warning_detector.clone()
self._drift_detectors[i] = self.drift_detector.clone()
@@ -663,7 +663,9 @@ def _mutable_attributes(self):
def _multiclass(self):
return True
- def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]:
+ def predict_proba_one(
+ self, x: dict, **kwargs: typing.Any
+ ) -> dict[base.typing.ClfTarget, float]:
y_pred: typing.Counter = collections.Counter()
if len(self) == 0:
@@ -671,7 +673,7 @@ def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]:
return y_pred # type: ignore
for i, model in enumerate(self):
- y_proba_temp = model.predict_proba_one(x)
+ y_proba_temp = model.predict_proba_one(x) # type:ignore[attr-defined]
metric_value = self._metrics[i].get()
if not self.disable_weighted_vote and metric_value > 0.0:
y_proba_temp = {k: val * metric_value for k, val in y_proba_temp.items()}
@@ -952,7 +954,7 @@ def predict_one(self, x: dict) -> base.typing.RegTarget:
weights = np.zeros(self.n_models)
sum_weights = 0.0
for i, model in enumerate(self):
- y_pred[i] = model.predict_one(x)
+ y_pred[i] = model.predict_one(x) # type:ignore[attr-defined]
weights[i] = self._metrics[i].get()
sum_weights += weights[i]
@@ -964,7 +966,7 @@ def predict_one(self, x: dict) -> base.typing.RegTarget:
y_pred *= weights
else:
for i, model in enumerate(self):
- y_pred[i] = model.predict_one(x)
+ y_pred[i] = model.predict_one(x) # type:ignore[attr-defined]
if self.aggregation_method == self._MEAN:
y_pred = y_pred.mean()
diff --git a/river/forest/aggregated_mondrian_forest.py b/river/forest/aggregated_mondrian_forest.py
index 12601d0cd5..2e79b0667c 100644
--- a/river/forest/aggregated_mondrian_forest.py
+++ b/river/forest/aggregated_mondrian_forest.py
@@ -170,7 +170,7 @@ def __init__(
self._classes: set[base.typing.ClfTarget] = set()
def _initialize_trees(self) -> None:
- self.data: list[MondrianTreeClassifier] = []
+ self.data: list[MondrianTreeClassifier] = [] # type:ignore[assignment]
for _ in range(self.n_estimators):
tree = MondrianTreeClassifier(
self.step,
@@ -290,7 +290,7 @@ def __init__(
def _initialize_trees(self) -> None:
"""Initialize the forest."""
- self.data: list[MondrianTreeRegressor] = []
+ self.data: list[MondrianTreeRegressor] = [] # type:ignore[assignment]
for _ in range(self.n_estimators):
# We don't want to have the same stochastic scheme for each tree, or it'll break the randomness
# Hence we introduce a new seed for each, that is derived of the given seed by a deterministic process
diff --git a/river/forest/online_extra_trees.py b/river/forest/online_extra_trees.py
index ee361007eb..7c8808d401 100644
--- a/river/forest/online_extra_trees.py
+++ b/river/forest/online_extra_trees.py
@@ -720,7 +720,7 @@ def predict_one(self, x: dict) -> base.typing.RegTarget:
weights = []
for perf, model in zip(self._perfs, self.models):
- preds.append(model.predict_one(x))
+ preds.append(model.predict_one(x)) # type:ignore[attr-defined]
weights.append(perf.get())
sum_weights = sum(weights)
@@ -733,6 +733,6 @@ def predict_one(self, x: dict) -> base.typing.RegTarget:
preds = [(w / sum_weights) * pred for w, pred in zip(weights, preds)]
return sum(preds)
else:
- preds = [model.predict_one(x) for model in self.models]
+ preds = [model.predict_one(x) for model in self.models] # type:ignore[attr-defined]
return sum(preds) / len(preds)
diff --git a/river/metrics/base.py b/river/metrics/base.py
index 788de42d3f..7110c66844 100644
--- a/river/metrics/base.py
+++ b/river/metrics/base.py
@@ -2,7 +2,6 @@
import abc
import collections
-import numbers
import operator
from river import base, stats, utils
@@ -190,11 +189,11 @@ class RegressionMetric(Metric):
_fmt = ",.6f" # use commas to separate big numbers and show 6 decimals
@abc.abstractmethod
- def update(self, y_true: numbers.Number, y_pred: numbers.Number) -> None:
+ def update(self, y_true: float, y_pred: float) -> None:
"""Update the metric."""
@abc.abstractmethod
- def revert(self, y_true: numbers.Number, y_pred: numbers.Number) -> None:
+ def revert(self, y_true: float, y_pred: float) -> None:
"""Revert the metric."""
@property
diff --git a/river/multioutput/chain.py b/river/multioutput/chain.py
index e4516fd708..091a2bef3c 100644
--- a/river/multioutput/chain.py
+++ b/river/multioutput/chain.py
@@ -34,7 +34,7 @@ def __getitem__(self, key):
return self[key]
-class ClassifierChain(BaseChain, base.MultiLabelClassifier):
+class ClassifierChain(BaseChain, base.MultiLabelClassifier): # type:ignore[misc]
"""A multi-output model that arranges classifiers into a chain.
This will create one model per output. The prediction of the first output will be used as a
@@ -165,7 +165,7 @@ def predict_proba_one(self, x, **kwargs):
return y_pred
-class RegressorChain(BaseChain, base.MultiTargetRegressor):
+class RegressorChain(BaseChain, base.MultiTargetRegressor): # type:ignore[misc]
"""A multi-output model that arranges regressors into a chain.
This will create one model per output. The prediction of the first output will be used as a
diff --git a/river/optim/adam.py b/river/optim/adam.py
index f78afe05ec..748445e119 100644
--- a/river/optim/adam.py
+++ b/river/optim/adam.py
@@ -51,7 +51,7 @@ class Adam(optim.base.Optimizer):
"""
- def __init__(self, lr=0.1, beta_1=0.9, beta_2=0.999, eps=1e-8):
+ def __init__(self, lr=0.1, beta_1=0.9, beta_2=0.999, eps=1e-8) -> None:
super().__init__(lr)
self.beta_1 = beta_1
self.beta_2 = beta_2
diff --git a/river/optim/sgd.py b/river/optim/sgd.py
index 0e05e8b98a..6d8e9e7321 100644
--- a/river/optim/sgd.py
+++ b/river/optim/sgd.py
@@ -39,7 +39,7 @@ class SGD(optim.base.Optimizer):
"""
- def __init__(self, lr=0.01):
+ def __init__(self, lr=0.01) -> None:
super().__init__(lr)
def _step_with_dict(self, w, g):
diff --git a/river/preprocessing/scale.py b/river/preprocessing/scale.py
index eab0a68966..654075e15b 100644
--- a/river/preprocessing/scale.py
+++ b/river/preprocessing/scale.py
@@ -152,11 +152,11 @@ class StandardScaler(base.MiniBatchTransformer):
"""
- def __init__(self, with_std=True):
+ def __init__(self, with_std=True) -> None:
self.with_std = with_std
- self.counts = collections.Counter()
- self.means = collections.defaultdict(float)
- self.vars = collections.defaultdict(float)
+ self.counts: collections.Counter = collections.Counter()
+ self.means: collections.defaultdict = collections.defaultdict(float)
+ self.vars: collections.defaultdict = collections.defaultdict(float)
def learn_one(self, x):
for i, xi in x.items():
diff --git a/river/rules/amrules.py b/river/rules/amrules.py
index 245f71c1f4..602dc81592 100644
--- a/river/rules/amrules.py
+++ b/river/rules/amrules.py
@@ -334,6 +334,7 @@ def n_drifts_detected(self) -> int:
return self._n_drifts_detected
def _new_rule(self) -> RegRule:
+ predictor: base.Regressor
if self.pred_type == self._PRED_MEAN:
predictor = MeanRegressor()
elif self.pred_type == self._PRED_MODEL:
diff --git a/river/stats/var.py b/river/stats/var.py
index 86fa412061..cfdc2b6fbf 100644
--- a/river/stats/var.py
+++ b/river/stats/var.py
@@ -70,7 +70,7 @@ class Var(stats.base.Univariate):
"""
- def __init__(self, ddof=1):
+ def __init__(self, ddof=1) -> None:
self.ddof = ddof
self.mean = stats.Mean()
self._S = 0
@@ -79,7 +79,7 @@ def __init__(self, ddof=1):
def n(self):
return self.mean.n
- def update(self, x, w=1.0):
+ def update(self, x, w=1.0) -> None:
mean_old = self.mean.get()
self.mean.update(x, w)
mean_new = self.mean.get()
diff --git a/river/stream/iter_sql.py b/river/stream/iter_sql.py
index 13462315e0..453cd0c33e 100644
--- a/river/stream/iter_sql.py
+++ b/river/stream/iter_sql.py
@@ -102,4 +102,4 @@ def iter_sql(
for row in result_proxy:
x = dict(row._mapping.items())
y = x.pop(target_name)
- yield x, y
+ yield x, y # type: ignore[misc]
diff --git a/river/time_series/snarimax.py b/river/time_series/snarimax.py
index 150f15c2bb..144d34fb63 100644
--- a/river/time_series/snarimax.py
+++ b/river/time_series/snarimax.py
@@ -4,7 +4,7 @@
import itertools
import math
-from river import base, linear_model, preprocessing, time_series
+from river import base, compose, linear_model, preprocessing, time_series
__all__ = ["SNARIMAX"]
@@ -280,7 +280,7 @@ def __init__(
sp: int = 0,
sd: int = 0,
sq: int = 0,
- regressor: base.Regressor | None = None,
+ regressor: base.Regressor | compose.Pipeline | None = None,
):
self.p = p
self.d = d
diff --git a/river/tree/stochastic_gradient_tree.py b/river/tree/stochastic_gradient_tree.py
index 8ea1d76371..01dfb3dcce 100644
--- a/river/tree/stochastic_gradient_tree.py
+++ b/river/tree/stochastic_gradient_tree.py
@@ -2,6 +2,7 @@
import abc
import sys
+from typing import Any
from scipy.stats import f as f_dist
@@ -291,7 +292,7 @@ def __init__(
def _target_transform(self, y):
return float(y)
- def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]:
+ def predict_proba_one(self, x: dict, **kwargs: Any) -> dict[base.typing.ClfTarget, float]:
if isinstance(self._root, DTBranch):
leaf = self._root.traverse(x, until_leaf=True)
else:
diff --git a/river/utils/pretty.py b/river/utils/pretty.py
index df9d7c926a..8b53d0a83b 100644
--- a/river/utils/pretty.py
+++ b/river/utils/pretty.py
@@ -56,7 +56,7 @@ def print_table(
return table
-def humanize_bytes(n_bytes: int):
+def humanize_bytes(n_bytes: int) -> str:
"""Returns a human-friendly byte size.
Parameters