@@ -70,6 +70,40 @@ def __call__(self, outputs: torch.Tensor,
70
70
pred_label_ids = max_idx .numpy ().tolist ()
71
71
pred_label_scores = max_value .numpy ().tolist ()
72
72
73
+ # inference process do not have item in gt_label,
74
+ # so select valid token with word_ids rather than
75
+ # with gt_label_ids like official code.
76
+ pred_words_biolabels = []
77
+ word_biolabels = []
78
+ pre_word_id = None
79
+ for idx , cur_word_id in enumerate (word_ids ):
80
+ if cur_word_id is not None :
81
+ if cur_word_id != pre_word_id :
82
+ if word_biolabels :
83
+ pred_words_biolabels .append (word_biolabels )
84
+ word_biolabels = []
85
+ word_biolabels .append ((self .id2biolabel [pred_label_ids [idx ]],
86
+ pred_label_scores [idx ]))
87
+ else :
88
+ pred_words_biolabels .append (word_biolabels )
89
+ break
90
+ pre_word_id = cur_word_id
91
+ # record pred_label
92
+ if self .only_label_first_subword :
93
+ pred_label = LabelData ()
94
+ pred_label .item = [
95
+ pred_word_biolabels [0 ][0 ]
96
+ for pred_word_biolabels in pred_words_biolabels
97
+ ]
98
+ pred_label .score = [
99
+ pred_word_biolabels [0 ][1 ]
100
+ for pred_word_biolabels in pred_words_biolabels
101
+ ]
102
+ merged_data_sample .pred_label = pred_label
103
+ else :
104
+ raise NotImplementedError (
105
+ 'The `only_label_first_subword=False` is not support yet.' )
106
+
73
107
# inference process do not have item in gt_label,
74
108
# so select valid token with word_ids rather than
75
109
# with gt_label_ids like official code.
0 commit comments