Skip to content

Commit 1ed5568

Browse files
committed
fix dtype
1 parent 8cacc89 commit 1ed5568

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

sklearn/metrics/_ranking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ def det_curve(
380380

381381
# add a threshold at inf where the clf always predicts the negative class
382382
# i.e. tps = fps = 0
383-
tps = xp.concat((xp.asarray([0], device=device), tps))
384-
fps = xp.concat((xp.asarray([0], device=device), fps))
383+
tps = xp.concat((xp.asarray([0.0], device=device), tps))
384+
fps = xp.concat((xp.asarray([0.0], device=device), fps))
385385
thresholds = xp.concat((xp.asarray([np.inf], device=device), thresholds))
386386

387387
if drop_intermediate and len(fps) > 2:

0 commit comments

Comments
 (0)