-
Notifications
You must be signed in to change notification settings - Fork 634
Expand file tree
/
Copy pathohd_sjtu.py
More file actions
148 lines (123 loc) · 5.34 KB
/
ohd_sjtu.py
File metadata and controls
148 lines (123 loc) · 5.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
from typing import List, Tuple
from mmengine.dataset import BaseDataset
from mmrotate.registry import DATASETS
@DATASETS.register_module()
class OHD_SJTUDataset_S(BaseDataset):
"""OHD-SJTU-S dataset for detection.
Note: 'ann_file' in OHD_SJTUDataset is different from the BaseDataset.
In BaseDataset, it is the path of an annotation file. In OHD_SJTUDataset,
it is the path of a folder containing txt files.
Args:
img_shape (tuple[int]):
diff_thr (int):
"""
METAINFO = {
'classes': ('ship', 'plane'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(165, 42, 42), (189, 183, 107)]
}
def __init__(self,
img_shape: Tuple[int, int] = (1024, 1024),
diff_thr: int = 100,
**kwargs) -> None:
self.img_shape = img_shape
self.diff_thr = diff_thr
super().__init__(**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as 'self.ann_file'.
Returns:
List[dict]: A list of annotation.
"""
assert self._metainfo.get('classes', None) is not None, \
"classes in 'OHD-SJTUDataset' can not be None"
cls_map = {c: i for i, c in enumerate(self.metainfo['classes'])}
data_list = []
if self.ann_file == '':
img_files = glob.glob(
osp.join(self.data_prefix['img_path'], '*.png'))
for img_path in img_files:
data_info = {'img_path': img_path}
img_name = osp.split(img_path)[1]
data_info['file_name'] = img_name
img_id = img_name[:-4]
data_info['img_id'] = img_id
data_info['height'] = self.img_shape[0]
data_info['width'] = self.img_shape[1]
instance = dict(
bbox=[], bbox_head=[], bbox_label=[], ignore_flag=0)
data_info['instances'] = [instance]
data_list.append(data_info)
return data_list
else:
txt_files = glob.glob(osp.join(self.ann_file, '*.txt'))
if len(txt_files) == 0:
raise ValueError('There is no txt file in '
f'{self.ann_file}')
for txt_file in txt_files:
img_id = osp.split(txt_file)[1][:-4]
data_info = {'img_id': img_id}
img_name = img_id + '.png'
data_info['file_name'] = img_name
data_info['img_path'] = osp.join(self.data_prefix['img_path'],
img_name)
data_info['height'] = self.img_shape[0]
data_info['width'] = self.img_shape[1]
instances = []
with open(txt_file) as f:
contents = f.readlines()
for content in contents:
bbox_info = content.split(' ')
instance = {'bbox': [float(i) for i in bbox_info[:10]]}
cls_name = bbox_info[-2]
instance['bbox_label'] = cls_map[cls_name]
difficulty = int(bbox_info[-1])
if difficulty > self.diff_thr:
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instances.append(instance)
data_info['instances'] = instances
data_list.append(data_info)
print(len(data_list))
return data_list
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
if self.filter_cfg is not None else False
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
if filter_empty_gt and len(data_info['instances']) == 0:
continue
valid_data_infos.append(data_info)
return valid_data_infos
def get_cat_ids(self, idx: int) -> List[int]:
"""Get OHD-SJTU category ids by index.
Args:
idx(int): Index of data.
Returns:
List[int]: All categories in the image of specified index.
"""
instances = self.get_data_info(idx)['instances']
return [instance['bbox_label'] for instance in instances]
@DATASETS.register_module()
class OHD_SJTUDataset_L(OHD_SJTUDataset_S):
"""OHD-SJTU-L dataset for detection.
Note: 'ann_file' in OHD_SJTUDataset is different from the BaseDataset.
In BaseDataset, it is the path of an annotation file. In OHD_SJTUDataset,
it is the path of a folder containing txt files.
"""
METAINFO = {
'classes': ('ship', 'plane', 'small-vehicle', 'large-vehicle',
'harbor', 'helicopter'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
(138, 43, 226), (255, 128, 0)]
}