Skip to content

Commit 2563554

Browse files
Pranav ChoudharyPranav Choudhary
authored andcommitted
Add unit tests for benchmarking return_raw parameter
1 parent 7cbf6ce commit 2563554

1 file changed

Lines changed: 53 additions & 0 deletions

File tree

pyaptamer/benchmarking/tests/test_benchmarking.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,56 @@ def test_benchmarking_with_predefined_split_regression(aptamer_seq, protein_seq)
6868
assert "train" in summary.columns
6969
assert "test" in summary.columns
7070
assert (reg.__class__.__name__, "mean_squared_error") in summary.index
71+
72+
73+
@pytest.mark.parametrize("aptamer_seq, protein_seq", params)
74+
def test_benchmarking_return_raw(aptamer_seq, protein_seq):
75+
"""
76+
Test that Benchmarking.run(return_raw=True) returns both the summary and raw DataFrame.
77+
"""
78+
X_raw = [(aptamer_seq, protein_seq) for _ in range(20)]
79+
y = np.array([0] * 10 + [1] * 10, dtype=np.float32)
80+
81+
clf = AptaNetPipeline()
82+
test_fold = np.ones(len(y), dtype=int) * -1
83+
test_fold[-2:] = 0
84+
cv = PredefinedSplit(test_fold)
85+
86+
bench = Benchmarking(
87+
estimators=[clf],
88+
metrics=[accuracy_score],
89+
X=X_raw,
90+
y=y,
91+
cv=cv,
92+
)
93+
summary, raw = bench.run(return_raw=True)
94+
95+
# Check summary
96+
assert "train" in summary.columns
97+
assert "test" in summary.columns
98+
99+
# Check raw
100+
assert "train" in raw.columns
101+
assert "test" in raw.columns
102+
assert raw.index.names == ["estimator", "metric", "fold"]
103+
assert len(raw) > 0
104+
105+
106+
def test_benchmarking_labels():
107+
"""
108+
Test that Benchmarking validates the length of labels matches estimators.
109+
"""
110+
clf1 = AptaNetPipeline()
111+
clf2 = AptaNetPipeline()
112+
113+
# Passing 2 estimators but only 1 label should raise ValueError
114+
with pytest.raises(ValueError, match="Length of labels must match length of estimators"):
115+
bench = Benchmarking(
116+
estimators=[clf1, clf2],
117+
metrics=[accuracy_score],
118+
X=[("A", "B")],
119+
y=np.array([1]),
120+
cv=None,
121+
labels=["OnlyOneLabel"]
122+
)
123+
bench.run()

0 commit comments

Comments
 (0)