Skip to content

Commit 7710b5f

Browse files
committed
handle mismatched dtypes with xp.inf
1 parent 40d0f79 commit 7710b5f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

sklearn/metrics/_ranking.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def det_curve(
382382
# i.e. tps = fps = 0
383383
tps = xp.concat((xp.asarray([0.0], device=device), tps))
384384
fps = xp.concat((xp.asarray([0.0], device=device), fps))
385-
thresholds = xp.concat((xp.asarray([np.inf], device=device), thresholds))
385+
thresholds = xp.astype(thresholds, _max_precision_float_dtype(xp, device))
386+
thresholds = xp.concat((xp.asarray([xp.inf], device=device), thresholds))
386387

387388
if drop_intermediate and len(fps) > 2:
388389
# Drop thresholds where true positives (tp) do not change from the

0 commit comments

Comments
 (0)