1- """Test MolFilter, which invalidate molecules based on criteria defined in the respective filter."""
1+ """Unittest for MolFilter functionality.
2+
3+ MolFilter set flag Molecules as invalid based on the criteria defined in the filter.
4+
5+ """
26
37import json
48import tempfile
59import unittest
610from pathlib import Path
11+ from typing import TYPE_CHECKING
712
813from molpipeline import ErrorFilter , FilterReinserter , Pipeline
914from molpipeline .any2mol import SmilesToMol
1924)
2025from molpipeline .utils .comparison import compare_recursive
2126from molpipeline .utils .json_operations import recursive_from_json , recursive_to_json
22- from molpipeline .utils .molpipeline_types import FloatCountRange , IntOrIntCountRange
27+
28+ if TYPE_CHECKING :
29+ from molpipeline .utils .molpipeline_types import FloatCountRange , IntOrIntCountRange
2330
2431# pylint: disable=duplicate-code # test case molecules are allowed to be duplicated
2532SMILES_ANTIMONY = "[SbH6+3]"
@@ -80,15 +87,15 @@ def test_element_filter(self) -> None:
8087 6 : 6 ,
8188 1 : (5 , 6 ),
8289 17 : (0 , 1 ),
83- }
90+ },
8491 },
8592 "result" : [SMILES_BENZENE , SMILES_CHLOROBENZENE ],
8693 },
8794 {"params" : {"ElementFilter__add_hydrogens" : False }, "result" : []},
8895 ]
8996
9097 for test_params in test_params_list_with_results :
91- pipeline .set_params (** test_params ["params" ])
98+ pipeline .set_params (** test_params ["params" ]) # type: ignore
9299 filtered_smiles = pipeline .fit_transform (SMILES_LIST )
93100 self .assertEqual (filtered_smiles , test_params ["result" ])
94101
@@ -99,14 +106,15 @@ def test_json_roundtrip(self) -> None:
99106 -----
100107 It is important to save the ElementFilter as a JSON file and then load it back.
101108 This is because json.dumps() sets the keys of the dictionary to strings.
109+
102110 """
103111 element_filter = ElementFilter ()
104112 json_object = recursive_to_json (element_filter )
105113 with tempfile .TemporaryDirectory () as temp_folder :
106114 temp_file_path = Path (temp_folder ) / "test.json"
107- with open (temp_file_path , "w" , encoding = "UTF-8" ) as out_file :
115+ with temp_file_path . open ("w" , encoding = "UTF-8" ) as out_file :
108116 json .dump (json_object , out_file )
109- with open (temp_file_path , encoding = "UTF-8" ) as in_file :
117+ with temp_file_path . open (encoding = "UTF-8" ) as in_file :
110118 loaded_json_object = json .load (in_file )
111119 recreated_element_filter = recursive_from_json (loaded_json_object )
112120
@@ -117,7 +125,8 @@ def test_json_roundtrip(self) -> None:
117125 with self .subTest (param_name = param_name ):
118126 self .assertTrue (
119127 compare_recursive (original_value , recreated_params [param_name ]),
120- f"Original: { original_value } , Recreated: { recreated_params [param_name ]} " ,
128+ f"Original: { original_value } , "
129+ f"Recreated: { recreated_params [param_name ]} " ,
121130 )
122131
123132
@@ -132,6 +141,7 @@ def _create_pipeline() -> Pipeline:
132141 -------
133142 Pipeline
134143 Pipeline with a complex filter.
144+
135145 """
136146 element_filter_1 = ElementFilter ({6 : 6 , 1 : 6 })
137147 element_filter_2 = ElementFilter ({6 : 6 , 1 : 5 , 17 : 1 })
@@ -140,18 +150,17 @@ def _create_pipeline() -> Pipeline:
140150 (
141151 ("element_filter_1" , element_filter_1 ),
142152 ("element_filter_2" , element_filter_2 ),
143- )
153+ ),
144154 )
145155
146- pipeline = Pipeline (
156+ return Pipeline (
147157 [
148158 ("Smiles2Mol" , SmilesToMol ()),
149159 ("MultiElementFilter" , multi_element_filter ),
150160 ("Mol2Smiles" , MolToSmiles ()),
151161 ("ErrorFilter" , ErrorFilter ()),
152162 ],
153163 )
154- return pipeline
155164
156165 def test_complex_filter (self ) -> None :
157166 """Test if molecules are filtered correctly by allowed chemical elements."""
@@ -169,14 +178,14 @@ def test_complex_filter(self) -> None:
169178 {
170179 "params" : {
171180 "MultiElementFilter__mode" : "any" ,
172- "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens" : False ,
181+ "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens" : False , # noqa: E501
173182 },
174183 "result" : [SMILES_CHLOROBENZENE ],
175184 },
176185 ]
177186
178187 for test_params in test_params_list_with_results :
179- pipeline .set_params (** test_params ["params" ])
188+ pipeline .set_params (** test_params ["params" ]) # type: ignore
180189 filtered_smiles = pipeline .fit_transform (SMILES_LIST )
181190 self .assertEqual (filtered_smiles , test_params ["result" ])
182191
@@ -198,7 +207,7 @@ def test_complex_filter_non_unique_names(self) -> None:
198207
199208 with self .assertRaises (ValueError ):
200209 ComplexFilter (
201- (("filter_1" , element_filter_1 ), ("filter_1" , element_filter_2 ))
210+ (("filter_1" , element_filter_1 ), ("filter_1" , element_filter_2 )),
202211 )
203212
204213
@@ -264,7 +273,7 @@ def test_smarts_smiles_filter(self) -> None:
264273 ]
265274
266275 for test_params in test_params_list_with_results :
267- pipeline .set_params (** test_params ["params" ])
276+ pipeline .set_params (** test_params ["params" ]) # type: ignore
268277 filtered_smiles = pipeline .fit_transform (SMILES_LIST )
269278 self .assertEqual (filtered_smiles , test_params ["result" ])
270279
@@ -285,7 +294,11 @@ def test_smarts_smiles_filter_wrong_pattern(self) -> None:
285294 SmilesFilter (smiles_pats )
286295
287296 def test_smarts_filter_parallel (self ) -> None :
288- """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel."""
297+ """Test if molecules are filtered correctly.
298+
299+ This test runs the SmartsFilter in parallel.
300+
301+ """
289302 smarts_pats : dict [str , IntOrIntCountRange ] = {
290303 "c" : (4 , None ),
291304 "Cl" : 1 ,
@@ -352,38 +365,38 @@ def test_descriptor_filter(self) -> None:
352365 },
353366 {
354367 "params" : {
355- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1.99 , 4 )}
368+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1.99 , 4 )},
356369 },
357370 "result" : [SMILES_CL_BR ],
358371 },
359372 {
360373 "params" : {
361- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (2.01 , 4 )}
374+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (2.01 , 4 )},
362375 },
363376 "result" : [],
364377 },
365378 {
366379 "params" : {
367- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.00 )}
380+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.00 )},
368381 },
369382 "result" : [SMILES_CL_BR ],
370383 },
371384 {
372385 "params" : {
373- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.01 )}
386+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.01 )},
374387 },
375388 "result" : [SMILES_CL_BR ],
376389 },
377390 {
378391 "params" : {
379- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 1.99 )}
392+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 1.99 )},
380393 },
381394 "result" : [],
382395 },
383396 ]
384397
385398 for test_params in test_params_list_with_results :
386- pipeline .set_params (** test_params ["params" ])
399+ pipeline .set_params (** test_params ["params" ]) # type: ignore
387400 filtered_smiles = pipeline .fit_transform (SMILES_LIST )
388401 self .assertEqual (filtered_smiles , test_params ["result" ])
389402
@@ -409,7 +422,7 @@ def test_invalidate_mixtures(self) -> None:
409422 ("mol2smi" , mol2smi ),
410423 ("error_filter" , error_filter ),
411424 ("error_replacer" , error_replacer ),
412- ]
425+ ],
413426 )
414427 mols_processed = pipeline .fit_transform (mol_list )
415428 self .assertEqual (expected_invalidated_mol_list , mols_processed )
@@ -424,7 +437,7 @@ def test_inorganic_filter(self) -> None:
424437 inorganics_filter = InorganicsFilter ()
425438 mol2smiles = MolToSmiles ()
426439 error_filter = ErrorFilter .from_element_list (
427- [smiles2mol , inorganics_filter , mol2smiles ]
440+ [smiles2mol , inorganics_filter , mol2smiles ],
428441 )
429442 pipeline = Pipeline (
430443 [
0 commit comments