Skip to content

Commit 1d0c5e3

Browse files
committed
阶段性完成SERDataset数据集加载
1 parent d3e16ad commit 1d0c5e3

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

mmocr/datasets/ser_dataset.py

+31-24
Original file line numberDiff line numberDiff line change
@@ -46,39 +46,46 @@ def load_data_list(self) -> List[dict]:
4646
data_list = super().load_data_list()
4747

4848
# split text to several slices because of over-length
49-
input_ids, bboxes, labels = [], [], []
50-
segment_ids, position_ids = [], []
51-
image_path = []
49+
split_text_data_list = []
5250
for i in range(len(data_list)):
5351
start = 0
5452
cur_iter = 0
5553
while start < len(data_list[i]['input_ids']):
5654
end = min(start + 510, len(data_list[i]['input_ids']))
57-
58-
input_ids.append([self.tokenizer.cls_token_id] +
59-
data_list[i]['input_ids'][start:end] +
60-
[self.tokenizer.sep_token_id])
61-
bboxes.append([[0, 0, 0, 0]] +
62-
data_list[i]['bboxes'][start:end] +
63-
[[1000, 1000, 1000, 1000]])
64-
labels.append([-100] + data_list[i]['labels'][start:end] +
65-
[-100])
66-
67-
cur_segment_ids = self.get_segment_ids(bboxes[-1])
68-
cur_position_ids = self.get_position_ids(cur_segment_ids)
69-
segment_ids.append(cur_segment_ids)
70-
position_ids.append(cur_position_ids)
71-
image_path.append(
72-
os.path.join(self.data_root, data_list[i]['img_path']))
55+
# get input_ids
56+
input_ids = [self.tokenizer.cls_token_id] + \
57+
data_list[i]['input_ids'][start:end] + \
58+
[self.tokenizer.sep_token_id]
59+
# get bboxes
60+
bboxes = [[0, 0, 0, 0]] + \
61+
data_list[i]['bboxes'][start:end] + \
62+
[[1000, 1000, 1000, 1000]]
63+
# get labels
64+
labels = [-100] + data_list[i]['labels'][start:end] + [-100]
65+
# get segment_ids
66+
segment_ids = self.get_segment_ids(bboxes)
67+
# get position_ids
68+
position_ids = self.get_position_ids(segment_ids)
69+
# get img_path
70+
img_path = os.path.join(self.data_root,
71+
data_list[i]['img_path'])
72+
# get attention_mask
73+
attention_mask = [1] * len(input_ids)
74+
75+
data_info = {}
76+
data_info['input_ids'] = input_ids
77+
data_info['bboxes'] = bboxes
78+
data_info['labels'] = labels
79+
data_info['segment_ids'] = segment_ids
80+
data_info['position_ids'] = position_ids
81+
data_info['img_path'] = img_path
82+
data_info['attention_mask '] = attention_mask
83+
split_text_data_list.append(data_info)
7384

7485
start = end
7586
cur_iter += 1
7687

77-
assert len(input_ids) == len(bboxes) == len(labels) == len(
78-
segment_ids) == len(position_ids)
79-
assert len(segment_ids) == len(image_path)
80-
81-
return data_list
88+
return split_text_data_list
8289

8390
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
8491
instances = raw_data_info['instances']

projects/LayoutLMv3/test.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
from mmengine.config import Config
2+
from mmengine.registry import init_default_scope
23

34
from mmocr.registry import DATASETS
45

56
if __name__ == '__main__':
67
cfg_path = '/Users/wangnu/Documents/GitHub/mmocr/projects/' \
78
'LayoutLMv3/configs/layoutlmv3_xfund_zh.py'
89
cfg = Config.fromfile(cfg_path)
10+
init_default_scope(cfg.get('default_scope', 'mmocr'))
911

1012
dataset_cfg = cfg.train_dataset
1113
dataset_cfg['tokenizer'] = \
1214
'/Users/wangnu/Documents/GitHub/mmocr/data/layoutlmv3-base-chinese'
15+
16+
train_pipeline = [
17+
dict(type='LoadImageFromFile', color_type='color'),
18+
dict(type='Resize', scale=(224, 224))
19+
]
20+
dataset_cfg['pipeline'] = train_pipeline
1321
ds = DATASETS.build(dataset_cfg)
14-
print(ds[0])
22+
data = ds[0]
23+
print('hi')

0 commit comments

Comments
 (0)