@@ -155,25 +155,48 @@ def r2_score(
155
155
multioutput : Optional [str ] = "uniform_average" ,
156
156
compute : bool = True ,
157
157
) -> ArrayLike :
158
+ """
159
+ Compute the R² score for regression.
160
+
161
+ This function calculates the coefficient of determination using residual
162
+ and total sums of squares. It employs dask.array.where to gracefully handle
163
+ unknown dimensions without in-place assignment.
164
+
165
+ Parameters
166
+ ----------
167
+ y_true : ArrayLike
168
+ True target values.
169
+ y_pred : ArrayLike
170
+ Predicted target values.
171
+ sample_weight : Optional[ArrayLike], default=None
172
+ Weights for samples.
173
+ multioutput : Optional[str], default="uniform_average"
174
+ Method to aggregate multiple outputs.
175
+ compute : bool, default=True
176
+ If True, return the computed result; else, return a Dask array.
177
+
178
+ Returns
179
+ -------
180
+ result : ArrayLike
181
+ The R² score (scalar/NumPy array if computed, or a Dask array otherwise).
182
+ """
158
183
_check_sample_weight (sample_weight )
159
184
_ , y_true , y_pred , _ = _check_reg_targets (y_true , y_pred , multioutput )
160
185
weight = 1.0
161
186
187
+ # Compute residual and total sums of squares.
162
188
numerator = (weight * (y_true - y_pred ) ** 2 ).sum (axis = 0 , dtype = "f8" )
163
189
denominator = (weight * (y_true - y_true .mean (axis = 0 )) ** 2 ).sum (axis = 0 , dtype = "f8" )
164
190
165
- nonzero_denominator = denominator != 0
166
- nonzero_numerator = numerator != 0
167
- valid_score = nonzero_denominator & nonzero_numerator
168
- output_chunks = getattr (y_true , "chunks" , [None , None ])[1 ]
169
- output_scores = da .ones ([y_true .shape [1 ]], chunks = output_chunks )
170
- with np .errstate (all = "ignore" ):
171
- output_scores [valid_score ] = 1 - (
172
- numerator [valid_score ] / denominator [valid_score ]
173
- )
174
- output_scores [nonzero_numerator & ~ nonzero_denominator ] = 0.0
175
-
176
- result = output_scores .mean (axis = 0 )
191
+ # Determine R²: 1.0 for perfect predictions, 1 - numerator/denom when valid,
192
+ # and 0.0 if denominator is zero.
193
+ score = da .where (
194
+ numerator == 0 ,
195
+ 1.0 ,
196
+ da .where (denominator != 0 , 1 - numerator / denominator , 0.0 )
197
+ )
198
+
199
+ result = score .mean (axis = 0 )
177
200
if compute :
178
201
result = result .compute ()
179
202
return result
0 commit comments