@@ -316,11 +316,11 @@ def test_not_implemented_error_for_non_estimator_report(
316316
317317
318318def test_threshold_display_creation (
319- pyplot , logistic_binary_classification_with_train_test
319+ pyplot , forest_binary_classification_with_train_test
320320):
321321 """Check that we can create a confusion matrix display with threshold support."""
322322 estimator , X_train , X_test , y_train , y_test = (
323- logistic_binary_classification_with_train_test
323+ forest_binary_classification_with_train_test
324324 )
325325 report = EstimatorReport (
326326 estimator ,
@@ -339,12 +339,12 @@ def test_threshold_display_creation(
339339
340340
341341def test_threshold_display_without_threshold (
342- pyplot , logistic_binary_classification_with_train_test
342+ pyplot , forest_binary_classification_with_train_test
343343):
344344 """Check that do_threshold is False when threshold=False and that we raise an error
345345 when frame or plot is called with threshold."""
346346 estimator , X_train , X_test , y_train , y_test = (
347- logistic_binary_classification_with_train_test
347+ forest_binary_classification_with_train_test
348348 )
349349 report = EstimatorReport (
350350 estimator ,
@@ -371,10 +371,10 @@ def test_threshold_display_without_threshold(
371371 display .plot (threshold = 0.5 )
372372
373373
374- def test_plot_with_threshold (pyplot , logistic_binary_classification_with_train_test ):
374+ def test_plot_with_threshold (pyplot , forest_binary_classification_with_train_test ):
375375 """Check that we can plot with a specific threshold."""
376376 estimator , X_train , X_test , y_train , y_test = (
377- logistic_binary_classification_with_train_test
377+ forest_binary_classification_with_train_test
378378 )
379379 report = EstimatorReport (
380380 estimator ,
@@ -390,11 +390,11 @@ def test_plot_with_threshold(pyplot, logistic_binary_classification_with_train_t
390390
391391
392392def test_plot_with_default_threshold (
393- pyplot , logistic_binary_classification_with_train_test
393+ pyplot , forest_binary_classification_with_train_test
394394):
395395 """Check that the default threshold (0.5) is used when not specified."""
396396 estimator , X_train , X_test , y_train , y_test = (
397- logistic_binary_classification_with_train_test
397+ forest_binary_classification_with_train_test
398398 )
399399 report = EstimatorReport (
400400 estimator ,
@@ -415,10 +415,10 @@ def test_plot_with_default_threshold(
415415 )
416416
417417
418- def test_frame_with_threshold (logistic_binary_classification_with_train_test ):
418+ def test_frame_with_threshold (forest_binary_classification_with_train_test ):
419419 """Check that we can get a frame at a specific threshold."""
420420 estimator , X_train , X_test , y_train , y_test = (
421- logistic_binary_classification_with_train_test
421+ forest_binary_classification_with_train_test
422422 )
423423 report = EstimatorReport (
424424 estimator ,
@@ -434,10 +434,10 @@ def test_frame_with_threshold(logistic_binary_classification_with_train_test):
434434 assert frame .shape == (2 , 2 )
435435
436436
437- def test_frame_all_thresholds (logistic_binary_classification_with_train_test ):
437+ def test_frame_all_thresholds (forest_binary_classification_with_train_test ):
438438 """Check that we get all thresholds when threshold=None."""
439439 estimator , X_train , X_test , y_train , y_test = (
440- logistic_binary_classification_with_train_test
440+ forest_binary_classification_with_train_test
441441 )
442442 report = EstimatorReport (
443443 estimator ,
@@ -454,12 +454,10 @@ def test_frame_all_thresholds(logistic_binary_classification_with_train_test):
454454 assert len (frame ) == len (display .thresholds_ )
455455
456456
457- def test_threshold_normalization (
458- pyplot , logistic_binary_classification_with_train_test
459- ):
457+ def test_threshold_normalization (pyplot , forest_binary_classification_with_train_test ):
460458 """Check that normalization works with threshold support."""
461459 estimator , X_train , X_test , y_train , y_test = (
462- logistic_binary_classification_with_train_test
460+ forest_binary_classification_with_train_test
463461 )
464462 report = EstimatorReport (
465463 estimator ,
@@ -483,12 +481,10 @@ def test_threshold_normalization(
483481 assert np .isclose (frame .sum ().sum (), 1.0 )
484482
485483
486- def test_threshold_closest_match (
487- pyplot , logistic_binary_classification_with_train_test
488- ):
484+ def test_threshold_closest_match (pyplot , forest_binary_classification_with_train_test ):
489485 """Check that the closest threshold is selected."""
490486 estimator , X_train , X_test , y_train , y_test = (
491- logistic_binary_classification_with_train_test
487+ forest_binary_classification_with_train_test
492488 )
493489 report = EstimatorReport (
494490 estimator ,
@@ -514,12 +510,12 @@ def test_threshold_closest_match(
514510
515511
516512def test_frame_plot_coincidence_with_threshold (
517- pyplot , logistic_binary_classification_with_train_test
513+ pyplot , forest_binary_classification_with_train_test
518514):
519515 """Check that the values in the frame and plot coincide when threshold is
520516 provided."""
521517 estimator , X_train , X_test , y_train , y_test = (
522- logistic_binary_classification_with_train_test
518+ forest_binary_classification_with_train_test
523519 )
524520 report = EstimatorReport (
525521 estimator ,
@@ -533,3 +529,30 @@ def test_frame_plot_coincidence_with_threshold(
533529 frame_values = frame .values .flatten ()
534530 display .plot (threshold = 0.5 )
535531 assert np .allclose (frame_values , display .ax_ .collections [0 ].get_array ().flatten ())
532+
533+
534+ def test_pos_label (pyplot , forest_binary_classification_with_train_test ):
535+ """Check that the pos_label parameter works correctly."""
536+ estimator , X_train , X_test , y_train , y_test = (
537+ forest_binary_classification_with_train_test
538+ )
539+ labels = np .array (["A" , "B" ], dtype = object )
540+ y_train = labels [y_train ]
541+ y_test = labels [y_test ]
542+ estimator .fit (X_train , y_train )
543+ report = EstimatorReport (
544+ estimator ,
545+ X_train = X_train ,
546+ y_train = y_train ,
547+ X_test = X_test ,
548+ y_test = y_test ,
549+ )
550+ display = report .metrics .confusion_matrix (pos_label = "A" )
551+ display .plot ()
552+ assert display .ax_ .get_xticklabels ()[1 ].get_text () == "A"
553+ assert display .ax_ .get_yticklabels ()[1 ].get_text () == "A"
554+
555+ display = report .metrics .confusion_matrix (pos_label = "B" )
556+ display .plot ()
557+ assert display .ax_ .get_xticklabels ()[1 ].get_text () == "B"
558+ assert display .ax_ .get_yticklabels ()[1 ].get_text () == "B"
0 commit comments