Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Commit eee85b7

Browse files
authored
Merge pull request #716 from yuyu2172/cub-fail-test
Fix failing CUB tests
2 parents 4ca0db1 + 6733506 commit eee85b7

File tree

4 files changed

+28
-44
lines changed

4 files changed

+28
-44
lines changed

chainercv/datasets/cub/cub_label_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33

44
from chainercv.datasets.cub.cub_utils import CUBDatasetBase
5+
from chainercv import utils
56

67

78
class CUBLabelDataset(CUBDatasetBase):
@@ -54,6 +55,7 @@ def __init__(self, data_dir='auto', return_bb=False,
5455
d_label in open(image_class_labels_file)]
5556
self._labels = np.array(labels, dtype=np.int32)
5657

58+
self.add_getter('img', self._get_image)
5759
self.add_getter('label', self._get_label)
5860

5961
keys = ('img', 'label')
@@ -63,5 +65,11 @@ def __init__(self, data_dir='auto', return_bb=False,
6365
keys += ('prob_map',)
6466
self.keys = keys
6567

68+
def _get_image(self, i):
69+
img = utils.read_image(
70+
os.path.join(self.data_dir, 'images', self.paths[i]),
71+
color=True)
72+
return img
73+
6674
def _get_label(self, i):
6775
return self._labels[i]

chainercv/datasets/cub/cub_point_dataset.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44

55
from chainercv.datasets.cub.cub_utils import CUBDatasetBase
6+
from chainercv import utils
67

78

89
class CUBPointDataset(CUBDatasetBase):
@@ -66,7 +67,8 @@ def __init__(self, data_dir='auto', return_bb=False,
6667
self._point_dict[id_].append(point)
6768
self._mask_dict[id_].append(mask)
6869

69-
self.add_getter(('point', 'mask'), self._get_annotations)
70+
self.add_getter(('img', 'point', 'mask'),
71+
self._get_img_and_annotations)
7072

7173
keys = ('img', 'point', 'mask')
7274
if return_bb:
@@ -75,7 +77,17 @@ def __init__(self, data_dir='auto', return_bb=False,
7577
keys += ('prob_map',)
7678
self.keys = keys
7779

78-
def _get_annotations(self, i):
80+
def _get_img_and_annotations(self, i):
81+
img = utils.read_image(
82+
os.path.join(self.data_dir, 'images', self.paths[i]),
83+
color=True)
84+
7985
point = np.array(self._point_dict[i], dtype=np.float32)
8086
mask = np.array(self._mask_dict[i], dtype=np.bool)
81-
return point, mask
87+
88+
_, H, W = img.shape
89+
invalid = np.logical_or(
90+
np.logical_or(point[:, 0] > H, point[:, 1] > W),
91+
np.any(point < 0, axis=1))
92+
mask[invalid] = False
93+
return img, point, mask

chainercv/datasets/cub/cub_utils.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -77,52 +77,16 @@ def __init__(self, data_dir='auto', prob_map_dir='auto'):
7777
os.path.join(self.prob_map_dir, os.path.splitext(path)[0] + '.png')
7878
for path in self.paths]
7979

80-
self.add_getter('img', self.get_image)
81-
self.add_getter('bb', self.get_bb)
82-
self.add_getter('prob_map', self.get_prob_map)
80+
self.add_getter('bb', self._get_bb)
81+
self.add_getter('prob_map', self._get_prob_map)
8382

8483
def __len__(self):
8584
return len(self.paths)
8685

87-
def get_image(self, i):
88-
"""Returns the i-th image.
89-
90-
Args:
91-
i (int): The index of the example.
92-
93-
Returns:
94-
An image.
95-
The image is in CHW format and its color channel is ordered in
96-
RGB.
97-
98-
"""
99-
img = utils.read_image(
100-
os.path.join(self.data_dir, 'images', self.paths[i]),
101-
color=True)
102-
return img
103-
104-
def get_bb(self, i):
105-
"""Returns the bounding box of the i-th example.
106-
107-
Args:
108-
i (int): The index of the example.
109-
110-
Returns:
111-
A bounding box.
112-
113-
"""
86+
def _get_bb(self, i):
11487
return self.bbs[i]
11588

116-
def get_prob_map(self, i):
117-
"""Returns the probability map of the i-th example.
118-
119-
Args:
120-
i (int): The index of the example.
121-
122-
Returns:
123-
A probability map.
124-
125-
"""
89+
def _get_prob_map(self, i):
12690
prob_map = utils.read_image(self.prob_map_paths[i],
12791
dtype=np.uint8, color=False)
12892
prob_map = prob_map.astype(np.float32) / 255 # [0, 255] -> [0, 1]

tests/datasets_tests/cub_tests/test_cub_point_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def setUp(self):
2121
return_prob_map=self.return_prob_map)
2222

2323
@attr.slow
24-
def test_camvid_dataset(self):
24+
def test_cub_point_dataset(self):
2525
assert_is_point_dataset(
2626
self.dataset, n_point=15, n_example=10)
2727

0 commit comments

Comments
 (0)