@@ -372,15 +372,17 @@ def det_curve(
372372 >>> thresholds
373373 array([0.35, 0.4 , 0.8 ])
374374 """
375+
376+ xp , _ , device = get_namespace_and_device (y_true , y_score )
375377 fps , tps , thresholds = _binary_clf_curve (
376378 y_true , y_score , pos_label = pos_label , sample_weight = sample_weight
377379 )
378380
379381 # add a threshold at inf where the clf always predicts the negative class
380382 # i.e. tps = fps = 0
381- tps = np .concatenate (([0 ], tps ))
382- fps = np .concatenate (([0 ], fps ))
383- thresholds = np .concatenate (([np .inf ], thresholds ))
383+ tps = xp .concatenate ((xp . asarray ( [0 ], device = device ) , tps ))
384+ fps = xp .concatenate ((xp . asarray ( [0 ], device = device ) , fps ))
385+ thresholds = xp .concatenate ((xp . asarray ( [np .inf ], device = device ) , thresholds ))
384386
385387 if drop_intermediate and len (fps ) > 2 :
386388 # Drop thresholds where true positives (tp) do not change from the
@@ -389,16 +391,20 @@ def det_curve(
389391 # false positive rate (fpr) changes, producing horizontal line segments
390392 # in the transformed (normal deviate) scale. These intermediate points
391393 # can be dropped to create lighter DET curve plots.
392- optimal_idxs = np .where (
393- np .concatenate (
394- [[True ], np .logical_or (np .diff (tps [:- 1 ]), np .diff (tps [1 :])), [True ]]
394+ optimal_idxs = xp .where (
395+ xp .concatenate (
396+ [
397+ xp .asarray ([True ], device = device ),
398+ xp .logical_or (xp .diff (tps [:- 1 ]), xp .diff (tps [1 :])),
399+ xp .asarray ([True ], device = device ),
400+ ]
395401 )
396402 )[0 ]
397403 fps = fps [optimal_idxs ]
398404 tps = tps [optimal_idxs ]
399405 thresholds = thresholds [optimal_idxs ]
400406
401- if len (np .unique (y_true )) != 2 :
407+ if len (xp .unique (y_true )) != 2 :
402408 raise ValueError (
403409 "Only one class is present in y_true. Detection error "
404410 "tradeoff curve is not defined in that case."
@@ -410,16 +416,20 @@ def det_curve(
410416
411417 # start with false positives zero, which may be at a finite threshold
412418 first_ind = (
413- fps .searchsorted (fps [0 ], side = "right" ) - 1
414- if fps .searchsorted (fps [0 ], side = "right" ) > 0
419+ xp .searchsorted (fps , fps [0 ], side = "right" ) - 1
420+ if xp .searchsorted (fps , fps [0 ], side = "right" ) > 0
415421 else None
416422 )
417423 # stop with false negatives zero
418- last_ind = tps .searchsorted (tps [- 1 ]) + 1
424+ last_ind = xp .searchsorted (tps , tps [- 1 ]) + 1
419425 sl = slice (first_ind , last_ind )
420426
421427 # reverse the output such that list of false positives is decreasing
422- return (fps [sl ][::- 1 ] / n_count , fns [sl ][::- 1 ] / p_count , thresholds [sl ][::- 1 ])
428+ return (
429+ xp .flip (fps [sl ]) / n_count ,
430+ xp .flip (fns [sl ]) / p_count ,
431+ xp .flip (thresholds [sl ]),
432+ )
423433
424434
425435def _binary_roc_auc_score (y_true , y_score , sample_weight = None , max_fpr = None ):
0 commit comments