11from dataclasses import replace
2- from typing import Any , Callable , Dict , List , Optional , Tuple
2+ from typing import Any , Dict , List , Optional , Tuple
33import numpy as np
44import pytest
55import scipy .sparse as sp
2828# ----------------------------
2929
3030def _rdkit_generators () -> List [Tuple [str , Any ]]:
31- # rdFingerprintGenerator-style objects
3231 return [
3332 ("rdkit_morgan_2048_r2" , rdFingerprintGenerator .GetMorganGenerator (radius = 2 , fpSize = 2048 )),
3433 ("rdkit_rdkitfp_2048" , rdFingerprintGenerator .GetRDKitFPGenerator (fpSize = 2048 )),
@@ -51,43 +50,29 @@ def _skfp_generators() -> Dict[str, Any]:
5150
5251
5352def _supports_count_param (params : Dict [str , Any ]) -> Optional [str ]:
54- """
55- Heuristic: discover a "count" flag parameter if present.
56- (scikit-fingerprints uses different spellings across versions / classes)
57- """
58- for key in ("count" , "counts" , "use_counts" , "useCounts" , "use_count" , "useCountsSimulation" ):
53+ for key in ("count" , "counts" , "use_counts" , "useCounts" , "use_count" ):
5954 if key in params :
6055 return key
6156 return None
6257
6358
6459def _build_skfp_transformers () -> List [Tuple [str , Any , bool ]]:
65- """
66- Returns [(name, transformer_instance_binary, supports_count_variant)].
67- """
6860 mod = _skfp_generators ()
69- if mod is None :
70- return []
71-
7261 out : List [Tuple [str , Any , bool ]] = []
62+
7363 for cls_name , cls in mod .items ():
7464 try :
7565 base = cls () # type: ignore[call-arg]
7666 params = base .get_params (deep = False )
77- count_key = _supports_count_param (params )
78- supports_count = count_key is not None
67+ supports_count = _supports_count_param (params ) is not None
7968 out .append ((f"skfp_{ cls_name } " , base , supports_count ))
8069 except Exception :
81- # If a particular transformer can't be constructed with defaults, skip it.
8270 continue
8371
8472 return out
8573
8674
8775def _skfp_make_variant (fp : Any , * , count : bool ) -> Any :
88- """
89- Clone transformer with `count` enabled if it supports it; otherwise return original.
90- """
9176 params = fp .get_params (deep = False )
9277 count_key = _supports_count_param (params )
9378 if count_key is None :
@@ -97,6 +82,14 @@ def _skfp_make_variant(fp: Any, *, count: bool) -> Any:
9782 return fp .__class__ (** params )
9883
9984
85+ def _supports_unfolded_variant (fp : Any ) -> bool :
86+ try :
87+ params = fp .get_params (deep = False )
88+ except Exception :
89+ return False
90+ return "variant" in params
91+
92+
10093# ----------------------------
10194# Assertions / shape utilities
10295# ----------------------------
@@ -110,15 +103,12 @@ def _assert_dense_matrix(X: np.ndarray, n_rows: int) -> None:
110103
111104
112105def _assert_binary_dense (X : np .ndarray ) -> None :
113- # allow all-zeros rows for tiny molecules, but values must be 0/1
114106 u = np .unique (X )
115107 assert set (u .tolist ()).issubset ({0.0 , 1.0 })
116108
117109
118110def _assert_count_dense (X : np .ndarray ) -> None :
119111 assert (X >= 0 ).all ()
120- # at least something non-zero overall
121- assert float (X .sum ()) >= 0.0
122112
123113
124114def _assert_csr_matrix (X : sp .csr_matrix , n_rows : int ) -> None :
@@ -138,157 +128,184 @@ def _assert_count_csr(X: sp.csr_matrix) -> None:
138128 assert (X .data >= 0 ).all ()
139129
140130
141- def _supports_unfolded_variant (fp : Any ) -> bool :
142- # Your code requires a `variant` param to do folded=False for sklearn/scikit-fingerprints
143- try :
144- params = fp .get_params (deep = False )
145- except Exception :
146- return False
147- return "variant" in params
148-
149131# ============================================================
150- # CASE 1
151- # - dense fingerprint (binary and where feasible count)
132+ # Parametrized case builders
152133# ============================================================
153134
154- def test_dense_fingerprints_binary_and_count_many_generators () -> None :
155- gens_rdkit = _rdkit_generators ()
156- gens_skfp = _build_skfp_transformers ()
157-
158- if not gens_rdkit and not gens_skfp :
159- pytest .skip ("No fingerprint generators available (RDKit/scikit-fingerprints not importable)." )
135+ def _case1_dense_cases ():
136+ """
137+ CASE 1:
138+ - dense fingerprint (binary and where feasible count)
139+ Each run = separate pytest param case.
140+ """
141+ cases : List [pytest .ParamSpecArg ] = []
160142
161- # RDKit: always supports both binary + count
162- for name , gen in gens_rdkit :
163- # binary
164- cfg = FingerprintConfig (count = False , folded = True , return_csr = False , invalid_policy = "keep" )
165- X = compute_fingerprints (SMILES , gen , cfg , show_progress = False , n_jobs = 1 )
166- _assert_dense_matrix (X , n_rows = len (SMILES ))
167- _assert_binary_dense (X )
143+ # RDKit: binary + count
144+ for name , gen in _rdkit_generators ():
145+ cfg_bin = FingerprintConfig (count = False , folded = True , return_csr = False , invalid_policy = "keep" )
146+ cases .append (pytest .param (name , gen , cfg_bin , "dense-binary" , id = f"{ name } __dense__binary" ))
168147
169- # count
170- cfg = replace (cfg , count = True )
171- X = compute_fingerprints (SMILES , gen , cfg , show_progress = False , n_jobs = 1 )
172- _assert_dense_matrix (X , n_rows = len (SMILES ))
173- _assert_count_dense (X )
148+ cfg_cnt = replace (cfg_bin , count = True )
149+ cases .append (pytest .param (name , gen , cfg_cnt , "dense-count" , id = f"{ name } __dense__count" ))
174150
175- # scikit-fingerprints: binary always; count only if transformer supports it
176- for name , fp , supports_count in gens_skfp :
177- # binary
178- cfg = FingerprintConfig (count = False , folded = True , return_csr = False , invalid_policy = "keep" )
179- X = compute_fingerprints (SMILES , fp , cfg , show_progress = False , n_jobs = 1 )
180- _assert_dense_matrix (X , n_rows = len (SMILES ))
181- _assert_binary_dense (X )
151+ # skfp: binary always; count only if supported
152+ for name , fp , supports_count in _build_skfp_transformers ():
153+ cfg_bin = FingerprintConfig (count = False , folded = True , return_csr = False , invalid_policy = "keep" )
154+ cases .append (pytest .param (name , fp , cfg_bin , "dense-binary" , id = f"{ name } __dense__binary" ))
182155
183- # count (where feasible)
184156 if supports_count :
185157 fp_count = _skfp_make_variant (fp , count = True )
186- cfg = replace (cfg , count = True )
187- X = compute_fingerprints (SMILES , fp_count , cfg , show_progress = False , n_jobs = 1 )
188- _assert_dense_matrix (X , n_rows = len (SMILES ))
189- _assert_count_dense (X )
158+ cfg_cnt = replace (cfg_bin , count = True )
159+ cases .append (pytest .param (name , fp_count , cfg_cnt , "dense-count" , id = f"{ name } __dense__count" ))
190160
161+ return cases
191162
192- # ============================================================
193- # CASE 2
194- # - if fitting the fingerprint: return_csr=True (binary and where feasible count)
195- # (We still include RDKit here because return_csr=True is supported as well.)
196- # ============================================================
197163
198- def test_return_csr_folded_binary_and_count_many_generators () -> None :
199- gens_rdkit = _rdkit_generators ()
200- gens_skfp = _build_skfp_transformers ()
164+ def _case2_csr_cases ():
165+ """
166+ CASE 2:
167+ - return_csr=True (binary and where feasible count)
168+ Each run = separate pytest param case.
169+ """
170+ cases : List [pytest .ParamSpecArg ] = []
201171
202- if not gens_rdkit and not gens_skfp :
203- pytest .skip ("No fingerprint generators available (RDKit/scikit-fingerprints not importable)." )
172+ # RDKit: binary + count
173+ for name , gen in _rdkit_generators ():
174+ cfg_bin = FingerprintConfig (count = False , folded = True , return_csr = True , invalid_policy = "keep" )
175+ cases .append (pytest .param (name , gen , cfg_bin , "csr-binary" , id = f"{ name } __csr__binary" ))
204176
205- # RDKit
206- for name , gen in gens_rdkit :
207- # binary -> CSR
208- cfg = FingerprintConfig (count = False , folded = True , return_csr = True , invalid_policy = "keep" )
209- X = compute_fingerprints (SMILES , gen , cfg , show_progress = False , n_jobs = 1 )
210- _assert_csr_matrix (X , n_rows = len (SMILES ))
211- _assert_binary_csr (X )
177+ cfg_cnt = replace (cfg_bin , count = True )
178+ cases .append (pytest .param (name , gen , cfg_cnt , "csr-count" , id = f"{ name } __csr__count" ))
212179
213- # count -> CSR
214- cfg = replace (cfg , count = True )
215- X = compute_fingerprints (SMILES , gen , cfg , show_progress = False , n_jobs = 1 )
216- _assert_csr_matrix (X , n_rows = len (SMILES ))
217- _assert_count_csr (X )
180+ # skfp: binary always; count only if supported
181+ for name , fp , supports_count in _build_skfp_transformers ():
182+ cfg_bin = FingerprintConfig (count = False , folded = True , return_csr = True , invalid_policy = "keep" )
183+ cases .append (pytest .param (name , fp , cfg_bin , "csr-binary" , id = f"{ name } __csr__binary" ))
218184
219- # scikit-fingerprints (fit/transform backend)
220- for name , fp , supports_count in gens_skfp :
221- # binary -> CSR
222- cfg = FingerprintConfig (count = False , folded = True , return_csr = True , invalid_policy = "keep" )
223- X = compute_fingerprints (SMILES , fp , cfg , show_progress = False , n_jobs = 1 )
224- _assert_csr_matrix (X , n_rows = len (SMILES ))
225- _assert_binary_csr (X )
226-
227- # count -> CSR (where feasible)
228185 if supports_count :
229186 fp_count = _skfp_make_variant (fp , count = True )
230- cfg = replace (cfg , count = True )
231- X = compute_fingerprints (SMILES , fp_count , cfg , show_progress = False , n_jobs = 1 )
232- _assert_csr_matrix (X , n_rows = len (SMILES ))
233- _assert_count_csr (X )
187+ cfg_cnt = replace (cfg_bin , count = True )
188+ cases .append (pytest .param (name , fp_count , cfg_cnt , "csr-count" , id = f"{ name } __csr__count" ))
234189
190+ return cases
235191
236- # ============================================================
237- # CASE 3
238- # - if fitting the fingerprint:
239- # return_csr=True with folded=True and folded=False
240- #
241- # (folded=False, return_csr=True) is INVALID by design
242- # and must raise (validated in _validate_config). We test both behaviors.
243- # We also compute unfolded output (folded=False) with return_csr=False to
244- # still "run them all on a few simple smiles".
245- # ============================================================
246192
247- def test_fit_backend_return_csr_true_folded_true_and_false () -> None :
248- gens_skfp = _build_skfp_transformers ()
249- if not gens_skfp :
250- pytest .skip ("scikit-fingerprints/sklearn-style transformers not importable in this env." )
193+ def _case3_fit_backend_cases ():
194+ """
195+ CASE 3 (skfp only):
196+ - "return_csr=True with folded=True and folded=False"
197+ We split into separate runs:
198+ A) folded=True, return_csr=True works
199+ B) folded=False, return_csr=True raises ValueError
200+ C) folded=False, return_csr=False:
201+ - if supports variant: works (unfolded)
202+ - else: raises NotImplementedError
203+ """
204+ cases : List [pytest .ParamSpecArg ] = []
251205
252- for name , fp , supports_count in gens_skfp :
253- # 1) folded=True + return_csr=True should work
206+ for name , fp , supports_count in _build_skfp_transformers () :
207+ # A)
254208 cfg_ok = FingerprintConfig (count = False , folded = True , return_csr = True , invalid_policy = "keep" )
255- X = compute_fingerprints ( SMILES , fp , cfg_ok , show_progress = False , n_jobs = 1 )
256- _assert_csr_matrix ( X , n_rows = len ( SMILES ))
209+ cases . append ( pytest . param ( name , fp , cfg_ok , "folded_true_csr_ok" , supports_count ,
210+ id = f" { name } __case3__folded_true__csr_ok" ))
257211
258- # 2) folded=False + return_csr=True must raise (per _validate_config )
212+ # B )
259213 cfg_bad = replace (cfg_ok , folded = False , return_csr = True )
260- with pytest .raises ( ValueError , match = "return_csr is only valid when folded=True" ):
261- compute_fingerprints ( SMILES , fp , cfg_bad , show_progress = False , n_jobs = 1 )
214+ cases . append ( pytest .param ( name , fp , cfg_bad , "folded_false_csr_raises_valueerror" , supports_count ,
215+ id = f" { name } __case3__folded_false__csr_raises_valueerror" ) )
262216
263- # 3) folded=False + return_csr=False:
264- # - run if transformer supports `variant`
265- # - otherwise, assert the documented NotImplementedError
217+ # C)
266218 cfg_unfolded = FingerprintConfig (count = False , folded = False , return_csr = False , invalid_policy = "keep" )
219+ behavior = "unfolded_ok" if _supports_unfolded_variant (fp ) else "unfolded_raises_notimplemented"
220+ cases .append (pytest .param (name , fp , cfg_unfolded , behavior , supports_count ,
221+ id = f"{ name } __case3__folded_false__unfolded_{ behavior } " ))
222+
223+ # (Optional) If you also want unfolded-count as separate cases where feasible + variant exists:
224+ if supports_count and _supports_unfolded_variant (fp ):
225+ fp_count = _skfp_make_variant (fp , count = True )
226+ cfg_unfolded_count = replace (cfg_unfolded , count = True )
227+ cases .append (pytest .param (name , fp_count , cfg_unfolded_count , "unfolded_count_ok" , True ,
228+ id = f"{ name } __case3__folded_false__unfolded_count_ok" ))
229+
230+ return cases
231+
232+
233+ # ============================================================
234+ # CASE 1: dense (parametrized per run)
235+ # ============================================================
236+
237+ @pytest .mark .parametrize ("name, fpgen, cfg, mode" , _case1_dense_cases ())
238+ def test_case1_dense_per_run (name : str , fpgen : Any , cfg : FingerprintConfig , mode : str ) -> None :
239+ X = compute_fingerprints (SMILES , fpgen , cfg , show_progress = False , n_jobs = 1 )
240+
241+ _assert_dense_matrix (X , n_rows = len (SMILES ))
242+ if mode .endswith ("binary" ):
243+ _assert_binary_dense (X )
244+ else :
245+ _assert_count_dense (X )
246+
267247
268- if _supports_unfolded_variant (fp ):
269- out = compute_fingerprints (SMILES , fp , cfg_unfolded , show_progress = False , n_jobs = 1 )
270- assert isinstance (out , list )
271- assert len (out ) == len (SMILES )
248+ # ============================================================
249+ # CASE 2: csr (parametrized per run)
250+ # ============================================================
251+
252+ @pytest .mark .parametrize ("name, fpgen, cfg, mode" , _case2_csr_cases ())
253+ def test_case2_csr_per_run (name : str , fpgen : Any , cfg : FingerprintConfig , mode : str ) -> None :
254+ X = compute_fingerprints (SMILES , fpgen , cfg , show_progress = False , n_jobs = 1 )
255+
256+ _assert_csr_matrix (X , n_rows = len (SMILES ))
257+ if mode .endswith ("binary" ):
258+ _assert_binary_csr (X )
259+ else :
260+ _assert_count_csr (X )
261+
262+
263+ # ============================================================
264+ # CASE 3: fitting backend + folded True/False (parametrized)
265+ # ============================================================
266+
267+ @pytest .mark .parametrize ("name, fpgen, cfg, behavior, supports_count" , _case3_fit_backend_cases ())
268+ def test_case3_fit_backend_per_run (
269+ name : str ,
270+ fpgen : Any ,
271+ cfg : FingerprintConfig ,
272+ behavior : str ,
273+ supports_count : bool ,
274+ ) -> None :
275+ if behavior == "folded_true_csr_ok" :
276+ X = compute_fingerprints (SMILES , fpgen , cfg , show_progress = False , n_jobs = 1 )
277+ _assert_csr_matrix (X , n_rows = len (SMILES ))
278+ return
279+
280+ if behavior == "folded_false_csr_raises_valueerror" :
281+ with pytest .raises (ValueError , match = "return_csr is only valid when folded=True" ):
282+ compute_fingerprints (SMILES , fpgen , cfg , show_progress = False , n_jobs = 1 )
283+ return
284+
285+ if behavior in ("unfolded_ok" , "unfolded_count_ok" ):
286+ out = compute_fingerprints (SMILES , fpgen , cfg , show_progress = False , n_jobs = 1 )
287+ assert isinstance (out , list )
288+ assert len (out ) == len (SMILES )
289+
290+ if cfg .count :
291+ for keys , vals in out :
292+ assert isinstance (keys , np .ndarray ) and keys .dtype == np .int64
293+ assert isinstance (vals , np .ndarray ) and vals .dtype == np .float32
294+ assert keys .shape == vals .shape
295+ assert (keys >= 0 ).all ()
296+ assert (vals >= 0 ).all ()
297+ assert np .all (keys [:- 1 ] <= keys [1 :]) if keys .size > 1 else True
298+ else :
272299 for keys in out :
273300 assert isinstance (keys , np .ndarray )
274301 assert keys .dtype == np .int64
275302 assert (keys >= 0 ).all ()
276303 assert np .all (keys [:- 1 ] <= keys [1 :]) if keys .size > 1 else True
304+ return
277305
278- # count unfolded (where feasible AND where variant exists)
279- if supports_count :
280- fp_count = _skfp_make_variant (fp , count = True )
281- cfg_unfolded_count = replace (cfg_unfolded , count = True )
282- out = compute_fingerprints (SMILES , fp_count , cfg_unfolded_count , show_progress = False , n_jobs = 1 )
283- assert isinstance (out , list )
284- assert len (out ) == len (SMILES )
285- for keys , vals in out :
286- assert isinstance (keys , np .ndarray ) and keys .dtype == np .int64
287- assert isinstance (vals , np .ndarray ) and vals .dtype == np .float32
288- assert keys .shape == vals .shape
289- assert (keys >= 0 ).all ()
290- assert (vals >= 0 ).all ()
291- assert np .all (keys [:- 1 ] <= keys [1 :]) if keys .size > 1 else True
292- else :
293- with pytest .raises (NotImplementedError , match = "Requested folded=False" ):
294- compute_fingerprints (SMILES , fp , cfg_unfolded , show_progress = False , n_jobs = 1 )
306+ if behavior == "unfolded_raises_notimplemented" :
307+ with pytest .raises (NotImplementedError , match = "Requested folded=False" ):
308+ compute_fingerprints (SMILES , fpgen , cfg , show_progress = False , n_jobs = 1 )
309+ return
310+
311+ raise AssertionError (f"Unknown behavior: { behavior !r} " )
0 commit comments