16
16
class SERPostprocessor (nn .Module ):
17
17
"""PostProcessor for SER."""
18
18
19
- def __init__ (self , classes : Union [tuple , list ]) -> None :
19
+ def __init__ (self ,
20
+ classes : Union [tuple , list ],
21
+ only_label_first_subword : bool = True ) -> None :
20
22
super ().__init__ ()
21
23
self .other_label_name = find_other_label_name_of_biolabel (classes )
22
24
self .id2biolabel = self ._generate_id2biolabel_map (classes )
25
+ assert only_label_first_subword is True , \
26
+ 'Only support `only_label_first_subword=True` now.'
27
+ self .only_label_first_subword = only_label_first_subword
23
28
self .softmax = nn .Softmax (dim = - 1 )
24
29
25
30
def _generate_id2biolabel_map (self , classes : Union [tuple , list ]) -> Dict :
@@ -40,62 +45,95 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
40
45
def __call__ (self , outputs : torch .Tensor ,
41
46
data_samples : Sequence [SERDataSample ]
42
47
) -> Sequence [SERDataSample ]:
43
- # merge several truncation data_sample to one data_sample
44
48
assert all ('truncation_word_ids' in d for d in data_samples ), \
45
49
'The key `truncation_word_ids` should be specified' \
46
50
'in PackSERInputs.'
47
- truncation_word_ids = []
48
- for data_sample in data_samples :
49
- truncation_word_ids .append (data_sample .pop ('truncation_word_ids' ))
50
- merged_data_sample = copy .deepcopy (data_samples [0 ])
51
- merged_data_sample .set_metainfo (
52
- dict (truncation_word_ids = truncation_word_ids ))
53
- flattened_word_ids = [
54
- word_id for word_ids in truncation_word_ids for word_id in word_ids
51
+ truncation_word_ids = [
52
+ data_sample .pop ('truncation_word_ids' )
53
+ for data_sample in data_samples
54
+ ]
55
+ word_ids = [
56
+ word_id for word_ids in truncation_word_ids
57
+ for word_id in word_ids [1 :- 1 ]
55
58
]
56
59
60
+ # merge several truncation data_sample to one data_sample
61
+ merged_data_sample = copy .deepcopy (data_samples [0 ])
62
+
57
63
# convert outputs dim from (truncation_num, max_length, label_num)
58
64
# to (truncation_num * max_length, label_num)
59
65
outputs = outputs .cpu ().detach ()
60
- outputs = torch .reshape (outputs , (- 1 , outputs .size (- 1 )))
66
+ outputs = torch .reshape (outputs [:, 1 : - 1 , :] , (- 1 , outputs .size (- 1 )))
61
67
# get pred label ids/scores from outputs
62
68
probs = self .softmax (outputs )
63
69
max_value , max_idx = torch .max (probs , - 1 )
64
- pred_label_ids = max_idx .numpy ()
65
- pred_label_scores = max_value .numpy ()
70
+ pred_label_ids = max_idx .numpy ().tolist ()
71
+ pred_label_scores = max_value .numpy ().tolist ()
72
+
73
+ # inference process do not have item in gt_label,
74
+ # so select valid token with word_ids rather than
75
+ # with gt_label_ids like official code.
76
+ pred_words_biolabels = []
77
+ word_biolabels = []
78
+ pre_word_id = None
79
+ for idx , cur_word_id in enumerate (word_ids ):
80
+ if cur_word_id is not None :
81
+ if cur_word_id != pre_word_id :
82
+ if word_biolabels :
83
+ pred_words_biolabels .append (word_biolabels )
84
+ word_biolabels = []
85
+ word_biolabels .append ((self .id2biolabel [pred_label_ids [idx ]],
86
+ pred_label_scores [idx ]))
87
+ else :
88
+ pred_words_biolabels .append (word_biolabels )
89
+ break
90
+ pre_word_id = cur_word_id
91
+ # record pred_label
92
+ if self .only_label_first_subword :
93
+ pred_label = LabelData ()
94
+ pred_label .item = [
95
+ pred_word_biolabels [0 ][0 ]
96
+ for pred_word_biolabels in pred_words_biolabels
97
+ ]
98
+ pred_label .score = [
99
+ pred_word_biolabels [0 ][1 ]
100
+ for pred_word_biolabels in pred_words_biolabels
101
+ ]
102
+ merged_data_sample .pred_label = pred_label
103
+ else :
104
+ raise NotImplementedError (
105
+ 'The `only_label_first_subword=False` is not support yet.' )
66
106
67
107
# determine whether it is an inference process
68
108
if 'item' in data_samples [0 ].gt_label :
69
109
# merge gt label ids from data_samples
70
110
gt_label_ids = [
71
- data_sample .gt_label .item for data_sample in data_samples
111
+ data_sample .gt_label .item [ 1 : - 1 ] for data_sample in data_samples
72
112
]
73
113
gt_label_ids = torch .cat (
74
- gt_label_ids , dim = 0 ).cpu ().detach ().numpy ()
75
- gt_biolabels = [
76
- self .id2biolabel [g ]
77
- for (w , g ) in zip (flattened_word_ids , gt_label_ids )
78
- if w is not None
79
- ]
114
+ gt_label_ids , dim = 0 ).cpu ().detach ().numpy ().tolist ()
115
+ gt_words_biolabels = []
116
+ word_biolabels = []
117
+ pre_word_id = None
118
+ for idx , cur_word_id in enumerate (word_ids ):
119
+ if cur_word_id is not None :
120
+ if cur_word_id != pre_word_id :
121
+ if word_biolabels :
122
+ gt_words_biolabels .append (word_biolabels )
123
+ word_biolabels = []
124
+ word_biolabels .append (self .id2biolabel [gt_label_ids [idx ]])
125
+ else :
126
+ gt_words_biolabels .append (word_biolabels )
127
+ break
128
+ pre_word_id = cur_word_id
80
129
# update merged gt_label
81
- merged_data_sample .gt_label .item = gt_biolabels
82
-
83
- # inference process do not have item in gt_label,
84
- # so select valid token with flattened_word_ids
85
- # rather than with gt_label_ids like official code.
86
- pred_biolabels = [
87
- self .id2biolabel [p ]
88
- for (w , p ) in zip (flattened_word_ids , pred_label_ids )
89
- if w is not None
90
- ]
91
- pred_biolabel_scores = [
92
- s for (w , s ) in zip (flattened_word_ids , pred_label_scores )
93
- if w is not None
94
- ]
95
- # record pred_label
96
- pred_label = LabelData ()
97
- pred_label .item = pred_biolabels
98
- pred_label .score = pred_biolabel_scores
99
- merged_data_sample .pred_label = pred_label
130
+ if self .only_label_first_subword :
131
+ merged_data_sample .gt_label .item = [
132
+ gt_word_biolabels [0 ]
133
+ for gt_word_biolabels in gt_words_biolabels
134
+ ]
135
+ else :
136
+ raise NotImplementedError (
137
+ 'The `only_label_first_subword=False` is not support yet.' )
100
138
101
139
return [merged_data_sample ]
0 commit comments