Skip to content

Commit 1d6bcda

Browse files
authored
Adapt Unittests to new RDKit SMILES + Linting (#175)
* update smiles + linting
1 parent d858edd commit 1d6bcda

1 file changed

Lines changed: 33 additions & 20 deletions

File tree

tests/test_elements/test_mol2mol/test_mol2mol_filter.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
"""Test MolFilter, which invalidate molecules based on criteria defined in the respective filter."""
1+
"""Unittest for MolFilters functionality.
2+
3+
MolFilters flag Molecules as invalid based on the criteria defined in the filter.
4+
5+
"""
26

37
import json
48
import tempfile
59
import unittest
610
from pathlib import Path
11+
from typing import TYPE_CHECKING
712

813
from molpipeline import ErrorFilter, FilterReinserter, Pipeline
914
from molpipeline.any2mol import SmilesToMol
@@ -19,14 +24,16 @@
1924
)
2025
from molpipeline.utils.comparison import compare_recursive
2126
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json
22-
from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange
27+
28+
if TYPE_CHECKING:
29+
from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange
2330

2431
# pylint: disable=duplicate-code # test case molecules are allowed to be duplicated
2532
SMILES_ANTIMONY = "[SbH6+3]"
2633
SMILES_BENZENE = "c1ccccc1"
2734
SMILES_CHLOROBENZENE = "Clc1ccccc1"
2835
SMILES_CL_BR = "NC(Cl)(Br)C(=O)O"
29-
SMILES_METAL_AU = "OC[C@H]1OC(S[Au])[C@H](O)[C@@H](O)[C@@H]1O"
36+
SMILES_METAL_AU = "OC[C@H]1OC([S][Au])[C@H](O)[C@@H](O)[C@@H]1O"
3037

3138
SMILES_LIST = [
3239
SMILES_ANTIMONY,
@@ -80,7 +87,7 @@ def test_element_filter(self) -> None:
8087
6: 6,
8188
1: (5, 6),
8289
17: (0, 1),
83-
}
90+
},
8491
},
8592
"result": [SMILES_BENZENE, SMILES_CHLOROBENZENE],
8693
},
@@ -99,14 +106,15 @@ def test_json_roundtrip(self) -> None:
99106
-----
100107
It is important to save the ElementFilter as a JSON file and then load it back.
101108
This is because json.dumps() sets the keys of the dictionary to strings.
109+
102110
"""
103111
element_filter = ElementFilter()
104112
json_object = recursive_to_json(element_filter)
105113
with tempfile.TemporaryDirectory() as temp_folder:
106114
temp_file_path = Path(temp_folder) / "test.json"
107-
with open(temp_file_path, "w", encoding="UTF-8") as out_file:
115+
with temp_file_path.open("w", encoding="UTF-8") as out_file:
108116
json.dump(json_object, out_file)
109-
with open(temp_file_path, encoding="UTF-8") as in_file:
117+
with temp_file_path.open(encoding="UTF-8") as in_file:
110118
loaded_json_object = json.load(in_file)
111119
recreated_element_filter = recursive_from_json(loaded_json_object)
112120

@@ -117,7 +125,8 @@ def test_json_roundtrip(self) -> None:
117125
with self.subTest(param_name=param_name):
118126
self.assertTrue(
119127
compare_recursive(original_value, recreated_params[param_name]),
120-
f"Original: {original_value}, Recreated: {recreated_params[param_name]}",
128+
f"Original: {original_value}, "
129+
f"Recreated: {recreated_params[param_name]}",
121130
)
122131

123132

@@ -132,6 +141,7 @@ def _create_pipeline() -> Pipeline:
132141
-------
133142
Pipeline
134143
Pipeline with a complex filter.
144+
135145
"""
136146
element_filter_1 = ElementFilter({6: 6, 1: 6})
137147
element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1})
@@ -140,18 +150,17 @@ def _create_pipeline() -> Pipeline:
140150
(
141151
("element_filter_1", element_filter_1),
142152
("element_filter_2", element_filter_2),
143-
)
153+
),
144154
)
145155

146-
pipeline = Pipeline(
156+
return Pipeline(
147157
[
148158
("Smiles2Mol", SmilesToMol()),
149159
("MultiElementFilter", multi_element_filter),
150160
("Mol2Smiles", MolToSmiles()),
151161
("ErrorFilter", ErrorFilter()),
152162
],
153163
)
154-
return pipeline
155164

156165
def test_complex_filter(self) -> None:
157166
"""Test if molecules are filtered correctly by allowed chemical elements."""
@@ -169,7 +178,7 @@ def test_complex_filter(self) -> None:
169178
{
170179
"params": {
171180
"MultiElementFilter__mode": "any",
172-
"MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False,
181+
"MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False, # noqa: E501
173182
},
174183
"result": [SMILES_CHLOROBENZENE],
175184
},
@@ -198,7 +207,7 @@ def test_complex_filter_non_unique_names(self) -> None:
198207

199208
with self.assertRaises(ValueError):
200209
ComplexFilter(
201-
(("filter_1", element_filter_1), ("filter_1", element_filter_2))
210+
(("filter_1", element_filter_1), ("filter_1", element_filter_2)),
202211
)
203212

204213

@@ -285,7 +294,11 @@ def test_smarts_smiles_filter_wrong_pattern(self) -> None:
285294
SmilesFilter(smiles_pats)
286295

287296
def test_smarts_filter_parallel(self) -> None:
288-
"""Test if molecules are filtered correctly by allowed SMARTS patterns in parallel."""
297+
"""Test if molecules are filtered correctly.
298+
299+
This test runs the SmartsFilter in parallel.
300+
301+
"""
289302
smarts_pats: dict[str, IntOrIntCountRange] = {
290303
"c": (4, None),
291304
"Cl": 1,
@@ -352,31 +365,31 @@ def test_descriptor_filter(self) -> None:
352365
},
353366
{
354367
"params": {
355-
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)}
368+
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)},
356369
},
357370
"result": [SMILES_CL_BR],
358371
},
359372
{
360373
"params": {
361-
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)}
374+
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)},
362375
},
363376
"result": [],
364377
},
365378
{
366379
"params": {
367-
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)}
380+
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)},
368381
},
369382
"result": [SMILES_CL_BR],
370383
},
371384
{
372385
"params": {
373-
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)}
386+
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)},
374387
},
375388
"result": [SMILES_CL_BR],
376389
},
377390
{
378391
"params": {
379-
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)}
392+
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)},
380393
},
381394
"result": [],
382395
},
@@ -409,7 +422,7 @@ def test_invalidate_mixtures(self) -> None:
409422
("mol2smi", mol2smi),
410423
("error_filter", error_filter),
411424
("error_replacer", error_replacer),
412-
]
425+
],
413426
)
414427
mols_processed = pipeline.fit_transform(mol_list)
415428
self.assertEqual(expected_invalidated_mol_list, mols_processed)
@@ -424,7 +437,7 @@ def test_inorganic_filter(self) -> None:
424437
inorganics_filter = InorganicsFilter()
425438
mol2smiles = MolToSmiles()
426439
error_filter = ErrorFilter.from_element_list(
427-
[smiles2mol, inorganics_filter, mol2smiles]
440+
[smiles2mol, inorganics_filter, mol2smiles],
428441
)
429442
pipeline = Pipeline(
430443
[

0 commit comments

Comments
 (0)