We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 40d0f79 commit 7710b5fCopy full SHA for 7710b5f
sklearn/metrics/_ranking.py
@@ -382,7 +382,8 @@ def det_curve(
382
# i.e. tps = fps = 0
383
tps = xp.concat((xp.asarray([0.0], device=device), tps))
384
fps = xp.concat((xp.asarray([0.0], device=device), fps))
385
- thresholds = xp.concat((xp.asarray([np.inf], device=device), thresholds))
+ thresholds = xp.astype(thresholds, _max_precision_float_dtype(xp, device))
386
+ thresholds = xp.concat((xp.asarray([xp.inf], device=device), thresholds))
387
388
if drop_intermediate and len(fps) > 2:
389
# Drop thresholds where true positives (tp) do not change from the
0 commit comments