|
| 1 | +import os |
| 2 | + |
| 3 | +from chainer.dataset import download |
| 4 | + |
| 5 | +from chainercv.chainer_experimental.datasets.sliceable import GetterDataset |
| 6 | +from chainercv.datasets.imagenet.imagenet_utils import imagenet_loc_synset_ids |
| 7 | +from chainercv.datasets.voc.voc_utils import parse_voc_bbox_annotation |
| 8 | +from chainercv.utils import read_image |
| 9 | + |
| 10 | + |
| 11 | +class ImagenetLocBboxDataset(GetterDataset): |
| 12 | + |
| 13 | + """ILSVRC2012 ImageNet localization dataset. |
| 14 | +
|
| 15 | + The data is distributed on `the official Kaggle page`_. |
| 16 | +
|
| 17 | + .. _`the official Kaggle page`: https://www.kaggle.com/c/ |
| 18 | + imagenet-object-localization-challenge |
| 19 | +
|
| 20 | + Please refer to the readme of ILSVRC2012 dev kit for a comprehensive |
| 21 | + documentation. Note that the detection part of ILSVRC has not changed since |
| 22 | + 2012. |
| 23 | +
|
| 24 | + Every image in the training and validation sets has a single |
| 25 | + image-level label specifying the presence of one object category. |
| 26 | +
|
| 27 | + Args: |
| 28 | + data_dir (string): Path to the root of the training data. If this is |
| 29 | + :obj:`auto`, |
| 30 | + :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/imagenet` is used. |
| 31 | + split ({'train', 'val'}): Selects a split of the dataset. |
| 32 | +
|
| 33 | + This dataset returns the following data. |
| 34 | +
|
| 35 | + .. csv-table:: |
| 36 | + :header: name, shape, dtype, format |
| 37 | +
|
| 38 | + :obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \ |
| 39 | + "RGB, :math:`[0, 255]`" |
| 40 | + :obj:`bbox`, ":math:`(R, 4)`", :obj:`float32`, \ |
| 41 | + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" |
| 42 | + :obj:`label`, ":math:`(R,)`", :obj:`int32`, \ |
| 43 | + ":math:`[0, \#fg\_class - 1]`" |
| 44 | +
|
| 45 | + """ |
| 46 | + |
| 47 | + def __init__(self, data_dir='auto', split='train'): |
| 48 | + super(ImagenetLocBboxDataset, self).__init__() |
| 49 | + if data_dir == 'auto': |
| 50 | + data_dir = download.get_dataset_directory( |
| 51 | + 'pfnet/chainercv/imagenet') |
| 52 | + self.base_dir = os.path.join(data_dir, 'ILSVRC') |
| 53 | + imageset_dir = os.path.join(self.base_dir, 'ImageSets/CLS-LOC') |
| 54 | + |
| 55 | + ids = [] |
| 56 | + if split == 'train': |
| 57 | + imageset_path = os.path.join(imageset_dir, 'train_loc.txt') |
| 58 | + elif split == 'val': |
| 59 | + imageset_path = os.path.join(imageset_dir, 'val.txt') |
| 60 | + with open(imageset_path) as f: |
| 61 | + for l in f: |
| 62 | + id_ = l.split()[0] |
| 63 | + ids.append(id_) |
| 64 | + self.ids = ids |
| 65 | + self.split = split |
| 66 | + |
| 67 | + self.add_getter('img', self._get_image) |
| 68 | + self.add_getter(('bbox', 'label'), self._get_inst_anno) |
| 69 | + |
| 70 | + def __len__(self): |
| 71 | + return len(self.ids) |
| 72 | + |
| 73 | + def _get_image(self, i): |
| 74 | + img_path = os.path.join( |
| 75 | + self.base_dir, 'Data/CLS-LOC', self.split, |
| 76 | + self.ids[i] + '.JPEG') |
| 77 | + img = read_image(img_path, color=True) |
| 78 | + return img |
| 79 | + |
| 80 | + def _get_inst_anno(self, i): |
| 81 | + anno_path = os.path.join( |
| 82 | + self.base_dir, 'Annotations/CLS-LOC', self.split, |
| 83 | + self.ids[i] + '.xml') |
| 84 | + bbox, label, _ = parse_voc_bbox_annotation( |
| 85 | + anno_path, imagenet_loc_synset_ids, |
| 86 | + skip_names_not_in_label_names=True) |
| 87 | + return bbox, label |
0 commit comments