@@ -341,7 +341,8 @@ def test_threshold_display_creation(
341341def test_threshold_display_without_threshold (
342342 pyplot , logistic_binary_classification_with_train_test
343343):
344- """Check that do_threshold is False when threshold=False."""
344+ """Check that do_threshold is False when threshold=False and that we raise an error
345+ when frame or plot is called with threshold."""
345346 estimator , X_train , X_test , y_train , y_test = (
346347 logistic_binary_classification_with_train_test
347348 )
@@ -357,6 +358,18 @@ def test_threshold_display_without_threshold(
357358 assert display .do_threshold is False
358359 assert display .thresholds_ is None
359360
361+ display = report .metrics .confusion_matrix (threshold = False )
362+
363+ err_msg = (
364+ "threshold can only be used with binary classification and "
365+ "when `report.metrics.confusion_matrix\\ (threshold=True\\ )` is used."
366+ )
367+ with pytest .raises (ValueError , match = err_msg ):
368+ display .frame (threshold = 0.5 )
369+
370+ with pytest .raises (ValueError , match = err_msg ):
371+ display .plot (threshold = 0.5 )
372+
360373
361374def test_plot_with_threshold (pyplot , logistic_binary_classification_with_train_test ):
362375 """Check that we can plot with a specific threshold."""
@@ -375,9 +388,6 @@ def test_plot_with_threshold(pyplot, logistic_binary_classification_with_train_t
375388 display .plot (threshold = 0.3 )
376389 assert "threshold" in display .ax_ .get_title ().lower ()
377390
378- display .plot (threshold = 0.7 )
379- assert "threshold" in display .ax_ .get_title ().lower ()
380-
381391
382392def test_plot_with_default_threshold (
383393 pyplot , logistic_binary_classification_with_train_test
@@ -394,34 +404,15 @@ def test_plot_with_default_threshold(
394404 y_test = y_test ,
395405 )
396406 display = report .metrics .confusion_matrix (threshold = True )
397- display .plot () # Should use default threshold (0.5)
398-
399- # The title should include the threshold
400- assert "threshold" in display .ax_ .get_title ().lower ()
401-
402-
403- def test_threshold_error_without_threshold_support (
404- pyplot , forest_binary_classification_with_train_test
405- ):
406- """Check that we raise an error when threshold is used without threshold support."""
407- estimator , X_train , X_test , y_train , y_test = (
408- forest_binary_classification_with_train_test
409- )
410- report = EstimatorReport (
411- estimator ,
412- X_train = X_train ,
413- y_train = y_train ,
414- X_test = X_test ,
415- y_test = y_test ,
416- )
417- display = report .metrics .confusion_matrix (threshold = False )
407+ display .plot ()
418408
419- err_msg = (
420- "threshold can only be used with binary classification and "
421- "when `report.metrics.confusion_matrix\\ (threshold=True\\ )` is used."
409+ closest_threshold = display .thresholds_ [
410+ np .argmin (np .abs (display .thresholds_ - 0.5 ))
411+ ]
412+ assert (
413+ display .ax_ .get_title ()
414+ == f"Confusion Matrix (threshold: { closest_threshold :.2f} )"
422415 )
423- with pytest .raises (ValueError , match = err_msg ):
424- display .plot (threshold = 0.5 )
425416
426417
427418def test_frame_with_threshold (logistic_binary_classification_with_train_test ):
@@ -463,30 +454,6 @@ def test_frame_all_thresholds(logistic_binary_classification_with_train_test):
463454 assert len (frame ) == len (display .thresholds_ )
464455
465456
466- def test_frame_threshold_error_without_threshold_support (
467- forest_binary_classification_with_train_test ,
468- ):
469- """Check that we raise an error when threshold is used without threshold support."""
470- estimator , X_train , X_test , y_train , y_test = (
471- forest_binary_classification_with_train_test
472- )
473- report = EstimatorReport (
474- estimator ,
475- X_train = X_train ,
476- y_train = y_train ,
477- X_test = X_test ,
478- y_test = y_test ,
479- )
480- display = report .metrics .confusion_matrix (threshold = False )
481-
482- err_msg = (
483- "threshold can only be used with binary classification "
484- "when `report.metrics.confusion_matrix\\ (threshold=True\\ )` is used."
485- )
486- with pytest .raises (ValueError , match = err_msg ):
487- display .frame (threshold = 0.5 )
488-
489-
490457def test_threshold_normalization (
491458 pyplot , logistic_binary_classification_with_train_test
492459):
@@ -503,26 +470,23 @@ def test_threshold_normalization(
503470 )
504471 display = report .metrics .confusion_matrix (threshold = True )
505472
506- # Test with normalize="true"
507473 display .plot (threshold = 0.5 , normalize = "true" )
508474 frame = display .frame (threshold = 0.5 , normalize = "true" )
509475 assert np .allclose (frame .sum (axis = 1 ), np .ones (2 ))
510476
511- # Test with normalize="pred"
512477 display .plot (threshold = 0.5 , normalize = "pred" )
513478 frame = display .frame (threshold = 0.5 , normalize = "pred" )
514479 assert np .allclose (frame .sum (axis = 0 ), np .ones (2 ))
515480
516- # Test with normalize="all"
517481 display .plot (threshold = 0.5 , normalize = "all" )
518482 frame = display .frame (threshold = 0.5 , normalize = "all" )
519483 assert np .isclose (frame .sum ().sum (), 1.0 )
520484
521485
522- def test_plot_with_multiple_thresholds (
486+ def test_threshold_closest_match (
523487 pyplot , logistic_binary_classification_with_train_test
524488):
525- """Check that we can plot with multiple thresholds ."""
489+ """Check that the closest threshold is selected ."""
526490 estimator , X_train , X_test , y_train , y_test = (
527491 logistic_binary_classification_with_train_test
528492 )
@@ -535,17 +499,25 @@ def test_plot_with_multiple_thresholds(
535499 )
536500 display = report .metrics .confusion_matrix (threshold = True )
537501
538- # Plot with multiple thresholds
539- display .plot (threshold = [0.3 , 0.5 , 0.7 ])
540-
541- # Should have 3 subplots
542- assert len (display .figure_ .axes ) >= 3
502+ # Create a threshold that is not in the list to test the closest match
503+ middle_index = len (display .thresholds_ ) // 2
504+ threshold = (
505+ display .thresholds_ [middle_index ] + display .thresholds_ [middle_index + 1 ]
506+ ) / 2 - 1e-6
507+ closest_threshold = display .thresholds_ [middle_index ]
508+ assert threshold not in display .thresholds_
509+ display .plot (threshold = threshold )
510+ assert (
511+ display .ax_ .get_title ()
512+ == f"Confusion Matrix (threshold: { closest_threshold :.2f} )"
513+ )
543514
544515
545- def test_threshold_closest_match (
516+ def test_frame_plot_coincidence_with_threshold (
546517 pyplot , logistic_binary_classification_with_train_test
547518):
548- """Check that the closest threshold is selected."""
519+ """Check that the values in the frame and plot coincide when threshold is
520+ provided."""
549521 estimator , X_train , X_test , y_train , y_test = (
550522 logistic_binary_classification_with_train_test
551523 )
@@ -557,7 +529,7 @@ def test_threshold_closest_match(
557529 y_test = y_test ,
558530 )
559531 display = report .metrics .confusion_matrix (threshold = True )
560-
561- # Even with a threshold not in the list, it should work
562- display .plot (threshold = 0.12345 )
563- assert display .ax_ is not None
532+ frame = display . frame ( threshold = 0.5 )
533+ frame_values = frame . values . flatten ()
534+ display .plot (threshold = 0.5 )
535+ assert np . allclose ( frame_values , display .ax_ . collections [ 0 ]. get_array (). flatten ())
0 commit comments