|
10 | 10 | import numpy as np |
11 | 11 | import numpy.typing as npt |
12 | 12 | from loguru import logger |
13 | | -from sklearn.base import _fit_context, clone |
| 13 | +from sklearn.base import _fit_context, clone # noqa: PLC2701 |
14 | 14 | from sklearn.pipeline import Pipeline as _Pipeline |
15 | | -from sklearn.pipeline import _final_estimator_has, _fit_transform_one |
| 15 | +from sklearn.pipeline import _final_estimator_has, _fit_transform_one # noqa: PLC2701 |
16 | 16 | from sklearn.utils import Bunch |
17 | | -from sklearn.utils._tags import Tags, get_tags |
| 17 | +from sklearn.utils._tags import Tags, get_tags # noqa: PLC2701 |
18 | 18 | from sklearn.utils.metadata_routing import ( |
19 | 19 | MetadataRouter, |
20 | 20 | MethodMapping, |
21 | | - _routing_enabled, |
| 21 | + _routing_enabled, # noqa: PLC2701 |
22 | 22 | process_routing, |
23 | 23 | ) |
24 | 24 | from sklearn.utils.metaestimators import available_if |
@@ -94,12 +94,16 @@ def _set_error_resinserter(self) -> None: |
94 | 94 | error_filter_list = [ |
95 | 95 | n_filter for _, n_filter in self.steps if isinstance(n_filter, ErrorFilter) |
96 | 96 | ] |
97 | | - for step in self.steps: |
98 | | - if isinstance(step[1], PostPredictionWrapper) and isinstance( |
99 | | - step[1].wrapped_estimator, |
100 | | - FilterReinserter, |
101 | | - ): |
102 | | - error_replacer_list.append(step[1].wrapped_estimator) |
| 97 | + error_replacer_list.extend( |
| 98 | + [ |
| 99 | + step[1].wrapped_estimator |
| 100 | + for step in self.steps |
| 101 | + if ( |
| 102 | + isinstance(step[1], PostPredictionWrapper) |
| 103 | + and isinstance(step[1].wrapped_estimator, FilterReinserter) |
| 104 | + ) |
| 105 | + ], |
| 106 | + ) |
103 | 107 | for error_replacer in error_replacer_list: |
104 | 108 | error_replacer.select_error_filter(error_filter_list) |
105 | 109 |
|
@@ -231,7 +235,7 @@ def _final_estimator( |
231 | 235 |
|
232 | 236 | # pylint: disable=too-many-locals,too-many-branches |
233 | 237 | @override |
234 | | - def _fit( |
| 238 | + def _fit( # noqa: PLR0912 |
235 | 239 | self, |
236 | 240 | X: Any, |
237 | 241 | y: Any = None, |
@@ -983,7 +987,7 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]: |
983 | 987 | return last_step.classes_ |
984 | 988 | raise ValueError("Last step has no classes_ attribute.") |
985 | 989 |
|
986 | | - def __sklearn_tags__(self) -> Tags: |
| 990 | + def __sklearn_tags__(self) -> Tags: # noqa: PLW3201 |
987 | 991 | """Return the sklearn tags. |
988 | 992 |
|
989 | 993 | Returns |
@@ -1093,7 +1097,7 @@ def get_metadata_routing(self) -> MetadataRouter: |
1093 | 1097 | .add(caller="score", callee="transform") |
1094 | 1098 | ) |
1095 | 1099 |
|
1096 | | - router.add(method_mapping=method_mapping, **{name: trans}) |
| 1100 | + router.add(method_mapping=method_mapping, **{name: trans}) # type: ignore |
1097 | 1101 |
|
1098 | 1102 | # Only the _non_post_processing_steps is changed from the original |
1099 | 1103 | # implementation is changed in the following line |
|
0 commit comments