Skip to content

Commit fbcd403

Browse files
committed
refactor to parametrized pytest tests
1 parent f3e3d50 commit fbcd403

File tree

1 file changed

+159
-142
lines changed

1 file changed

+159
-142
lines changed
Lines changed: 159 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import replace
2-
from typing import Any, Callable, Dict, List, Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple
33
import numpy as np
44
import pytest
55
import scipy.sparse as sp
@@ -28,7 +28,6 @@
2828
# ----------------------------
2929

3030
def _rdkit_generators() -> List[Tuple[str, Any]]:
31-
# rdFingerprintGenerator-style objects
3231
return [
3332
("rdkit_morgan_2048_r2", rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)),
3433
("rdkit_rdkitfp_2048", rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=2048)),
@@ -51,43 +50,29 @@ def _skfp_generators() -> Dict[str, Any]:
5150

5251

5352
def _supports_count_param(params: Dict[str, Any]) -> Optional[str]:
54-
"""
55-
Heuristic: discover a "count" flag parameter if present.
56-
(scikit-fingerprints uses different spellings across versions / classes)
57-
"""
58-
for key in ("count", "counts", "use_counts", "useCounts", "use_count", "useCountsSimulation"):
53+
for key in ("count", "counts", "use_counts", "useCounts", "use_count"):
5954
if key in params:
6055
return key
6156
return None
6257

6358

6459
def _build_skfp_transformers() -> List[Tuple[str, Any, bool]]:
65-
"""
66-
Returns [(name, transformer_instance_binary, supports_count_variant)].
67-
"""
6860
mod = _skfp_generators()
69-
if mod is None:
70-
return []
71-
7261
out: List[Tuple[str, Any, bool]] = []
62+
7363
for cls_name, cls in mod.items():
7464
try:
7565
base = cls() # type: ignore[call-arg]
7666
params = base.get_params(deep=False)
77-
count_key = _supports_count_param(params)
78-
supports_count = count_key is not None
67+
supports_count = _supports_count_param(params) is not None
7968
out.append((f"skfp_{cls_name}", base, supports_count))
8069
except Exception:
81-
# If a particular transformer can't be constructed with defaults, skip it.
8270
continue
8371

8472
return out
8573

8674

8775
def _skfp_make_variant(fp: Any, *, count: bool) -> Any:
88-
"""
89-
Clone transformer with `count` enabled if it supports it; otherwise return original.
90-
"""
9176
params = fp.get_params(deep=False)
9277
count_key = _supports_count_param(params)
9378
if count_key is None:
@@ -97,6 +82,14 @@ def _skfp_make_variant(fp: Any, *, count: bool) -> Any:
9782
return fp.__class__(**params)
9883

9984

85+
def _supports_unfolded_variant(fp: Any) -> bool:
86+
try:
87+
params = fp.get_params(deep=False)
88+
except Exception:
89+
return False
90+
return "variant" in params
91+
92+
10093
# ----------------------------
10194
# Assertions / shape utilities
10295
# ----------------------------
@@ -110,15 +103,12 @@ def _assert_dense_matrix(X: np.ndarray, n_rows: int) -> None:
110103

111104

112105
def _assert_binary_dense(X: np.ndarray) -> None:
113-
# allow all-zeros rows for tiny molecules, but values must be 0/1
114106
u = np.unique(X)
115107
assert set(u.tolist()).issubset({0.0, 1.0})
116108

117109

118110
def _assert_count_dense(X: np.ndarray) -> None:
119111
assert (X >= 0).all()
120-
# at least something non-zero overall
121-
assert float(X.sum()) >= 0.0
122112

123113

124114
def _assert_csr_matrix(X: sp.csr_matrix, n_rows: int) -> None:
@@ -138,157 +128,184 @@ def _assert_count_csr(X: sp.csr_matrix) -> None:
138128
assert (X.data >= 0).all()
139129

140130

141-
def _supports_unfolded_variant(fp: Any) -> bool:
142-
# Your code requires a `variant` param to do folded=False for sklearn/scikit-fingerprints
143-
try:
144-
params = fp.get_params(deep=False)
145-
except Exception:
146-
return False
147-
return "variant" in params
148-
149131
# ============================================================
150-
# CASE 1
151-
# - dense fingerprint (binary and where feasible count)
132+
# Parametrized case builders
152133
# ============================================================
153134

154-
def test_dense_fingerprints_binary_and_count_many_generators() -> None:
155-
gens_rdkit = _rdkit_generators()
156-
gens_skfp = _build_skfp_transformers()
157-
158-
if not gens_rdkit and not gens_skfp:
159-
pytest.skip("No fingerprint generators available (RDKit/scikit-fingerprints not importable).")
135+
def _case1_dense_cases():
136+
"""
137+
CASE 1:
138+
- dense fingerprint (binary and where feasible count)
139+
Each run = separate pytest param case.
140+
"""
141+
cases: List[pytest.ParamSpecArg] = []
160142

161-
# RDKit: always supports both binary + count
162-
for name, gen in gens_rdkit:
163-
# binary
164-
cfg = FingerprintConfig(count=False, folded=True, return_csr=False, invalid_policy="keep")
165-
X = compute_fingerprints(SMILES, gen, cfg, show_progress=False, n_jobs=1)
166-
_assert_dense_matrix(X, n_rows=len(SMILES))
167-
_assert_binary_dense(X)
143+
# RDKit: binary + count
144+
for name, gen in _rdkit_generators():
145+
cfg_bin = FingerprintConfig(count=False, folded=True, return_csr=False, invalid_policy="keep")
146+
cases.append(pytest.param(name, gen, cfg_bin, "dense-binary", id=f"{name}__dense__binary"))
168147

169-
# count
170-
cfg = replace(cfg, count=True)
171-
X = compute_fingerprints(SMILES, gen, cfg, show_progress=False, n_jobs=1)
172-
_assert_dense_matrix(X, n_rows=len(SMILES))
173-
_assert_count_dense(X)
148+
cfg_cnt = replace(cfg_bin, count=True)
149+
cases.append(pytest.param(name, gen, cfg_cnt, "dense-count", id=f"{name}__dense__count"))
174150

175-
# scikit-fingerprints: binary always; count only if transformer supports it
176-
for name, fp, supports_count in gens_skfp:
177-
# binary
178-
cfg = FingerprintConfig(count=False, folded=True, return_csr=False, invalid_policy="keep")
179-
X = compute_fingerprints(SMILES, fp, cfg, show_progress=False, n_jobs=1)
180-
_assert_dense_matrix(X, n_rows=len(SMILES))
181-
_assert_binary_dense(X)
151+
# skfp: binary always; count only if supported
152+
for name, fp, supports_count in _build_skfp_transformers():
153+
cfg_bin = FingerprintConfig(count=False, folded=True, return_csr=False, invalid_policy="keep")
154+
cases.append(pytest.param(name, fp, cfg_bin, "dense-binary", id=f"{name}__dense__binary"))
182155

183-
# count (where feasible)
184156
if supports_count:
185157
fp_count = _skfp_make_variant(fp, count=True)
186-
cfg = replace(cfg, count=True)
187-
X = compute_fingerprints(SMILES, fp_count, cfg, show_progress=False, n_jobs=1)
188-
_assert_dense_matrix(X, n_rows=len(SMILES))
189-
_assert_count_dense(X)
158+
cfg_cnt = replace(cfg_bin, count=True)
159+
cases.append(pytest.param(name, fp_count, cfg_cnt, "dense-count", id=f"{name}__dense__count"))
190160

161+
return cases
191162

192-
# ============================================================
193-
# CASE 2
194-
# - if fitting the fingerprint: return_csr=True (binary and where feasible count)
195-
# (We still include RDKit here because return_csr=True is supported as well.)
196-
# ============================================================
197163

198-
def test_return_csr_folded_binary_and_count_many_generators() -> None:
199-
gens_rdkit = _rdkit_generators()
200-
gens_skfp = _build_skfp_transformers()
164+
def _case2_csr_cases():
165+
"""
166+
CASE 2:
167+
- return_csr=True (binary and where feasible count)
168+
Each run = separate pytest param case.
169+
"""
170+
cases: List[pytest.ParamSpecArg] = []
201171

202-
if not gens_rdkit and not gens_skfp:
203-
pytest.skip("No fingerprint generators available (RDKit/scikit-fingerprints not importable).")
172+
# RDKit: binary + count
173+
for name, gen in _rdkit_generators():
174+
cfg_bin = FingerprintConfig(count=False, folded=True, return_csr=True, invalid_policy="keep")
175+
cases.append(pytest.param(name, gen, cfg_bin, "csr-binary", id=f"{name}__csr__binary"))
204176

205-
# RDKit
206-
for name, gen in gens_rdkit:
207-
# binary -> CSR
208-
cfg = FingerprintConfig(count=False, folded=True, return_csr=True, invalid_policy="keep")
209-
X = compute_fingerprints(SMILES, gen, cfg, show_progress=False, n_jobs=1)
210-
_assert_csr_matrix(X, n_rows=len(SMILES))
211-
_assert_binary_csr(X)
177+
cfg_cnt = replace(cfg_bin, count=True)
178+
cases.append(pytest.param(name, gen, cfg_cnt, "csr-count", id=f"{name}__csr__count"))
212179

213-
# count -> CSR
214-
cfg = replace(cfg, count=True)
215-
X = compute_fingerprints(SMILES, gen, cfg, show_progress=False, n_jobs=1)
216-
_assert_csr_matrix(X, n_rows=len(SMILES))
217-
_assert_count_csr(X)
180+
# skfp: binary always; count only if supported
181+
for name, fp, supports_count in _build_skfp_transformers():
182+
cfg_bin = FingerprintConfig(count=False, folded=True, return_csr=True, invalid_policy="keep")
183+
cases.append(pytest.param(name, fp, cfg_bin, "csr-binary", id=f"{name}__csr__binary"))
218184

219-
# scikit-fingerprints (fit/transform backend)
220-
for name, fp, supports_count in gens_skfp:
221-
# binary -> CSR
222-
cfg = FingerprintConfig(count=False, folded=True, return_csr=True, invalid_policy="keep")
223-
X = compute_fingerprints(SMILES, fp, cfg, show_progress=False, n_jobs=1)
224-
_assert_csr_matrix(X, n_rows=len(SMILES))
225-
_assert_binary_csr(X)
226-
227-
# count -> CSR (where feasible)
228185
if supports_count:
229186
fp_count = _skfp_make_variant(fp, count=True)
230-
cfg = replace(cfg, count=True)
231-
X = compute_fingerprints(SMILES, fp_count, cfg, show_progress=False, n_jobs=1)
232-
_assert_csr_matrix(X, n_rows=len(SMILES))
233-
_assert_count_csr(X)
187+
cfg_cnt = replace(cfg_bin, count=True)
188+
cases.append(pytest.param(name, fp_count, cfg_cnt, "csr-count", id=f"{name}__csr__count"))
234189

190+
return cases
235191

236-
# ============================================================
237-
# CASE 3
238-
# - if fitting the fingerprint:
239-
# return_csr=True with folded=True and folded=False
240-
#
241-
# (folded=False, return_csr=True) is INVALID by design
242-
# and must raise (validated in _validate_config). We test both behaviors.
243-
# We also compute unfolded output (folded=False) with return_csr=False to
244-
# still "run them all on a few simple smiles".
245-
# ============================================================
246192

247-
def test_fit_backend_return_csr_true_folded_true_and_false() -> None:
248-
gens_skfp = _build_skfp_transformers()
249-
if not gens_skfp:
250-
pytest.skip("scikit-fingerprints/sklearn-style transformers not importable in this env.")
193+
def _case3_fit_backend_cases():
194+
"""
195+
CASE 3 (skfp only):
196+
- "return_csr=True with folded=True and folded=False"
197+
We split into separate runs:
198+
A) folded=True, return_csr=True works
199+
B) folded=False, return_csr=True raises ValueError
200+
C) folded=False, return_csr=False:
201+
- if supports variant: works (unfolded)
202+
- else: raises NotImplementedError
203+
"""
204+
cases: List[pytest.ParamSpecArg] = []
251205

252-
for name, fp, supports_count in gens_skfp:
253-
# 1) folded=True + return_csr=True should work
206+
for name, fp, supports_count in _build_skfp_transformers():
207+
# A)
254208
cfg_ok = FingerprintConfig(count=False, folded=True, return_csr=True, invalid_policy="keep")
255-
X = compute_fingerprints(SMILES, fp, cfg_ok, show_progress=False, n_jobs=1)
256-
_assert_csr_matrix(X, n_rows=len(SMILES))
209+
cases.append(pytest.param(name, fp, cfg_ok, "folded_true_csr_ok", supports_count,
210+
id=f"{name}__case3__folded_true__csr_ok"))
257211

258-
# 2) folded=False + return_csr=True must raise (per _validate_config)
212+
# B)
259213
cfg_bad = replace(cfg_ok, folded=False, return_csr=True)
260-
with pytest.raises(ValueError, match="return_csr is only valid when folded=True"):
261-
compute_fingerprints(SMILES, fp, cfg_bad, show_progress=False, n_jobs=1)
214+
cases.append(pytest.param(name, fp, cfg_bad, "folded_false_csr_raises_valueerror", supports_count,
215+
id=f"{name}__case3__folded_false__csr_raises_valueerror"))
262216

263-
# 3) folded=False + return_csr=False:
264-
# - run if transformer supports `variant`
265-
# - otherwise, assert the documented NotImplementedError
217+
# C)
266218
cfg_unfolded = FingerprintConfig(count=False, folded=False, return_csr=False, invalid_policy="keep")
219+
behavior = "unfolded_ok" if _supports_unfolded_variant(fp) else "unfolded_raises_notimplemented"
220+
cases.append(pytest.param(name, fp, cfg_unfolded, behavior, supports_count,
221+
id=f"{name}__case3__folded_false__unfolded_{behavior}"))
222+
223+
# (Optional) If you also want unfolded-count as separate cases where feasible + variant exists:
224+
if supports_count and _supports_unfolded_variant(fp):
225+
fp_count = _skfp_make_variant(fp, count=True)
226+
cfg_unfolded_count = replace(cfg_unfolded, count=True)
227+
cases.append(pytest.param(name, fp_count, cfg_unfolded_count, "unfolded_count_ok", True,
228+
id=f"{name}__case3__folded_false__unfolded_count_ok"))
229+
230+
return cases
231+
232+
233+
# ============================================================
234+
# CASE 1: dense (parametrized per run)
235+
# ============================================================
236+
237+
@pytest.mark.parametrize("name, fpgen, cfg, mode", _case1_dense_cases())
238+
def test_case1_dense_per_run(name: str, fpgen: Any, cfg: FingerprintConfig, mode: str) -> None:
239+
X = compute_fingerprints(SMILES, fpgen, cfg, show_progress=False, n_jobs=1)
240+
241+
_assert_dense_matrix(X, n_rows=len(SMILES))
242+
if mode.endswith("binary"):
243+
_assert_binary_dense(X)
244+
else:
245+
_assert_count_dense(X)
246+
267247

268-
if _supports_unfolded_variant(fp):
269-
out = compute_fingerprints(SMILES, fp, cfg_unfolded, show_progress=False, n_jobs=1)
270-
assert isinstance(out, list)
271-
assert len(out) == len(SMILES)
248+
# ============================================================
249+
# CASE 2: csr (parametrized per run)
250+
# ============================================================
251+
252+
@pytest.mark.parametrize("name, fpgen, cfg, mode", _case2_csr_cases())
253+
def test_case2_csr_per_run(name: str, fpgen: Any, cfg: FingerprintConfig, mode: str) -> None:
254+
X = compute_fingerprints(SMILES, fpgen, cfg, show_progress=False, n_jobs=1)
255+
256+
_assert_csr_matrix(X, n_rows=len(SMILES))
257+
if mode.endswith("binary"):
258+
_assert_binary_csr(X)
259+
else:
260+
_assert_count_csr(X)
261+
262+
263+
# ============================================================
264+
# CASE 3: fitting backend + folded True/False (parametrized)
265+
# ============================================================
266+
267+
@pytest.mark.parametrize("name, fpgen, cfg, behavior, supports_count", _case3_fit_backend_cases())
268+
def test_case3_fit_backend_per_run(
269+
name: str,
270+
fpgen: Any,
271+
cfg: FingerprintConfig,
272+
behavior: str,
273+
supports_count: bool,
274+
) -> None:
275+
if behavior == "folded_true_csr_ok":
276+
X = compute_fingerprints(SMILES, fpgen, cfg, show_progress=False, n_jobs=1)
277+
_assert_csr_matrix(X, n_rows=len(SMILES))
278+
return
279+
280+
if behavior == "folded_false_csr_raises_valueerror":
281+
with pytest.raises(ValueError, match="return_csr is only valid when folded=True"):
282+
compute_fingerprints(SMILES, fpgen, cfg, show_progress=False, n_jobs=1)
283+
return
284+
285+
if behavior in ("unfolded_ok", "unfolded_count_ok"):
286+
out = compute_fingerprints(SMILES, fpgen, cfg, show_progress=False, n_jobs=1)
287+
assert isinstance(out, list)
288+
assert len(out) == len(SMILES)
289+
290+
if cfg.count:
291+
for keys, vals in out:
292+
assert isinstance(keys, np.ndarray) and keys.dtype == np.int64
293+
assert isinstance(vals, np.ndarray) and vals.dtype == np.float32
294+
assert keys.shape == vals.shape
295+
assert (keys >= 0).all()
296+
assert (vals >= 0).all()
297+
assert np.all(keys[:-1] <= keys[1:]) if keys.size > 1 else True
298+
else:
272299
for keys in out:
273300
assert isinstance(keys, np.ndarray)
274301
assert keys.dtype == np.int64
275302
assert (keys >= 0).all()
276303
assert np.all(keys[:-1] <= keys[1:]) if keys.size > 1 else True
304+
return
277305

278-
# count unfolded (where feasible AND where variant exists)
279-
if supports_count:
280-
fp_count = _skfp_make_variant(fp, count=True)
281-
cfg_unfolded_count = replace(cfg_unfolded, count=True)
282-
out = compute_fingerprints(SMILES, fp_count, cfg_unfolded_count, show_progress=False, n_jobs=1)
283-
assert isinstance(out, list)
284-
assert len(out) == len(SMILES)
285-
for keys, vals in out:
286-
assert isinstance(keys, np.ndarray) and keys.dtype == np.int64
287-
assert isinstance(vals, np.ndarray) and vals.dtype == np.float32
288-
assert keys.shape == vals.shape
289-
assert (keys >= 0).all()
290-
assert (vals >= 0).all()
291-
assert np.all(keys[:-1] <= keys[1:]) if keys.size > 1 else True
292-
else:
293-
with pytest.raises(NotImplementedError, match="Requested folded=False"):
294-
compute_fingerprints(SMILES, fp, cfg_unfolded, show_progress=False, n_jobs=1)
306+
if behavior == "unfolded_raises_notimplemented":
307+
with pytest.raises(NotImplementedError, match="Requested folded=False"):
308+
compute_fingerprints(SMILES, fpgen, cfg, show_progress=False, n_jobs=1)
309+
return
310+
311+
raise AssertionError(f"Unknown behavior: {behavior!r}")

0 commit comments

Comments
 (0)