@@ -46,39 +46,46 @@ def load_data_list(self) -> List[dict]:
46
46
data_list = super ().load_data_list ()
47
47
48
48
# split text to several slices because of over-length
49
- input_ids , bboxes , labels = [], [], []
50
- segment_ids , position_ids = [], []
51
- image_path = []
49
+ split_text_data_list = []
52
50
for i in range (len (data_list )):
53
51
start = 0
54
52
cur_iter = 0
55
53
while start < len (data_list [i ]['input_ids' ]):
56
54
end = min (start + 510 , len (data_list [i ]['input_ids' ]))
57
-
58
- input_ids .append ([self .tokenizer .cls_token_id ] +
59
- data_list [i ]['input_ids' ][start :end ] +
60
- [self .tokenizer .sep_token_id ])
61
- bboxes .append ([[0 , 0 , 0 , 0 ]] +
62
- data_list [i ]['bboxes' ][start :end ] +
63
- [[1000 , 1000 , 1000 , 1000 ]])
64
- labels .append ([- 100 ] + data_list [i ]['labels' ][start :end ] +
65
- [- 100 ])
66
-
67
- cur_segment_ids = self .get_segment_ids (bboxes [- 1 ])
68
- cur_position_ids = self .get_position_ids (cur_segment_ids )
69
- segment_ids .append (cur_segment_ids )
70
- position_ids .append (cur_position_ids )
71
- image_path .append (
72
- os .path .join (self .data_root , data_list [i ]['img_path' ]))
55
+ # get input_ids
56
+ input_ids = [self .tokenizer .cls_token_id ] + \
57
+ data_list [i ]['input_ids' ][start :end ] + \
58
+ [self .tokenizer .sep_token_id ]
59
+ # get bboxes
60
+ bboxes = [[0 , 0 , 0 , 0 ]] + \
61
+ data_list [i ]['bboxes' ][start :end ] + \
62
+ [[1000 , 1000 , 1000 , 1000 ]]
63
+ # get labels
64
+ labels = [- 100 ] + data_list [i ]['labels' ][start :end ] + [- 100 ]
65
+ # get segment_ids
66
+ segment_ids = self .get_segment_ids (bboxes )
67
+ # get position_ids
68
+ position_ids = self .get_position_ids (segment_ids )
69
+ # get img_path
70
+ img_path = os .path .join (self .data_root ,
71
+ data_list [i ]['img_path' ])
72
+ # get attention_mask
73
+ attention_mask = [1 ] * len (input_ids )
74
+
75
+ data_info = {}
76
+ data_info ['input_ids' ] = input_ids
77
+ data_info ['bboxes' ] = bboxes
78
+ data_info ['labels' ] = labels
79
+ data_info ['segment_ids' ] = segment_ids
80
+ data_info ['position_ids' ] = position_ids
81
+ data_info ['img_path' ] = img_path
82
+ data_info ['attention_mask ' ] = attention_mask
83
+ split_text_data_list .append (data_info )
73
84
74
85
start = end
75
86
cur_iter += 1
76
87
77
- assert len (input_ids ) == len (bboxes ) == len (labels ) == len (
78
- segment_ids ) == len (position_ids )
79
- assert len (segment_ids ) == len (image_path )
80
-
81
- return data_list
88
+ return split_text_data_list
82
89
83
90
def parse_data_info (self , raw_data_info : dict ) -> Union [dict , List [dict ]]:
84
91
instances = raw_data_info ['instances' ]
0 commit comments