Skip to content

Commit 4371856

Browse files
committed
mypy ignore and linting
1 parent c19431f commit 4371856

2 files changed

Lines changed: 21 additions & 17 deletions

File tree

molpipeline/pipeline/_skl_pipeline.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
import numpy as np
1111
import numpy.typing as npt
1212
from loguru import logger
13-
from sklearn.base import _fit_context, clone
13+
from sklearn.base import _fit_context, clone # noqa: PLC2701
1414
from sklearn.pipeline import Pipeline as _Pipeline
15-
from sklearn.pipeline import _final_estimator_has, _fit_transform_one
15+
from sklearn.pipeline import _final_estimator_has, _fit_transform_one # noqa: PLC2701
1616
from sklearn.utils import Bunch
17-
from sklearn.utils._tags import Tags, get_tags
17+
from sklearn.utils._tags import Tags, get_tags # noqa: PLC2701
1818
from sklearn.utils.metadata_routing import (
1919
MetadataRouter,
2020
MethodMapping,
21-
_routing_enabled,
21+
_routing_enabled, # noqa: PLC2701
2222
process_routing,
2323
)
2424
from sklearn.utils.metaestimators import available_if
@@ -94,12 +94,16 @@ def _set_error_resinserter(self) -> None:
9494
error_filter_list = [
9595
n_filter for _, n_filter in self.steps if isinstance(n_filter, ErrorFilter)
9696
]
97-
for step in self.steps:
98-
if isinstance(step[1], PostPredictionWrapper) and isinstance(
99-
step[1].wrapped_estimator,
100-
FilterReinserter,
101-
):
102-
error_replacer_list.append(step[1].wrapped_estimator)
97+
error_replacer_list.extend(
98+
[
99+
step[1].wrapped_estimator
100+
for step in self.steps
101+
if (
102+
isinstance(step[1], PostPredictionWrapper)
103+
and isinstance(step[1].wrapped_estimator, FilterReinserter)
104+
)
105+
],
106+
)
103107
for error_replacer in error_replacer_list:
104108
error_replacer.select_error_filter(error_filter_list)
105109

@@ -231,7 +235,7 @@ def _final_estimator(
231235

232236
# pylint: disable=too-many-locals,too-many-branches
233237
@override
234-
def _fit(
238+
def _fit( # noqa: PLR0912
235239
self,
236240
X: Any,
237241
y: Any = None,
@@ -983,7 +987,7 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]:
983987
return last_step.classes_
984988
raise ValueError("Last step has no classes_ attribute.")
985989

986-
def __sklearn_tags__(self) -> Tags:
990+
def __sklearn_tags__(self) -> Tags: # noqa: PLW3201
987991
"""Return the sklearn tags.
988992
989993
Returns
@@ -1093,7 +1097,7 @@ def get_metadata_routing(self) -> MetadataRouter:
10931097
.add(caller="score", callee="transform")
10941098
)
10951099

1096-
router.add(method_mapping=method_mapping, **{name: trans})
1100+
router.add(method_mapping=method_mapping, **{name: trans}) # type: ignore
10971101

10981102
# Only the _non_post_processing_steps is changed from the original
10991103
# implementation is changed in the following line

tests/test_elements/test_mol2mol/test_mol2mol_filter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_element_filter(self) -> None:
9898
]
9999

100100
for test_params in test_params_list_with_results:
101-
pipeline.set_params(**test_params["params"])
101+
pipeline.set_params(**test_params["params"]) # type: ignore
102102
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
103103
self.assertEqual(filtered_smiles, test_params["result"])
104104

@@ -188,7 +188,7 @@ def test_complex_filter(self) -> None:
188188
]
189189

190190
for test_params in test_params_list_with_results:
191-
pipeline.set_params(**test_params["params"])
191+
pipeline.set_params(**test_params["params"]) # type: ignore
192192
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
193193
self.assertEqual(filtered_smiles, test_params["result"])
194194

@@ -276,7 +276,7 @@ def test_smarts_smiles_filter(self) -> None:
276276
]
277277

278278
for test_params in test_params_list_with_results:
279-
pipeline.set_params(**test_params["params"])
279+
pipeline.set_params(**test_params["params"]) # type: ignore
280280
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
281281
self.assertEqual(filtered_smiles, test_params["result"])
282282

@@ -433,7 +433,7 @@ def test_descriptor_filter(self) -> None:
433433
]
434434

435435
for test_params in test_params_list_with_results:
436-
pipeline.set_params(**test_params["params"])
436+
pipeline.set_params(**test_params["params"]) # type: ignore
437437
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
438438
self.assertEqual(filtered_smiles, test_params["result"])
439439

0 commit comments

Comments
 (0)