Skip to content

Commit 673ec7e

Browse files
authored
Merge pull request #51 from UBC-MDS/feature/more_test
looks good, merged
2 parents 67333fe + 29c3cfe commit 673ec7e

File tree

5 files changed

+143
-28
lines changed

5 files changed

+143
-28
lines changed

src/ridge_remake/ridge_remake.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

tests/unit/test_get_line.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,34 @@ def test_pandas_compatibility():
6262
y_pred = get_reg_line(X, y)
6363
assert isinstance(y_pred, np.ndarray)
6464
assert len(y_pred) == 3
65+
66+
def test_mismatched_dimensions():
67+
"""
68+
Verify that providing X and y with a different number of samples
69+
raises a ValueError during matrix multiplication.
70+
"""
71+
# X has 3 samples, but y only has 2 samples
72+
X = np.array([[1], [2], [3]])
73+
y = np.array([1, 2])
74+
75+
# NumPy's matrix multiplication (@) will raise ValueError due to dimension mismatch
76+
with pytest.raises(ValueError):
77+
get_reg_line(X, y)
78+
79+
def test_multi_target_regression():
80+
"""
81+
Verify that the function can handle multiple targets (y) simultaneously.
82+
If y is (n_samples, n_targets), y_pred should also be (n_samples, n_targets).
83+
"""
84+
X = np.array([[1], [2], [3]])
85+
# Target 1: y = 2x, Target 2: y = 10x
86+
y = np.array([[2, 10],
87+
[4, 20],
88+
[6, 30]])
89+
90+
y_pred = get_reg_line(X, y)
91+
92+
# Check if the output shape matches the multi-target input shape
93+
assert y_pred.shape == (3, 2)
94+
# Check if the predictions are accurate for both targets
95+
np.testing.assert_allclose(y_pred, y, rtol=1e-5)

tests/unit/test_ridge_r2.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,35 @@ def test_r2_constant_target_perfect():
111111
"""Verify that a constant target with perfect prediction returns 1.0."""
112112
y_true = [5, 5, 5]
113113
y_pred = [5, 5, 5]
114-
assert ridge_get_r2(y_true, y_pred) == 1.0
114+
assert ridge_get_r2(y_true, y_pred) == 1.0
115+
116+
def test_r2_negative_score():
117+
"""
118+
Conceptual Case: Worse Than Mean
119+
Verify that the function correctly returns a negative value when the
120+
predictions are worse than simply predicting the mean of y_true.
121+
"""
122+
y_true = np.array([1, 2, 3]) # Mean is 2.0, SS_tot is 2.0
123+
# Predictions that are very far from the true values
124+
y_pred = np.array([10, 20, 30])
125+
126+
result = ridge_get_r2(y_true, y_pred)
127+
128+
# R2 = 1 - (SS_res / SS_tot). If SS_res > SS_tot, R2 must be negative.
129+
assert result < 0.0
130+
assert isinstance(result, float)
131+
132+
def test_r2_2d_column_vectors():
133+
"""
134+
Structural Case: 2D Column Vectors (n_samples, 1)
135+
In many ML workflows, y is often reshaped as a 2D column vector.
136+
Verify that the function handles (n, 1) shapes correctly via numpy broadcasting.
137+
"""
138+
y_true = np.array([[1.0], [2.0], [3.0]])
139+
y_pred = np.array([[1.1], [1.9], [3.1]])
140+
141+
result = ridge_get_r2(y_true, y_pred)
142+
143+
# The result should be a single float value, even for 2D inputs
144+
assert np.isscalar(result)
145+
assert 0.0 < result < 1.0

tests/unit/test_ridge_scatter.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
matplotlib.use("Agg")
1919
import matplotlib.pyplot as plt
20+
import numpy as np
2021
import pytest
2122
from matplotlib.collections import PathCollection
2223

@@ -53,3 +54,38 @@ def test_ridge_scatter():
5354

5455
with pytest.raises(TypeError, match="ax"):
5556
ridge_scatter(None, [1, 2], [3, 4], scatter_kwargs=None)
57+
58+
def test_ridge_scatter_input_types():
59+
"""
60+
Verify that the function handles various array-like inputs (lists, numpy arrays)
61+
and flattens them correctly using .ravel().
62+
"""
63+
fig, ax = plt.subplots()
64+
65+
# Nested lists or 2D arrays should be flattened without error
66+
x_input = [[1], [2], [3]]
67+
y_input = np.array([[10], [20], [30]])
68+
69+
_, points = ridge_scatter(ax, x_input, y_input)
70+
71+
# Resulting offsets should have shape (3, 2) representing (x, y) coordinates
72+
assert points.get_offsets().shape == (3, 2)
73+
74+
plt.close(fig)
75+
76+
def test_ridge_scatter_visual_properties():
77+
"""
78+
Verify that styling arguments in scatter_kwargs are correctly applied
79+
to the resulting PathCollection.
80+
"""
81+
fig, ax = plt.subplots()
82+
83+
# Test specific visual styling: size (s) and transparency (alpha)
84+
style = {"s": 50, "alpha": 0.5, "edgecolors": "red"}
85+
_, points = ridge_scatter(ax, [1], [1], scatter_kwargs=style)
86+
87+
# PathCollection stores sizes as an array of squares of the diameters
88+
assert points.get_sizes()[0] == 50
89+
assert points.get_alpha() == 0.5
90+
91+
plt.close(fig)

tests/unit/test_ridge_scatter_line.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,47 @@ def test_bad_line_kwargs_typeerror():
9999
ridge_scatter_line(ax, [1, 2], [3, 4], line_kwargs=[("lw", 2)])
100100

101101
plt.close(fig)
102+
103+
def test_label_and_styling_application():
104+
"""
105+
Verify that label and line_kwargs are correctly applied to the Matplotlib line.
106+
"""
107+
fig, ax = plt.subplots()
108+
x = [1, 2, 3]
109+
y_line = [2, 4, 6]
110+
custom_kwargs = {"color": "red", "linewidth": 5, "linestyle": "--"}
111+
custom_label = "Regression Model"
112+
113+
_, line = ridge_scatter_line(
114+
ax, x, y_line,
115+
line_kwargs=custom_kwargs,
116+
label=custom_label
117+
)
118+
119+
# Check if the label was applied
120+
assert line.get_label() == custom_label
121+
# Check if kwargs were passed correctly
122+
assert line.get_color() == "red"
123+
assert line.get_linewidth() == 5
124+
assert line.get_linestyle() == "--"
125+
126+
plt.close(fig)
127+
128+
def test_return_values_consistency():
129+
"""
130+
Verify that the function returns the original Axes object and the created Line2D artist.
131+
"""
132+
fig, ax = plt.subplots()
133+
x = [0, 1]
134+
y_line = [0, 1]
135+
136+
returned_ax, returned_line = ridge_scatter_line(ax, x, y_line)
137+
138+
# The returned ax should be the same object passed in
139+
assert returned_ax is ax
140+
# The returned line should be a Matplotlib Line2D object
141+
assert isinstance(returned_line, matplotlib.lines.Line2D)
142+
# The returned line should be the same one present in the axes
143+
assert returned_line in ax.lines
144+
145+
plt.close(fig)

0 commit comments

Comments
 (0)