Skip to content

Commit 846abaa

Browse files
committed
Formatting
1 parent 31176a1 commit 846abaa

File tree

10 files changed

+59
-22
lines changed

10 files changed

+59
-22
lines changed

river/bandit/evaluate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def evaluate(
139139
arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined, arg-type]
140140
observation, reward, terminated, truncated, info = env_.step(arm)
141141
policy_.update(arm, reward)
142-
reward_stat_.update(reward) # type: ignore[arg-type]
142+
reward_stat_.update(reward) # type: ignore[arg-type]
143143

144144
yield {
145145
"episode": episode,

river/base/base.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def _get_params(self) -> dict[str, typing.Any]:
7373

7474
return params
7575

76-
def clone(self, new_params: dict[str, typing.Any] | None = None, include_attributes: bool = False) -> typing_extensions.Self:
76+
def clone(
77+
self, new_params: dict[str, typing.Any] | None = None, include_attributes: bool = False
78+
) -> typing_extensions.Self:
7779
"""Return a fresh estimator with the same parameters.
7880
7981
The clone has the same parameters but has not been updated with any data.
@@ -371,7 +373,7 @@ def _raw_memory_usage(self) -> int:
371373
buffer.extend([k for k in obj.keys()])
372374
buffer.extend([v for v in obj.values()])
373375
elif hasattr(obj, "__dict__"): # Save object contents
374-
contents= vars(obj)
376+
contents = vars(obj)
375377
size += sys.getsizeof(contents)
376378
buffer.extend([k for k in contents.keys()])
377379
buffer.extend([v for v in contents.values()])
@@ -398,7 +400,12 @@ def _memory_usage(self) -> str:
398400
return utils.pretty.humanize_bytes(self._raw_memory_usage)
399401

400402

401-
def _log_method_calls(self: typing.Any, name: str, class_condition: typing.Callable[[typing.Any], bool], method_condition: typing.Callable[[typing.Any], bool]) -> typing.Any:
403+
def _log_method_calls(
404+
self: typing.Any,
405+
name: str,
406+
class_condition: typing.Callable[[typing.Any], bool],
407+
method_condition: typing.Callable[[typing.Any], bool],
408+
) -> typing.Any:
402409
method = object.__getattribute__(self, name)
403410
if (
404411
not name.startswith("_")

river/base/classifier.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.ClfTar
2828
2929
"""
3030

31-
def predict_proba_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any) -> dict[base.typing.ClfTarget, float]:
31+
def predict_proba_one(
32+
self, x: dict[base.typing.FeatureName, Any], **kwargs: Any
33+
) -> dict[base.typing.ClfTarget, float]:
3234
"""Predict the probability of each label for a dictionary of features `x`.
3335
3436
Parameters
@@ -48,7 +50,9 @@ def predict_proba_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any
4850
# that a classifier does not support predict_proba_one.
4951
raise NotImplementedError
5052

51-
def predict_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any) -> base.typing.ClfTarget | None:
53+
def predict_one(
54+
self, x: dict[base.typing.FeatureName, Any], **kwargs: Any
55+
) -> base.typing.ClfTarget | None:
5256
"""Predict the label of a set of features `x`.
5357
5458
Parameters

river/base/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22

33
import abc
4-
from typing import Any, TYPE_CHECKING
54
from collections.abc import Iterator
6-
7-
from . import base
5+
from typing import TYPE_CHECKING, Any
86

97
if TYPE_CHECKING:
108
from river import compose
119

10+
from . import base
11+
1212

1313
class Estimator(base.Base, abc.ABC):
1414
"""An estimator."""

river/base/multi_output.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, bool]
2323
2424
"""
2525

26-
def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any) -> dict[FeatureName, dict[bool, float]]:
26+
def predict_proba_one(
27+
self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any
28+
) -> dict[FeatureName, dict[bool, float]]:
2729
"""Predict the probability of each label appearing given dictionary of features `x`.
2830
2931
Parameters
@@ -40,7 +42,9 @@ def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.A
4042
# In case the multi-label classifier does not support probabilities
4143
raise NotImplementedError
4244

43-
def predict_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any) -> dict[FeatureName, bool]:
45+
def predict_one(
46+
self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any
47+
) -> dict[FeatureName, bool]:
4448
"""Predict the labels of a set of features `x`.
4549
4650
Parameters
@@ -69,7 +73,12 @@ class MultiTargetRegressor(Estimator, abc.ABC):
6973
"""Multi-target regressor."""
7074

7175
@abc.abstractmethod
72-
def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, RegTarget], **kwargs: typing.Any) -> None:
76+
def learn_one(
77+
self,
78+
x: dict[FeatureName, typing.Any],
79+
y: dict[FeatureName, RegTarget],
80+
**kwargs: typing.Any,
81+
) -> None:
7382
"""Fits to a set of features `x` and a real-valued target `y`.
7483
7584
Parameters

river/base/transformer.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
if typing.TYPE_CHECKING:
1010
import pandas as pd
11+
1112
from river import compose
1213

1314

@@ -24,20 +25,34 @@ def __radd__(self, other: BaseTransformer) -> compose.TransformerUnion:
2425

2526
return compose.TransformerUnion(other, self)
2627

27-
def __mul__(self, other: BaseTransformer | compose.Pipeline | base.typing.FeatureName | list[base.typing.FeatureName]) -> compose.Grouper | compose.TransformerProduct:
28+
def __mul__(
29+
self,
30+
other: BaseTransformer
31+
| compose.Pipeline
32+
| base.typing.FeatureName
33+
| list[base.typing.FeatureName],
34+
) -> compose.Grouper | compose.TransformerProduct:
2835
from river import compose
2936

3037
if isinstance(other, BaseTransformer) or isinstance(other, compose.Pipeline):
3138
return compose.TransformerProduct(self, other)
3239

33-
return compose.Grouper(transformer=self, by=other) # type: ignore[arg-type]
40+
return compose.Grouper(transformer=self, by=other)
3441

35-
def __rmul__(self, other: BaseTransformer | compose.Pipeline | base.typing.FeatureName | list[base.typing.FeatureName]) -> compose.Grouper | compose.TransformerProduct:
42+
def __rmul__(
43+
self,
44+
other: BaseTransformer
45+
| compose.Pipeline
46+
| base.typing.FeatureName
47+
| list[base.typing.FeatureName],
48+
) -> compose.Grouper | compose.TransformerProduct:
3649
"""Creates a Grouper."""
3750
return self * other
3851

3952
@abc.abstractmethod
40-
def transform_one(self, x: dict[base.typing.FeatureName, Any]) -> dict[base.typing.FeatureName, Any]:
53+
def transform_one(
54+
self, x: dict[base.typing.FeatureName, Any]
55+
) -> dict[base.typing.FeatureName, Any]:
4156
"""Transform a set of features `x`.
4257
4358
Parameters

river/base/viz.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

3-
# This import is not cyclic because 'viz' is not exported by 'base'
4-
from river import base, compose
5-
63
import inspect
74
import textwrap
85
from xml.etree import ElementTree as ET
96

7+
# This import is not cyclic because 'viz' is not exported by 'base'
8+
from river import base, compose
9+
1010

1111
def to_html(obj: base.Estimator) -> ET.Element:
1212
if isinstance(obj, compose.Pipeline):

river/forest/adaptive_random_forest.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,9 @@ def _mutable_attributes(self):
663663
def _multiclass(self):
664664
return True
665665

666-
def predict_proba_one(self, x: dict, **kwargs: typing.Any) -> dict[base.typing.ClfTarget, float]:
666+
def predict_proba_one(
667+
self, x: dict, **kwargs: typing.Any
668+
) -> dict[base.typing.ClfTarget, float]:
667669
y_pred: typing.Counter = collections.Counter()
668670

669671
if len(self) == 0:

river/stream/iter_sql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ def iter_sql(
102102
for row in result_proxy:
103103
x = dict(row._mapping.items())
104104
y = x.pop(target_name)
105-
yield x, y # type: ignore[misc]
105+
yield x, y # type: ignore[misc]

river/time_series/snarimax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55
import math
66

7-
from river import base, linear_model, preprocessing, time_series, compose
7+
from river import base, compose, linear_model, preprocessing, time_series
88

99
__all__ = ["SNARIMAX"]
1010

0 commit comments

Comments
 (0)