Skip to content

Commit d7f7b86

Browse files
committed
fix(ensemble, metrics): compute chunk sizes and refactor r2_score with nested da.where to prevent broadcasting errors
1 parent 69d168b commit d7f7b86

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

Diff for: dask_ml/ensemble/_blockwise.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _predict(self, X):
5454
dtype = "float64"
5555

5656
if isinstance(X, da.Array):
57+
X = X.compute_chunk_sizes() # Ensure integer chunks to avoid broadcasting errors.
5758
chunks = (X.chunks[0], len(self.estimators_))
5859
combined = X.map_blocks(
5960
_predict_stack,
@@ -174,6 +175,7 @@ def _predict_proba(self, X):
174175

175176
def _collect_probas(self, X):
176177
if isinstance(X, da.Array):
178+
X = X.compute_chunk_sizes() # Ensure integer chunks to avoid broadcasting errors.
177179
chunks = (len(self.estimators_), X.chunks[0], len(self.classes_))
178180
meta = np.array([], dtype="float64")
179181
# (n_estimators, len(X), n_classes)

Diff for: dask_ml/metrics/regression.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -155,25 +155,48 @@ def r2_score(
155155
multioutput: Optional[str] = "uniform_average",
156156
compute: bool = True,
157157
) -> ArrayLike:
158+
"""
159+
Compute the R² score for regression.
160+
161+
This function calculates the coefficient of determination using residual
162+
and total sums of squares. It employs dask.array.where to gracefully handle
163+
unknown dimensions without in-place assignment.
164+
165+
Parameters
166+
----------
167+
y_true : ArrayLike
168+
True target values.
169+
y_pred : ArrayLike
170+
Predicted target values.
171+
sample_weight : Optional[ArrayLike], default=None
172+
Weights for samples.
173+
multioutput : Optional[str], default="uniform_average"
174+
Method to aggregate multiple outputs.
175+
compute : bool, default=True
176+
If True, return the computed result; else, return a Dask array.
177+
178+
Returns
179+
-------
180+
result : ArrayLike
181+
The R² score (scalar/NumPy array if computed, or a Dask array otherwise).
182+
"""
158183
_check_sample_weight(sample_weight)
159184
_, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput)
160185
weight = 1.0
161186

187+
# Compute residual and total sums of squares.
162188
numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8")
163189
denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8")
164190

165-
nonzero_denominator = denominator != 0
166-
nonzero_numerator = numerator != 0
167-
valid_score = nonzero_denominator & nonzero_numerator
168-
output_chunks = getattr(y_true, "chunks", [None, None])[1]
169-
output_scores = da.ones([y_true.shape[1]], chunks=output_chunks)
170-
with np.errstate(all="ignore"):
171-
output_scores[valid_score] = 1 - (
172-
numerator[valid_score] / denominator[valid_score]
173-
)
174-
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0
175-
176-
result = output_scores.mean(axis=0)
191+
# Determine R²: 1.0 for perfect predictions, 1 - numerator/denom when valid,
192+
# and 0.0 if denominator is zero.
193+
score = da.where(
194+
numerator == 0,
195+
1.0,
196+
da.where(denominator != 0, 1 - numerator / denominator, 0.0)
197+
)
198+
199+
result = score.mean(axis=0)
177200
if compute:
178201
result = result.compute()
179202
return result

0 commit comments

Comments
 (0)