diff --git a/dask_ml/ensemble/_blockwise.py b/dask_ml/ensemble/_blockwise.py index 359f96346..5684bd5d0 100644 --- a/dask_ml/ensemble/_blockwise.py +++ b/dask_ml/ensemble/_blockwise.py @@ -8,6 +8,18 @@ from ..utils import check_array, is_frame_base +def _safe_rechunk(arr, rechunk_dict, error_context=""): + """Helper function to safely rechunk arrays with proper error handling.""" + try: + return arr.rechunk(rechunk_dict) + except Exception as e: + msg = ( + "Failed to rechunk array" + f"{': ' + error_context if error_context else ''}: {e}" + ) + raise ValueError(msg) from e + + class BlockwiseBase(sklearn.base.BaseEstimator): def __init__(self, estimator): self.estimator = estimator @@ -22,6 +34,11 @@ def _check_array(self, X): def fit(self, X, y, **kwargs): X = self._check_array(X) + try: + self._n_samples = X.shape[0] + except Exception: + self._n_samples = None + estimatord = dask.delayed(self.estimator) Xs = X.to_delayed() @@ -45,6 +62,7 @@ def fit(self, X, y, **kwargs): ] results = list(dask.compute(*results)) self.estimators_ = results + return self def _predict(self, X): """Collect results from many predict calls""" @@ -54,6 +72,13 @@ def _predict(self, X): dtype = "float64" if isinstance(X, da.Array): + if hasattr(self, "_n_samples") and self._n_samples is not None: + desired = len(self.estimators_) + if X.numblocks[0] != desired: + block_size = max(1, self._n_samples // desired) + X = _safe_rechunk( + X, {0: block_size}, "to match estimator partitioning" + ) chunks = (X.chunks[0], len(self.estimators_)) combined = X.map_blocks( _predict_stack, @@ -174,6 +199,13 @@ def _predict_proba(self, X): def _collect_probas(self, X): if isinstance(X, da.Array): + if hasattr(self, "_n_samples") and self._n_samples is not None: + desired = len(self.estimators_) + if X.numblocks[0] != desired: + block_size = max(1, self._n_samples // desired) + X = _safe_rechunk( + X, {0: block_size}, "to match estimator partitioning" + ) chunks = (len(self.estimators_), X.chunks[0], len(self.classes_)) meta = np.array([], dtype="float64") # (n_estimators, len(X), n_classes) diff --git a/dask_ml/metrics/regression.py b/dask_ml/metrics/regression.py index 0c1b21b59..ccb7e7bca 100644 --- a/dask_ml/metrics/regression.py +++ b/dask_ml/metrics/regression.py @@ -162,18 +162,13 @@ def r2_score( numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8") denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8") - nonzero_denominator = denominator != 0 - nonzero_numerator = numerator != 0 - valid_score = nonzero_denominator & nonzero_numerator - output_chunks = getattr(y_true, "chunks", [None, None])[1] - output_scores = da.ones([y_true.shape[1]], chunks=output_chunks) - with np.errstate(all="ignore"): - output_scores[valid_score] = 1 - ( - numerator[valid_score] / denominator[valid_score] - ) - output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0 - - result = output_scores.mean(axis=0) + score = da.where( + numerator == 0, + 1.0, + da.where(denominator != 0, 1 - numerator / denominator, 0.0), + ) + + result = score.mean(axis=0) if compute: result = result.compute() return result diff --git a/tests/ensemble/test_blockwise.py b/tests/ensemble/test_blockwise.py index 6a8f63a4b..b8febb50c 100644 --- a/tests/ensemble/test_blockwise.py +++ b/tests/ensemble/test_blockwise.py @@ -186,6 +186,27 @@ def test_no_classes_raises(self): class TestBlockwiseVotingRegressor: + def test_no_unnecessary_computation_in_fit(self, monkeypatch): + X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25) + compute_called = False + original_compute = X.compute + + def spy_compute(*args, **kwargs): + nonlocal compute_called + compute_called = True + return original_compute(*args, **kwargs) + + monkeypatch.setattr(X, "compute", spy_compute) + + est = dask_ml.ensemble.BlockwiseVotingRegressor( + sklearn.linear_model.LinearRegression(), + ) + est.fit(X, y) + # Ensure that X.compute() was never invoked during fitting. + assert compute_called is False + # Verify that _n_samples was set using lazy metadata. + assert est._n_samples == X.shape[0] + def test_fit_array(self): X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25) est = dask_ml.ensemble.BlockwiseVotingRegressor( @@ -240,3 +261,24 @@ def test_fit_frame(self): # TODO: r2_score raising for ndarray # score2 = est.score(X3, y3) # assert score == score2 + + def test_predict_with_different_chunks(self): + X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25) + est = dask_ml.ensemble.BlockwiseVotingRegressor( + sklearn.linear_model.LinearRegression(), + ) + est.fit(X, y) + + X_test, y_test = dask_ml.datasets.make_regression(n_features=20, chunks=20) + result = est.predict(X_test) + assert result.dtype == np.dtype("float64") + assert result.shape == y_test.shape + # Prediction is rechunked to have one block per estimator. + assert result.numblocks[0] == len(est.estimators_) + + score = est.score(X_test, y_test) + assert isinstance(score, float) + + X_test_np, y_test_np = dask.compute(X_test, y_test) + result_np = est.predict(X_test_np) + da.utils.assert_eq(result, result_np) diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py index af775e168..d15580c7d 100644 --- a/tests/metrics/test_regression.py +++ b/tests/metrics/test_regression.py @@ -116,3 +116,19 @@ def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs): with pytest.raises((NotImplementedError, ValueError), match=error_msg): _ = m1(a, b, multioutput=weights) + + +def test_r2_score_with_different_chunk_patterns(): + """Test r2_score with different chunking configurations.""" + # Create arrays with compatible but different chunk patterns + a = da.random.uniform(size=(100,), chunks=25) # 4 chunks + b = da.random.uniform(size=(100,), chunks=20) # 5 chunks + result = dask_ml.metrics.r2_score(a, b) + assert isinstance(result, float) + # Create arrays with different chunk patterns + a_multi = da.random.uniform(size=(100, 3), chunks=(25, 3)) # 4 chunks + b_multi = da.random.uniform(size=(100, 3), chunks=(20, 3)) # 5 chunks + result_multi = dask_ml.metrics.r2_score( + a_multi, b_multi, multioutput="uniform_average" + ) + assert isinstance(result_multi, float)