|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
6 | 6 |
|
| 7 | +from sklearn._config import config_context |
7 | 8 | from sklearn.metrics.cluster import ( |
8 | 9 | adjusted_mutual_info_score, |
9 | 10 | adjusted_rand_score, |
|
18 | 19 | silhouette_score, |
19 | 20 | v_measure_score, |
20 | 21 | ) |
21 | | -from sklearn.utils._testing import assert_allclose |
| 22 | +from sklearn.utils._array_api import ( |
| 23 | + _atol_for_type, |
| 24 | + _convert_to_numpy, |
| 25 | + _get_namespace_device_dtype_ids, |
| 26 | + yield_namespace_device_dtype_combinations, |
| 27 | +) |
| 28 | +from sklearn.utils._testing import _array_api_for_tests, assert_allclose |
22 | 29 |
|
23 | 30 | # Dictionaries of metrics |
24 | 31 | # ------------------------ |
@@ -232,3 +239,88 @@ def test_returned_value_consistency(name): |
232 | 239 |
|
233 | 240 | assert isinstance(score, float) |
234 | 241 | assert not isinstance(score, (np.float64, np.float32)) |
| 242 | + |
| 243 | + |
| 244 | +def check_array_api_metric( |
| 245 | + metric, array_namespace, device, dtype_name, a_np, b_np, **metric_kwargs |
| 246 | +): |
| 247 | + xp = _array_api_for_tests(array_namespace, device) |
| 248 | + |
| 249 | + a_xp = xp.asarray(a_np, device=device) |
| 250 | + b_xp = xp.asarray(b_np, device=device) |
| 251 | + |
| 252 | + metric_np = metric(a_np, b_np, **metric_kwargs) |
| 253 | + |
| 254 | + # When array API dispatch is disabled, and np.asarray works (for example PyTorch |
| 255 | + # with CPU device), calling the metric function with such numpy compatible inputs |
| 256 | + # should work (albeit by implicitly converting to numpy arrays instead of |
| 257 | + # dispatching to the array library). |
| 258 | + try: |
| 259 | + np.asarray(a_xp) |
| 260 | + np.asarray(b_xp) |
| 261 | + numpy_as_array_works = True |
| 262 | + except (TypeError, RuntimeError, ValueError): |
| 263 | + # PyTorch with CUDA device and CuPy raise TypeError consistently. |
| 264 | + # array-api-strict chose to raise RuntimeError instead. NumPy raises |
| 265 | + # a ValueError if the `__array__` dunder does not return an array. |
| 266 | + # Exception type may need to be updated in the future for other libraries. |
| 267 | + numpy_as_array_works = False |
| 268 | + |
| 269 | + def _check_metric_matches(metric_a, metric_b, convert_a=False): |
| 270 | + if convert_a: |
| 271 | + metric_a = _convert_to_numpy(xp.asarray(metric_a), xp) |
| 272 | + assert_allclose(metric_a, metric_b, atol=_atol_for_type(dtype_name)) |
| 273 | + |
| 274 | + if numpy_as_array_works: |
| 275 | + metric_xp = metric(a_xp, b_xp, **metric_kwargs) |
| 276 | + |
| 277 | + # Handle cases where multiple return values are not of the same shape, |
| 278 | + # e.g. precision_recall_curve: |
| 279 | + _check_metric_matches(metric_xp, metric_np) |
| 280 | + |
| 281 | + metric_xp_mixed_1 = metric(a_np, b_xp, **metric_kwargs) |
| 282 | + _check_metric_matches(metric_xp_mixed_1, metric_np) |
| 283 | + |
| 284 | + metric_xp_mixed_2 = metric(a_xp, b_np, **metric_kwargs) |
| 285 | + _check_metric_matches(metric_xp_mixed_2, metric_np) |
| 286 | + |
| 287 | + with config_context(array_api_dispatch=True): |
| 288 | + metric_xp = metric(a_xp, b_xp, **metric_kwargs) |
| 289 | + _check_metric_matches(metric_xp, metric_np, convert_a=True) |
| 290 | + |
| 291 | + |
| 292 | +def check_array_api_unsupervised_metric(metric, array_namespace, device, dtype_name): |
| 293 | + y_pred = np.array([1, 0, 1, 0, 1, 1, 0]) |
| 294 | + X = np.random.randint(10, size=(7, 10)) |
| 295 | + |
| 296 | + check_array_api_metric( |
| 297 | + metric, |
| 298 | + array_namespace, |
| 299 | + device, |
| 300 | + dtype_name, |
| 301 | + a_np=X, |
| 302 | + b_np=y_pred, |
| 303 | + ) |
| 304 | + |
| 305 | + |
| 306 | +array_api_metric_checkers = { |
| 307 | + calinski_harabasz_score: [ |
| 308 | + check_array_api_unsupervised_metric, |
| 309 | + ] |
| 310 | +} |
| 311 | + |
| 312 | + |
| 313 | +def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers): |
| 314 | + for metric, checkers in metric_checkers.items(): |
| 315 | + for checker in checkers: |
| 316 | + yield metric, checker |
| 317 | + |
| 318 | + |
| 319 | +@pytest.mark.parametrize( |
| 320 | + "array_namespace, device, dtype_name", |
| 321 | + yield_namespace_device_dtype_combinations(), |
| 322 | + ids=_get_namespace_device_dtype_ids, |
| 323 | +) |
| 324 | +@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations()) |
| 325 | +def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func): |
| 326 | + check_func(metric, array_namespace, device, dtype_name) |
0 commit comments