@@ -67,5 +67,44 @@ def test_metric_interface(metric_kwargs):
6767 assert 'mAP' in metric_results
6868
6969
70+ @pytest .mark .parametrize (
71+ argnames = ('predictions' , 'groundtruths' , 'num_classes' , 'target_mAP' ),
72+ argvalues = [(
73+ [{
74+ 'bboxes' :
75+ np .array ([
76+ [23 , 31 , 10.0 , 20.0 , 0.0 ], # noqa: E201
77+ [100 , 120 , 10.0 , 20.0 , 0.1 ], # noqa: E201
78+ [150 , 160 , 10.0 , 20.0 , 0.2 ], # noqa: E201
79+ [250 , 260 , 10.0 , 20.0 , 0.3 ], # noqa: E201
80+ ]),
81+ 'scores' :
82+ np .array ([1.0 , 0.98 , 0.96 , 0.95 ]),
83+ 'labels' :
84+ np .array ([0 ] * 4 )
85+ }],
86+ [{
87+ 'bboxes' :
88+ np .array ([
89+ [23 , 31 , 10.0 , 20.0 , 0.0 ], # noqa: E201
90+ [100 , 120 , 10.0 , 20.0 , 0.1 ], # noqa: E201
91+ [150 , 160 , 10.0 , 20.0 , 0.2 ], # noqa: E201
92+ [250 , 260 , 10.0 , 20.0 , 0.3 ], # noqa: E201
93+ ]),
94+ 'labels' :
95+ np .array ([0 ] * 4 ),
96+ 'bboxes_ignore' :
97+ np .empty ((0 , 5 )),
98+ 'labels_ignore' :
99+ np .empty ((0 , )),
100+ }],
101+ 2 ,
102+ 1.0 )])
103+ def test_metric_accurate (predictions , groundtruths , num_classes , target_mAP ):
104+ dota_map = DOTAMeanAP (num_classes = num_classes )
105+ metric_results = dota_map (predictions , groundtruths )
106+ np .testing .assert_almost_equal (metric_results ['mAP' ], target_mAP )
107+
108+
70109if __name__ == '__main__' :
71110 pytest .main ([__file__ , '-vv' , '--capture=no' ])
0 commit comments