Skip to content

Commit e40c216

Browse files
FEA Add array API support for laplacian_kernel (scikit-learn#32613)
1 parent e04f9cb commit e40c216

File tree

5 files changed

+13
-3
lines changed

5 files changed

+13
-3
lines changed

doc/modules/array_api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,12 @@ Metrics
177177
- :func:`sklearn.metrics.pairwise.cosine_distances`
178178
- :func:`sklearn.metrics.pairwise.pairwise_distances` (only supports "cosine", "euclidean", "manhattan" and "l2" metrics)
179179
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
180+
- :func:`sklearn.metrics.pairwise.laplacian_kernel`
180181
- :func:`sklearn.metrics.pairwise.linear_kernel`
181182
- :func:`sklearn.metrics.pairwise.manhattan_distances`
182183
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
183184
- :func:`sklearn.metrics.pairwise.paired_euclidean_distances`
184-
- :func:`sklearn.metrics.pairwise.pairwise_kernels` (supports all `sklearn.pairwise.PAIRWISE_KERNEL_FUNCTIONS` except :func:`sklearn.metrics.pairwise.laplacian_kernel`)
185+
- :func:`sklearn.metrics.pairwise.pairwise_kernels`
185186
- :func:`sklearn.metrics.pairwise.polynomial_kernel`
186187
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
187188
- :func:`sklearn.metrics.pairwise.sigmoid_kernel`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.metrics.pairwise.laplacian_kernel` now supports array API compatible inputs.
2+
By :user:`Zubair Shakoor <zubairshakoorarbisoft>`.

sklearn/metrics/pairwise.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1672,7 +1672,11 @@ def laplacian_kernel(X, Y=None, gamma=None):
16721672
gamma = 1.0 / X.shape[1]
16731673

16741674
K = -gamma * manhattan_distances(X, Y)
1675-
np.exp(K, K) # exponentiate K in-place
1675+
xp, _ = get_namespace(X, Y)
1676+
if _is_numpy_namespace(xp):
1677+
np.exp(K, K) # exponentiate K in-place
1678+
else:
1679+
K = xp.exp(K)
16761680
return K
16771681

16781682

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
cosine_distances,
6565
cosine_similarity,
6666
euclidean_distances,
67+
laplacian_kernel,
6768
linear_kernel,
6869
manhattan_distances,
6970
paired_cosine_distances,
@@ -2349,6 +2350,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
23492350
euclidean_distances: [check_array_api_metric_pairwise],
23502351
manhattan_distances: [check_array_api_metric_pairwise],
23512352
linear_kernel: [check_array_api_metric_pairwise],
2353+
laplacian_kernel: [check_array_api_metric_pairwise],
23522354
polynomial_kernel: [check_array_api_metric_pairwise],
23532355
rbf_kernel: [check_array_api_metric_pairwise],
23542356
root_mean_squared_error: [

sklearn/metrics/tests/test_pairwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def test_pairwise_parallel(func, metric, kwds, dtype):
401401
(pairwise_distances, "manhattan", {}),
402402
(pairwise_kernels, "polynomial", {"degree": 1}),
403403
(pairwise_kernels, callable_rbf_kernel, {"gamma": 0.1}),
404+
(pairwise_kernels, "laplacian", {"gamma": 0.1}),
404405
],
405406
)
406407
def test_pairwise_parallel_array_api(
@@ -487,7 +488,7 @@ def test_pairwise_kernels(metric, csr_container):
487488
)
488489
@pytest.mark.parametrize(
489490
"metric",
490-
["rbf", "sigmoid", "polynomial", "linear", "chi2", "additive_chi2"],
491+
["rbf", "sigmoid", "polynomial", "linear", "laplacian", "chi2", "additive_chi2"],
491492
)
492493
def test_pairwise_kernels_array_api(metric, array_namespace, device, dtype_name):
493494
# Test array API support in pairwise_kernels.

0 commit comments

Comments
 (0)