77from chainer import testing
88from chainer .testing import attr
99
10+ from chainercv .links .model .fpn import BboxHead
1011from chainercv .links .model .fpn import FasterRCNN
11- from chainercv .links .model .fpn import Head
12+ from chainercv .links .model .fpn import MaskHead
1213from chainercv .links .model .fpn import RPN
14+ from chainercv .utils import assert_is_bbox
1315from chainercv .utils import assert_is_detection_link
16+ from chainercv .utils import assert_is_instance_segmentation_link
1417
1518
1619def _random_array (xp , shape ):
@@ -31,28 +34,35 @@ def __call__(self, x):
3134
3235class DummyFasterRCNN (FasterRCNN ):
3336
34- def __init__ (self , n_fg_class , min_size , max_size ):
37+ def __init__ (self , n_fg_class , return_values , min_size , max_size ):
3538 extractor = DummyExtractor ()
3639 super (DummyFasterRCNN , self ).__init__ (
3740 extractor = extractor ,
3841 rpn = RPN (extractor .scales ),
39- head = Head (n_fg_class + 1 , extractor .scales ),
42+ bbox_head = BboxHead (n_fg_class + 1 , extractor .scales ),
43+ mask_head = MaskHead (n_fg_class + 1 , extractor .scales ),
44+ return_values = return_values ,
4045 min_size = min_size , max_size = max_size ,
4146 )
4247
4348
4449@testing .parameterize (* testing .product_dict (
50+ [
51+ {'return_values' : 'detection' },
52+ {'return_values' : 'instance_segmentation' },
53+ {'return_values' : 'rpn' }
54+ ],
4555 [
4656 {'n_fg_class' : 1 },
4757 {'n_fg_class' : 5 },
4858 {'n_fg_class' : 20 },
4959 ],
5060 [
51- {
52- 'in_sizes' : [(480 , 640 ), (320 , 320 )],
53- 'min_size' : 800 , 'max_size' : 1333 ,
54- 'expected_shape' : (800 , 1088 ),
55- },
61+ # {
62+ # 'in_sizes': [(480, 640), (320, 320)],
63+ # 'min_size': 800, 'max_size': 1333,
64+ # 'expected_shape': (800, 1088),
65+ # },
5666 {
5767 'in_sizes' : [(200 , 50 ), (400 , 100 )],
5868 'min_size' : 200 , 'max_size' : 320 ,
@@ -63,7 +73,14 @@ def __init__(self, n_fg_class, min_size, max_size):
6373class TestFasterRCNN (unittest .TestCase ):
6474
6575 def setUp (self ):
76+ if self .return_values == 'detection' :
77+ return_values = ['bboxes' , 'labels' , 'scores' ]
78+ elif self .return_values == 'instance_segmentation' :
79+ return_values = ['masks' , 'labels' , 'scores' ]
80+ elif self .return_values == 'rpn' :
81+ return_values = ['rois' ]
6682 self .link = DummyFasterRCNN (n_fg_class = self .n_fg_class ,
83+ return_values = return_values ,
6784 min_size = self .min_size ,
6885 max_size = self .max_size )
6986
@@ -88,29 +105,20 @@ def test_use_preset(self):
88105 def _check_call (self ):
89106 x = _random_array (self .link .xp , (2 , 3 , 32 , 32 ))
90107 with chainer .using_config ('train' , False ):
91- rois , roi_indices , head_locs , head_confs = self .link (x )
108+ hs , rois , roi_indices = self .link (x )
92109
93- self .assertEqual (len (rois ), len (self .link .extractor .scales ))
94- self .assertEqual (len (roi_indices ), len (self .link .extractor .scales ))
110+ self .assertEqual (len (hs ), len (self .link .extractor .scales ))
95111 for l in range (len (self .link .extractor .scales )):
96- self .assertIsInstance (rois [l ], self .link .xp .ndarray )
97- self .assertEqual (rois [l ].shape [1 :], (4 ,))
98-
99- self .assertIsInstance (roi_indices [l ], self .link .xp .ndarray )
100- self .assertEqual (roi_indices [l ].shape [1 :], ())
101-
102- self .assertEqual (rois [l ].shape [0 ], roi_indices [l ].shape [0 ])
112+ self .assertIsInstance (hs [l ], chainer .Variable )
113+ self .assertIsInstance (hs [l ].data , self .link .xp .ndarray )
103114
104- n_roi = sum (
105- len (rois [ l ]) for l in range ( len ( self . link . extractor . scales ) ))
115+ self . assertIsInstance ( rois , self . link . xp . ndarray )
116+ self . assertEqual (rois . shape [ 1 :], ( 4 , ))
106117
107- self .assertIsInstance (head_locs , chainer .Variable )
108- self .assertIsInstance (head_locs .array , self .link .xp .ndarray )
109- self .assertEqual (head_locs .shape , (n_roi , self .n_fg_class + 1 , 4 ))
118+ self .assertIsInstance (roi_indices , self .link .xp .ndarray )
119+ self .assertEqual (roi_indices .shape [1 :], ())
110120
111- self .assertIsInstance (head_confs , chainer .Variable )
112- self .assertIsInstance (head_confs .array , self .link .xp .ndarray )
113- self .assertEqual (head_confs .shape , (n_roi , self .n_fg_class + 1 ))
121+ self .assertEqual (rois .shape [0 ], roi_indices .shape [0 ])
114122
115123 def test_call_cpu (self ):
116124 self ._check_call ()
@@ -126,13 +134,32 @@ def test_call_train_mode(self):
126134 with chainer .using_config ('train' , True ):
127135 self .link (x )
128136
137+ def _check_predict (self ):
138+ if self .return_values == 'detection' :
139+ assert_is_detection_link (self .link , self .n_fg_class )
140+ elif self .return_values == 'instance_segmentation' :
141+ assert_is_instance_segmentation_link (self .link , self .n_fg_class )
142+ elif self .return_values == 'rpn' :
143+ imgs = [
144+ np .random .randint (
145+ 0 , 256 , size = (3 , 480 , 320 )).astype (np .float32 ),
146+ np .random .randint (
147+ 0 , 256 , size = (3 , 480 , 320 )).astype (np .float32 )]
148+ result = self .link .predict (imgs )
149+ assert len (result ) == 1
150+ assert len (result [0 ]) == 1
151+ for i in range (len (result [0 ])):
152+ roi = result [0 ][i ]
153+ assert_is_bbox (roi )
154+
155+ @attr .slow
129156 def test_predict_cpu (self ):
130- assert_is_detection_link ( self .link , self . n_fg_class )
157+ self ._check_predict ( )
131158
132159 @attr .gpu
133160 def test_predict_gpu (self ):
134161 self .link .to_gpu ()
135- assert_is_detection_link ( self .link , self . n_fg_class )
162+ self ._check_predict ( )
136163
137164 def test_prepare (self ):
138165 imgs = [_random_array (np , (3 , s [0 ], s [1 ])) for s in self .in_sizes ]
0 commit comments