|
9 | 9 | BoundingBox, |
10 | 10 | DataLoader, |
11 | 11 | Detection, |
| 12 | + Filter, |
| 13 | + Metadata, |
12 | 14 | MetricType, |
13 | 15 | ) |
14 | 16 |
|
@@ -780,3 +782,47 @@ def test_filtering_invalid_indices( |
780 | 782 | with pytest.raises(ValueError) as e: |
781 | 783 | evaluator.create_filter(labels=np.array([1000])) |
782 | 784 | assert "cannot exceed total number of labels" in str(e) |
| 785 | + |
| 786 | + |
| 787 | +def test_filter_object(): |
| 788 | + |
| 789 | + mask = np.array([True, False, False]) |
| 790 | + true_mask = np.array([True, True, True]) |
| 791 | + false_mask = ~true_mask |
| 792 | + |
| 793 | + # check that no datums are defined |
| 794 | + with pytest.raises(EmptyFilterError) as e: |
| 795 | + Filter( |
| 796 | + mask_datums=false_mask, |
| 797 | + mask_groundtruths=mask, |
| 798 | + mask_predictions=mask, |
| 799 | + metadata=Metadata(), |
| 800 | + ) |
| 801 | + assert "filter removes all datums" in str(e) |
| 802 | + |
| 803 | + # check that no annotations are defined |
| 804 | + with pytest.raises(EmptyFilterError) as e: |
| 805 | + Filter( |
| 806 | + mask_datums=mask, |
| 807 | + mask_groundtruths=true_mask, |
| 808 | + mask_predictions=true_mask, |
| 809 | + metadata=Metadata(), |
| 810 | + ) |
| 811 | + |
| 812 | + # check that no ground truths are defined |
| 813 | + with pytest.warns(UserWarning): |
| 814 | + Filter( |
| 815 | + mask_datums=mask, |
| 816 | + mask_groundtruths=true_mask, |
| 817 | + mask_predictions=mask, |
| 818 | + metadata=Metadata(), |
| 819 | + ) |
| 820 | + |
| 821 | + # check that no predictions are defined |
| 822 | + with pytest.warns(UserWarning): |
| 823 | + Filter( |
| 824 | + mask_datums=mask, |
| 825 | + mask_groundtruths=mask, |
| 826 | + mask_predictions=true_mask, |
| 827 | + metadata=Metadata(), |
| 828 | + ) |
0 commit comments