1- """Test MolFilter, which invalidate molecules based on criteria defined in the respective filter."""
1+ """Unittest for MolFilters functionality.
2+
3+ MolFilters flag Molecules as invalid based on the criteria defined in the filter.
4+
5+ """
26
37import json
48import tempfile
59import unittest
610from pathlib import Path
11+ from typing import TYPE_CHECKING
712
813from molpipeline import ErrorFilter , FilterReinserter , Pipeline
914from molpipeline .any2mol import SmilesToMol
1924)
2025from molpipeline .utils .comparison import compare_recursive
2126from molpipeline .utils .json_operations import recursive_from_json , recursive_to_json
22- from molpipeline .utils .molpipeline_types import FloatCountRange , IntOrIntCountRange
27+
28+ if TYPE_CHECKING :
29+ from molpipeline .utils .molpipeline_types import FloatCountRange , IntOrIntCountRange
2330
2431# pylint: disable=duplicate-code # test case molecules are allowed to be duplicated
2532SMILES_ANTIMONY = "[SbH6+3]"
2633SMILES_BENZENE = "c1ccccc1"
2734SMILES_CHLOROBENZENE = "Clc1ccccc1"
2835SMILES_CL_BR = "NC(Cl)(Br)C(=O)O"
29- SMILES_METAL_AU = "OC[C@H]1OC(S [Au])[C@H](O)[C@@H](O)[C@@H]1O"
36+ SMILES_METAL_AU = "OC[C@H]1OC([S] [Au])[C@H](O)[C@@H](O)[C@@H]1O"
3037
3138SMILES_LIST = [
3239 SMILES_ANTIMONY ,
@@ -80,7 +87,7 @@ def test_element_filter(self) -> None:
8087 6 : 6 ,
8188 1 : (5 , 6 ),
8289 17 : (0 , 1 ),
83- }
90+ },
8491 },
8592 "result" : [SMILES_BENZENE , SMILES_CHLOROBENZENE ],
8693 },
@@ -99,14 +106,15 @@ def test_json_roundtrip(self) -> None:
99106 -----
100107 It is important to save the ElementFilter as a JSON file and then load it back.
101108 This is because json.dumps() sets the keys of the dictionary to strings.
109+
102110 """
103111 element_filter = ElementFilter ()
104112 json_object = recursive_to_json (element_filter )
105113 with tempfile .TemporaryDirectory () as temp_folder :
106114 temp_file_path = Path (temp_folder ) / "test.json"
107- with open (temp_file_path , "w" , encoding = "UTF-8" ) as out_file :
115+ with temp_file_path . open ("w" , encoding = "UTF-8" ) as out_file :
108116 json .dump (json_object , out_file )
109- with open (temp_file_path , encoding = "UTF-8" ) as in_file :
117+ with temp_file_path . open (encoding = "UTF-8" ) as in_file :
110118 loaded_json_object = json .load (in_file )
111119 recreated_element_filter = recursive_from_json (loaded_json_object )
112120
@@ -117,7 +125,8 @@ def test_json_roundtrip(self) -> None:
117125 with self .subTest (param_name = param_name ):
118126 self .assertTrue (
119127 compare_recursive (original_value , recreated_params [param_name ]),
120- f"Original: { original_value } , Recreated: { recreated_params [param_name ]} " ,
128+ f"Original: { original_value } , "
129+ f"Recreated: { recreated_params [param_name ]} " ,
121130 )
122131
123132
@@ -132,6 +141,7 @@ def _create_pipeline() -> Pipeline:
132141 -------
133142 Pipeline
134143 Pipeline with a complex filter.
144+
135145 """
136146 element_filter_1 = ElementFilter ({6 : 6 , 1 : 6 })
137147 element_filter_2 = ElementFilter ({6 : 6 , 1 : 5 , 17 : 1 })
@@ -140,18 +150,17 @@ def _create_pipeline() -> Pipeline:
140150 (
141151 ("element_filter_1" , element_filter_1 ),
142152 ("element_filter_2" , element_filter_2 ),
143- )
153+ ),
144154 )
145155
146- pipeline = Pipeline (
156+ return Pipeline (
147157 [
148158 ("Smiles2Mol" , SmilesToMol ()),
149159 ("MultiElementFilter" , multi_element_filter ),
150160 ("Mol2Smiles" , MolToSmiles ()),
151161 ("ErrorFilter" , ErrorFilter ()),
152162 ],
153163 )
154- return pipeline
155164
156165 def test_complex_filter (self ) -> None :
157166 """Test if molecules are filtered correctly by allowed chemical elements."""
@@ -169,7 +178,7 @@ def test_complex_filter(self) -> None:
169178 {
170179 "params" : {
171180 "MultiElementFilter__mode" : "any" ,
172- "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens" : False ,
181+ "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens" : False , # noqa: E501
173182 },
174183 "result" : [SMILES_CHLOROBENZENE ],
175184 },
@@ -198,7 +207,7 @@ def test_complex_filter_non_unique_names(self) -> None:
198207
199208 with self .assertRaises (ValueError ):
200209 ComplexFilter (
201- (("filter_1" , element_filter_1 ), ("filter_1" , element_filter_2 ))
210+ (("filter_1" , element_filter_1 ), ("filter_1" , element_filter_2 )),
202211 )
203212
204213
@@ -285,7 +294,11 @@ def test_smarts_smiles_filter_wrong_pattern(self) -> None:
285294 SmilesFilter (smiles_pats )
286295
287296 def test_smarts_filter_parallel (self ) -> None :
288- """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel."""
297+ """Test if molecules are filtered correctly.
298+
299+ This test runs the SmartsFilter in parallel.
300+
301+ """
289302 smarts_pats : dict [str , IntOrIntCountRange ] = {
290303 "c" : (4 , None ),
291304 "Cl" : 1 ,
@@ -352,31 +365,31 @@ def test_descriptor_filter(self) -> None:
352365 },
353366 {
354367 "params" : {
355- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1.99 , 4 )}
368+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1.99 , 4 )},
356369 },
357370 "result" : [SMILES_CL_BR ],
358371 },
359372 {
360373 "params" : {
361- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (2.01 , 4 )}
374+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (2.01 , 4 )},
362375 },
363376 "result" : [],
364377 },
365378 {
366379 "params" : {
367- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.00 )}
380+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.00 )},
368381 },
369382 "result" : [SMILES_CL_BR ],
370383 },
371384 {
372385 "params" : {
373- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.01 )}
386+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 2.01 )},
374387 },
375388 "result" : [SMILES_CL_BR ],
376389 },
377390 {
378391 "params" : {
379- "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 1.99 )}
392+ "DescriptorsFilter__filter_elements" : {"NumHAcceptors" : (1 , 1.99 )},
380393 },
381394 "result" : [],
382395 },
@@ -409,7 +422,7 @@ def test_invalidate_mixtures(self) -> None:
409422 ("mol2smi" , mol2smi ),
410423 ("error_filter" , error_filter ),
411424 ("error_replacer" , error_replacer ),
412- ]
425+ ],
413426 )
414427 mols_processed = pipeline .fit_transform (mol_list )
415428 self .assertEqual (expected_invalidated_mol_list , mols_processed )
@@ -424,7 +437,7 @@ def test_inorganic_filter(self) -> None:
424437 inorganics_filter = InorganicsFilter ()
425438 mol2smiles = MolToSmiles ()
426439 error_filter = ErrorFilter .from_element_list (
427- [smiles2mol , inorganics_filter , mol2smiles ]
440+ [smiles2mol , inorganics_filter , mol2smiles ],
428441 )
429442 pipeline = Pipeline (
430443 [
0 commit comments