Skip to content

Commit b6f55f8

Browse files
committed
优化ser_postprocessor
1 parent b04e126 commit b6f55f8

File tree

4 files changed

+99
-77
lines changed

4 files changed

+99
-77
lines changed

projects/LayoutLMv3/configs/ser/layoutlmv3_1k_xfund_zh_1xbs8.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@
108108
pretrained_model_name_or_path=hf_pretrained_model,
109109
num_labels=len(class_name) * 2 - 1),
110110
loss_processor=dict(type='ComputeLossAfterLabelSmooth'),
111-
postprocessor=dict(type='SERPostprocessor', classes=class_name))
111+
postprocessor=dict(
112+
type='SERPostprocessor',
113+
classes=class_name,
114+
only_label_first_subword=only_label_first_subword))
112115
# ====================================================================
113116
# ========================= Evaluation ===============================
114117
val_evaluator = dict(type='SeqevalMetric', prefix=dataset_name)

projects/LayoutLMv3/datasets/transforms/layoutlmv3_transforms.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,12 @@ class ConvertBIOLabelForSER(BaseTransform):
220220

221221
def __init__(self,
222222
classes: Union[tuple, list],
223-
only_label_first_subword: bool = False) -> None:
223+
only_label_first_subword: bool = True) -> None:
224224
super().__init__()
225225
self.other_label_name = find_other_label_name_of_biolabel(classes)
226226
self.biolabel2id = self._generate_biolabel2id_map(classes)
227+
assert only_label_first_subword is True, \
228+
'Only support `only_label_first_subword=True` now.'
227229
self.only_label_first_subword = only_label_first_subword
228230

229231
def _generate_biolabel2id_map(self, classes: Union[tuple, list]) -> Dict:

projects/LayoutLMv3/models/ser_postprocessor.py

+77-39
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616
class SERPostprocessor(nn.Module):
1717
"""PostProcessor for SER."""
1818

19-
def __init__(self, classes: Union[tuple, list]) -> None:
19+
def __init__(self,
20+
classes: Union[tuple, list],
21+
only_label_first_subword: bool = True) -> None:
2022
super().__init__()
2123
self.other_label_name = find_other_label_name_of_biolabel(classes)
2224
self.id2biolabel = self._generate_id2biolabel_map(classes)
25+
assert only_label_first_subword is True, \
26+
'Only support `only_label_first_subword=True` now.'
27+
self.only_label_first_subword = only_label_first_subword
2328
self.softmax = nn.Softmax(dim=-1)
2429

2530
def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
@@ -40,62 +45,95 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
4045
def __call__(self, outputs: torch.Tensor,
4146
data_samples: Sequence[SERDataSample]
4247
) -> Sequence[SERDataSample]:
43-
# merge several truncation data_sample to one data_sample
4448
assert all('truncation_word_ids' in d for d in data_samples), \
4549
'The key `truncation_word_ids` should be specified' \
4650
'in PackSERInputs.'
47-
truncation_word_ids = []
48-
for data_sample in data_samples:
49-
truncation_word_ids.append(data_sample.pop('truncation_word_ids'))
50-
merged_data_sample = copy.deepcopy(data_samples[0])
51-
merged_data_sample.set_metainfo(
52-
dict(truncation_word_ids=truncation_word_ids))
53-
flattened_word_ids = [
54-
word_id for word_ids in truncation_word_ids for word_id in word_ids
51+
truncation_word_ids = [
52+
data_sample.pop('truncation_word_ids')
53+
for data_sample in data_samples
54+
]
55+
word_ids = [
56+
word_id for word_ids in truncation_word_ids
57+
for word_id in word_ids[1:-1]
5558
]
5659

60+
# merge several truncation data_sample to one data_sample
61+
merged_data_sample = copy.deepcopy(data_samples[0])
62+
5763
# convert outputs dim from (truncation_num, max_length, label_num)
5864
# to (truncation_num * max_length, label_num)
5965
outputs = outputs.cpu().detach()
60-
outputs = torch.reshape(outputs, (-1, outputs.size(-1)))
66+
outputs = torch.reshape(outputs[:, 1:-1, :], (-1, outputs.size(-1)))
6167
# get pred label ids/scores from outputs
6268
probs = self.softmax(outputs)
6369
max_value, max_idx = torch.max(probs, -1)
64-
pred_label_ids = max_idx.numpy()
65-
pred_label_scores = max_value.numpy()
70+
pred_label_ids = max_idx.numpy().tolist()
71+
pred_label_scores = max_value.numpy().tolist()
72+
73+
# inference process do not have item in gt_label,
74+
# so select valid token with word_ids rather than
75+
# with gt_label_ids like official code.
76+
pred_words_biolabels = []
77+
word_biolabels = []
78+
pre_word_id = None
79+
for idx, cur_word_id in enumerate(word_ids):
80+
if cur_word_id is not None:
81+
if cur_word_id != pre_word_id:
82+
if word_biolabels:
83+
pred_words_biolabels.append(word_biolabels)
84+
word_biolabels = []
85+
word_biolabels.append((self.id2biolabel[pred_label_ids[idx]],
86+
pred_label_scores[idx]))
87+
else:
88+
pred_words_biolabels.append(word_biolabels)
89+
break
90+
pre_word_id = cur_word_id
91+
# record pred_label
92+
if self.only_label_first_subword:
93+
pred_label = LabelData()
94+
pred_label.item = [
95+
pred_word_biolabels[0][0]
96+
for pred_word_biolabels in pred_words_biolabels
97+
]
98+
pred_label.score = [
99+
pred_word_biolabels[0][1]
100+
for pred_word_biolabels in pred_words_biolabels
101+
]
102+
merged_data_sample.pred_label = pred_label
103+
else:
104+
raise NotImplementedError(
105+
'The `only_label_first_subword=False` is not support yet.')
66106

67107
# determine whether it is an inference process
68108
if 'item' in data_samples[0].gt_label:
69109
# merge gt label ids from data_samples
70110
gt_label_ids = [
71-
data_sample.gt_label.item for data_sample in data_samples
111+
data_sample.gt_label.item[1:-1] for data_sample in data_samples
72112
]
73113
gt_label_ids = torch.cat(
74-
gt_label_ids, dim=0).cpu().detach().numpy()
75-
gt_biolabels = [
76-
self.id2biolabel[g]
77-
for (w, g) in zip(flattened_word_ids, gt_label_ids)
78-
if w is not None
79-
]
114+
gt_label_ids, dim=0).cpu().detach().numpy().tolist()
115+
gt_words_biolabels = []
116+
word_biolabels = []
117+
pre_word_id = None
118+
for idx, cur_word_id in enumerate(word_ids):
119+
if cur_word_id is not None:
120+
if cur_word_id != pre_word_id:
121+
if word_biolabels:
122+
gt_words_biolabels.append(word_biolabels)
123+
word_biolabels = []
124+
word_biolabels.append(self.id2biolabel[gt_label_ids[idx]])
125+
else:
126+
gt_words_biolabels.append(word_biolabels)
127+
break
128+
pre_word_id = cur_word_id
80129
# update merged gt_label
81-
merged_data_sample.gt_label.item = gt_biolabels
82-
83-
# inference process do not have item in gt_label,
84-
# so select valid token with flattened_word_ids
85-
# rather than with gt_label_ids like official code.
86-
pred_biolabels = [
87-
self.id2biolabel[p]
88-
for (w, p) in zip(flattened_word_ids, pred_label_ids)
89-
if w is not None
90-
]
91-
pred_biolabel_scores = [
92-
s for (w, s) in zip(flattened_word_ids, pred_label_scores)
93-
if w is not None
94-
]
95-
# record pred_label
96-
pred_label = LabelData()
97-
pred_label.item = pred_biolabels
98-
pred_label.score = pred_biolabel_scores
99-
merged_data_sample.pred_label = pred_label
130+
if self.only_label_first_subword:
131+
merged_data_sample.gt_label.item = [
132+
gt_word_biolabels[0]
133+
for gt_word_biolabels in gt_words_biolabels
134+
]
135+
else:
136+
raise NotImplementedError(
137+
'The `only_label_first_subword=False` is not support yet.')
100138

101139
return [merged_data_sample]

projects/LayoutLMv3/visualization/ser_visualizer.py

+15-36
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def __init__(self,
6565
self.line_width = line_width
6666
self.alpha = alpha
6767

68-
def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
69-
torch.Tensor],
70-
word_ids: Optional[List[int]],
71-
gt_labels: Optional[LabelData],
72-
pred_labels: Optional[LabelData]) -> np.ndarray:
68+
def _draw_instances(self,
69+
image: np.ndarray,
70+
bboxes: Union[np.ndarray, torch.Tensor],
71+
gt_labels: Optional[LabelData] = None,
72+
pred_labels: Optional[LabelData] = None) -> np.ndarray:
7373
"""Draw bboxes and polygons on image.
7474
7575
Args:
@@ -97,33 +97,19 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
9797

9898
if gt_labels is not None:
9999
gt_tokens_biolabel = gt_labels.item
100-
gt_words_label = []
101-
102-
pre_word_id = None
103-
for idx, cur_word_id in enumerate(word_ids):
104-
if cur_word_id is not None:
105-
if cur_word_id != pre_word_id:
106-
gt_words_label_name = gt_tokens_biolabel[idx][2:] \
107-
if gt_tokens_biolabel[idx] != 'O' else 'other'
108-
gt_words_label.append(gt_words_label_name)
109-
pre_word_id = cur_word_id
100+
gt_words_label = [
101+
token_biolabel[2:] if token_biolabel != 'O' else 'other'
102+
for token_biolabel in gt_tokens_biolabel
103+
]
110104
assert len(gt_words_label) == len(bboxes)
105+
111106
if pred_labels is not None:
112107
pred_tokens_biolabel = pred_labels.item
113-
pred_words_label = []
114-
pred_tokens_biolabel_score = pred_labels.score
115-
pred_words_label_score = []
116-
117-
pre_word_id = None
118-
for idx, cur_word_id in enumerate(word_ids):
119-
if cur_word_id is not None:
120-
if cur_word_id != pre_word_id:
121-
pred_words_label_name = pred_tokens_biolabel[idx][2:] \
122-
if pred_tokens_biolabel[idx] != 'O' else 'other'
123-
pred_words_label.append(pred_words_label_name)
124-
pred_words_label_score.append(
125-
pred_tokens_biolabel_score[idx])
126-
pre_word_id = cur_word_id
108+
pred_words_label = [
109+
token_biolabel[2:] if token_biolabel != 'O' else 'other'
110+
for token_biolabel in pred_tokens_biolabel
111+
]
112+
pred_words_label_score = pred_labels.score
127113
assert len(pred_words_label) == len(bboxes)
128114

129115
# draw gt or pred labels
@@ -205,11 +191,6 @@ def add_datasample(self,
205191
cat_images = []
206192
if data_sample is not None:
207193
bboxes = np.array(data_sample.instances.get('boxes', None))
208-
# here need to flatten truncation_word_ids
209-
word_ids = [
210-
word_id for word_ids in data_sample.truncation_word_ids
211-
for word_id in word_ids[1:-1]
212-
]
213194
gt_label = data_sample.gt_label if \
214195
draw_gt and 'gt_label' in data_sample else None
215196
pred_label = data_sample.pred_label if \
@@ -218,15 +199,13 @@ def add_datasample(self,
218199
orig_img_with_bboxes = self._draw_instances(
219200
image=image.copy(),
220201
bboxes=bboxes,
221-
word_ids=None,
222202
gt_labels=None,
223203
pred_labels=None)
224204
cat_images.append(orig_img_with_bboxes)
225205
empty_img = np.full_like(image, 255)
226206
empty_img_with_label = self._draw_instances(
227207
image=empty_img,
228208
bboxes=bboxes,
229-
word_ids=word_ids,
230209
gt_labels=gt_label,
231210
pred_labels=pred_label)
232211
cat_images.append(empty_img_with_label)

0 commit comments

Comments
 (0)