Skip to content

Commit 17d3a02

Browse files
committed
test: r2_score and blockwise voting with different chunk patterns
1 parent 0f6bccf commit 17d3a02

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

Diff for: tests/ensemble/test_blockwise.py

+19
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,25 @@ def test_fit_array(self):
212212
# score2 = est.score(X3, y3)
213213
# assert score == score2
214214

215+
def test_predict_with_different_chunks(self):
216+
"""Test prediction with different chunking patterns."""
217+
# Train with one chunking pattern
218+
X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25)
219+
est = dask_ml.ensemble.BlockwiseVotingRegressor(
220+
sklearn.linear_model.LinearRegression(),
221+
)
222+
est.fit(X, y)
223+
# Predict with a different chunking pattern
224+
X_test, y_test = dask_ml.datasets.make_regression(n_features=20, chunks=20)
225+
226+
result = est.predict(X_test)
227+
assert result.dtype == np.dtype("float64")
228+
assert result.shape == y_test.shape
229+
230+
# Also test scoring
231+
score = est.score(X_test, y_test)
232+
assert isinstance(score, float)
233+
215234
def test_fit_frame(self):
216235
X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25)
217236
X = dd.from_dask_array(X)

Diff for: tests/metrics/test_regression.py

+16
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,19 @@ def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs):
116116

117117
with pytest.raises((NotImplementedError, ValueError), match=error_msg):
118118
_ = m1(a, b, multioutput=weights)
119+
120+
121+
def test_r2_score_with_different_chunk_patterns():
122+
"""Test r2_score with different chunking configurations."""
123+
# Create arrays with compatible but different chunk patterns
124+
a = da.random.uniform(size=(100,), chunks=25) # 4 chunks
125+
b = da.random.uniform(size=(100,), chunks=20) # 5 chunks
126+
result = dask_ml.metrics.r2_score(a, b)
127+
assert isinstance(result, float)
128+
# Create arrays with different chunk patterns
129+
a_multi = da.random.uniform(size=(100, 3), chunks=(25, 3)) # 4 chunks
130+
b_multi = da.random.uniform(size=(100, 3), chunks=(20, 3)) # 5 chunks
131+
result_multi = dask_ml.metrics.r2_score(
132+
a_multi, b_multi, multioutput="uniform_average"
133+
)
134+
assert isinstance(result_multi, float)

0 commit comments

Comments
 (0)