Skip to content

Commit e7627a5

Browse files
authored
Fixes: deepjavalibrary#3795, adds aggregation_strategy to token classification task (deepjavalibrary#3798)
1 parent f1648bb commit e7627a5

File tree

2 files changed

+334
-50
lines changed

2 files changed

+334
-50
lines changed

extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java

Lines changed: 257 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,24 @@
1717
import ai.djl.huggingface.tokenizers.jni.CharSpan;
1818
import ai.djl.modality.nlp.translator.NamedEntity;
1919
import ai.djl.ndarray.NDArray;
20+
import ai.djl.ndarray.NDArrays;
2021
import ai.djl.ndarray.NDList;
2122
import ai.djl.ndarray.NDManager;
2223
import ai.djl.translate.ArgumentsUtil;
2324
import ai.djl.translate.Batchifier;
2425
import ai.djl.translate.Translator;
2526
import ai.djl.translate.TranslatorContext;
2627
import ai.djl.util.JsonUtils;
28+
import ai.djl.util.Pair;
2729

2830
import java.io.IOException;
2931
import java.io.Reader;
3032
import java.nio.file.Files;
3133
import java.nio.file.Path;
3234
import java.util.ArrayList;
35+
import java.util.Arrays;
36+
import java.util.Collections;
37+
import java.util.Comparator;
3338
import java.util.List;
3439
import java.util.Map;
3540

@@ -40,20 +45,17 @@ public class TokenClassificationTranslator implements Translator<String, NamedEn
4045
private boolean includeTokenTypes;
4146
private boolean int32;
4247
private boolean softmax;
48+
private String aggregationStrategy;
4349
private Batchifier batchifier;
4450
private PretrainedConfig config;
4551

46-
TokenClassificationTranslator(
47-
HuggingFaceTokenizer tokenizer,
48-
boolean includeTokenTypes,
49-
boolean int32,
50-
boolean softmax,
51-
Batchifier batchifier) {
52-
this.tokenizer = tokenizer;
53-
this.includeTokenTypes = includeTokenTypes;
54-
this.int32 = int32;
55-
this.softmax = softmax;
56-
this.batchifier = batchifier;
52+
TokenClassificationTranslator(Builder builder) {
53+
this.tokenizer = builder.tokenizer;
54+
this.includeTokenTypes = builder.includeTokenTypes;
55+
this.int32 = builder.int32;
56+
this.softmax = builder.softmax;
57+
this.aggregationStrategy = builder.aggregationStrategy;
58+
this.batchifier = builder.batchifier;
5759
}
5860

5961
/** {@inheritDoc} */
@@ -77,6 +79,7 @@ public void prepare(TranslatorContext ctx) throws IOException {
7779
public NDList processInput(TranslatorContext ctx, String input) {
7880
Encoding encoding = tokenizer.encode(input);
7981
ctx.setAttachment("encoding", encoding);
82+
ctx.setAttachment("sentence", input);
8083
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes, int32);
8184
}
8285

@@ -86,6 +89,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs) {
8689
NDManager manager = ctx.getNDManager();
8790
Encoding[] encodings = tokenizer.batchEncode(inputs);
8891
ctx.setAttachment("encodings", encodings);
92+
ctx.setAttachment("sentences", inputs);
8993
NDList[] batch = new NDList[encodings.length];
9094
for (int i = 0; i < encodings.length; ++i) {
9195
batch[i] = encodings[i].toNDList(manager, includeTokenTypes, int32);
@@ -97,17 +101,20 @@ public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs) {
97101
@Override
98102
public NamedEntity[] processOutput(TranslatorContext ctx, NDList list) {
99103
Encoding encoding = (Encoding) ctx.getAttachment("encoding");
100-
return toNamedEntities(encoding, list);
104+
String sentence = (String) ctx.getAttachment("sentence");
105+
return toNamedEntities(encoding, list, sentence);
101106
}
102107

103108
/** {@inheritDoc} */
104109
@Override
110+
@SuppressWarnings("unchecked")
105111
public List<NamedEntity[]> batchProcessOutput(TranslatorContext ctx, NDList list) {
106112
NDList[] batch = batchifier.unbatchify(list);
107113
Encoding[] encodings = (Encoding[]) ctx.getAttachment("encodings");
114+
List<String> sentences = (List<String>) ctx.getAttachment("sentences");
108115
List<NamedEntity[]> ret = new ArrayList<>(batch.length);
109116
for (int i = 0; i < batch.length; ++i) {
110-
ret.add(toNamedEntities(encodings[i], batch[i]));
117+
ret.add(toNamedEntities(encodings[i], batch[i], sentences.get(i)));
111118
}
112119
return ret;
113120
}
@@ -136,46 +143,175 @@ public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arg
136143
return builder;
137144
}
138145

139-
private NamedEntity[] toNamedEntities(Encoding encoding, NDList list) {
146+
private NamedEntity[] toNamedEntities(Encoding encoding, NDList list, String sentence) {
140147
long[] inputIds = encoding.getIds();
141148
CharSpan[] offsetMapping = encoding.getCharTokenSpans();
142149
long[] specialTokenMasks = encoding.getSpecialTokenMask();
150+
String[] words = encoding.getTokens();
151+
long[] tokenIds = encoding.getIds();
143152
NDArray probabilities = list.get(0);
144153
if (softmax) {
145154
probabilities = probabilities.softmax(1);
146155
}
147156

148-
List<NamedEntity> entities = new ArrayList<>();
149-
157+
List<NamedEntityEx> entities = new ArrayList<>();
150158
for (int i = 0; i < inputIds.length; ++i) {
151159
if (specialTokenMasks[i] != 0) {
152160
continue;
153161
}
154162

155-
int entityIdx = (int) probabilities.get(i).argMax().getLong();
156-
String entity = config.id2label.get(String.valueOf(entityIdx));
163+
NDArray prob = probabilities.get(i);
164+
int start = offsetMapping[i].getStart();
165+
int end = offsetMapping[i].getEnd();
166+
boolean isSubWord = false;
167+
if (start > 0
168+
&& ("first".equals(aggregationStrategy)
169+
|| "average".equals(aggregationStrategy)
170+
|| "max".equals(aggregationStrategy))) {
171+
int pos = sentence.indexOf(' ', start - 1);
172+
if (pos < 0 || pos > start) {
173+
isSubWord = true;
174+
}
175+
}
176+
177+
NamedEntityEx item =
178+
new NamedEntityEx(prob, i, words[i], start, end, tokenIds[i], isSubWord);
179+
entities.add(item);
180+
}
181+
if ("first".equals(aggregationStrategy)
182+
|| "average".equals(aggregationStrategy)
183+
|| "max".equals(aggregationStrategy)) {
184+
entities = aggregateWords(entities);
185+
entities = groupEntities(entities);
186+
} else if ("simple".equals(aggregationStrategy)) {
187+
entities = groupEntities(entities);
188+
}
157189

158-
if (!"O".equals(entity)) {
159-
float score = probabilities.get(i).getFloat(entityIdx);
160-
String word = encoding.getTokens()[i];
161-
int start = offsetMapping[i].getStart();
162-
int end = offsetMapping[i].getEnd();
190+
return entities.stream()
191+
.filter(o -> !"O".equals(o.getEntity()))
192+
.map(NamedEntityEx::toNamedEntity)
193+
.toArray(NamedEntity[]::new);
194+
}
163195

164-
NamedEntity item = new NamedEntity(entity, score, i, word, start, end);
165-
entities.add(item);
196+
private List<NamedEntityEx> aggregateWords(List<NamedEntityEx> entities) {
197+
List<NamedEntityEx> agg = new ArrayList<>();
198+
List<NamedEntityEx> group = new ArrayList<>();
199+
for (NamedEntityEx entity : entities) {
200+
if (!entity.isSubWord && !group.isEmpty()) {
201+
agg.add(aggregateWord(group));
202+
group.clear();
166203
}
204+
group.add(entity);
167205
}
168-
return entities.toArray(new NamedEntity[0]);
206+
if (!group.isEmpty()) {
207+
agg.add(aggregateWord(group));
208+
}
209+
return agg;
210+
}
211+
212+
private NamedEntityEx aggregateWord(List<NamedEntityEx> entities) {
213+
if (entities.size() == 1) {
214+
return entities.get(0);
215+
}
216+
List<Long> tokenIds = new ArrayList<>();
217+
for (NamedEntityEx entity : entities) {
218+
tokenIds.addAll(entity.tokenIds);
219+
}
220+
NamedEntityEx first = entities.get(0);
221+
NamedEntityEx last = entities.get(entities.size() - 1);
222+
223+
String entityName;
224+
float score;
225+
226+
if ("first".equals(aggregationStrategy)) {
227+
entityName = first.getEntity();
228+
score = first.getScore();
229+
} else if ("max".equals(aggregationStrategy)) {
230+
NamedEntityEx max =
231+
entities.stream()
232+
.max(Comparator.comparingDouble(NamedEntityEx::getScore))
233+
.get();
234+
entityName = max.getEntity();
235+
score = max.getScore();
236+
} else {
237+
// average
238+
NDArray[] arrays = entities.stream().map(o -> o.prob).toArray(NDArray[]::new);
239+
NDList list = new NDList(arrays);
240+
NDArray array = NDArrays.stack(list).mean(new int[] {0});
241+
int entityIdx = (int) array.argMax().getLong();
242+
entityName = config.id2label.get(String.valueOf(entityIdx));
243+
score = array.getFloat(entityIdx);
244+
}
245+
return new NamedEntityEx(entityName, score, first.start, last.end, tokenIds);
246+
}
247+
248+
private List<NamedEntityEx> groupEntities(List<NamedEntityEx> entities) {
249+
List<NamedEntityEx> disaggregateGroup = new ArrayList<>();
250+
List<NamedEntityEx> entityGroups = new ArrayList<>();
251+
252+
for (NamedEntityEx entity : entities) {
253+
if (disaggregateGroup.isEmpty()) {
254+
disaggregateGroup.add(entity);
255+
continue;
256+
}
257+
258+
Pair<String, String> tag = getTag(entity.getEntity());
259+
NamedEntityEx lastEntity = disaggregateGroup.get(disaggregateGroup.size() - 1);
260+
Pair<String, String> lastTag = getTag(lastEntity.getEntity());
261+
if (!tag.getValue().equals(lastTag.getValue()) || "B".equals(tag.getKey())) {
262+
entityGroups.add(groupSubEntities(disaggregateGroup));
263+
disaggregateGroup.clear();
264+
}
265+
disaggregateGroup.add(entity);
266+
}
267+
268+
if (!disaggregateGroup.isEmpty()) {
269+
entityGroups.add(groupSubEntities(disaggregateGroup));
270+
}
271+
return entityGroups;
272+
}
273+
274+
private Pair<String, String> getTag(String entityName) {
275+
if (entityName.startsWith("B-")) {
276+
return new Pair<>("B", entityName.substring(2));
277+
} else if (entityName.startsWith("I-")) {
278+
return new Pair<>("I", entityName.substring(2));
279+
} else {
280+
return new Pair<>("I", entityName);
281+
}
282+
}
283+
284+
private NamedEntityEx groupSubEntities(List<NamedEntityEx> entities) {
285+
List<Long> tokens = new ArrayList<>();
286+
double[] scores = new double[entities.size()];
287+
for (int i = 0; i < scores.length; ++i) {
288+
NamedEntityEx entity = entities.get(i);
289+
tokens.addAll(entity.tokenIds);
290+
scores[i] = entity.getScore();
291+
}
292+
long[] tokenIds = tokens.stream().mapToLong(Long::longValue).toArray();
293+
String aggWord = tokenizer.decode(tokenIds);
294+
float aggScore = (float) Arrays.stream(scores).sum() / scores.length;
295+
NamedEntityEx first = entities.get(0);
296+
NamedEntityEx last = entities.get(entities.size() - 1);
297+
String entityName = first.getEntity();
298+
int pos = entityName.indexOf('-');
299+
if (pos > 0) {
300+
entityName = entityName.substring(pos + 1);
301+
}
302+
303+
return new NamedEntityEx(entityName, aggScore, aggWord, first.start, last.end);
169304
}
170305

171306
/** The builder for token classification translator. */
172307
public static final class Builder {
173308

174-
private HuggingFaceTokenizer tokenizer;
175-
private boolean includeTokenTypes;
176-
private boolean int32;
177-
private boolean softmax = true;
178-
private Batchifier batchifier = Batchifier.STACK;
309+
HuggingFaceTokenizer tokenizer;
310+
boolean includeTokenTypes;
311+
boolean int32;
312+
boolean softmax = true;
313+
String aggregationStrategy;
314+
Batchifier batchifier = Batchifier.STACK;
179315

180316
Builder(HuggingFaceTokenizer tokenizer) {
181317
this.tokenizer = tokenizer;
@@ -225,6 +361,18 @@ public Builder optBatchifier(Batchifier batchifier) {
225361
return this;
226362
}
227363

364+
/**
365+
* Sets the aggregation strategy for the {@link Translator}.
366+
*
367+
* @param aggregationStrategy the aggregation strategy, one of none, simple, first, average,
368+
* max
369+
* @return this builder
370+
*/
371+
public Builder optAggregationStrategy(String aggregationStrategy) {
372+
this.aggregationStrategy = aggregationStrategy;
373+
return this;
374+
}
375+
228376
/**
229377
* Configures the builder with the model arguments.
230378
*
@@ -234,6 +382,8 @@ public void configure(Map<String, ?> arguments) {
234382
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
235383
optInt32(ArgumentsUtil.booleanValue(arguments, "int32"));
236384
optSoftmax(ArgumentsUtil.booleanValue(arguments, "softmax", true));
385+
optAggregationStrategy(
386+
ArgumentsUtil.stringValue(arguments, "aggregation_strategy", "none"));
237387
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
238388
optBatchifier(Batchifier.fromString(batchifierStr));
239389
}
@@ -244,8 +394,82 @@ public void configure(Map<String, ?> arguments) {
244394
* @return the new translator
245395
*/
246396
public TokenClassificationTranslator build() {
247-
return new TokenClassificationTranslator(
248-
tokenizer, includeTokenTypes, int32, softmax, batchifier);
397+
return new TokenClassificationTranslator(this);
398+
}
399+
}
400+
401+
private class NamedEntityEx {
402+
403+
String entity;
404+
float score;
405+
int index;
406+
String word;
407+
int start;
408+
int end;
409+
List<Long> tokenIds;
410+
boolean isSubWord;
411+
NDArray prob;
412+
private boolean initialized;
413+
414+
NamedEntityEx(String entity, float score, String word, int start, int end) {
415+
this.entity = entity;
416+
this.score = score;
417+
this.index = -1;
418+
this.word = word;
419+
this.start = start;
420+
this.end = end;
421+
initialized = true;
422+
}
423+
424+
NamedEntityEx(String entity, float score, int start, int end, List<Long> tokenIds) {
425+
this.entity = entity;
426+
this.score = score;
427+
this.index = -1;
428+
this.start = start;
429+
this.end = end;
430+
this.tokenIds = tokenIds;
431+
initialized = true;
432+
}
433+
434+
NamedEntityEx(
435+
NDArray prob,
436+
int index,
437+
String word,
438+
int start,
439+
int end,
440+
long tokenId,
441+
boolean isSubWord) {
442+
this.prob = prob;
443+
this.index = index;
444+
this.word = word;
445+
this.start = start;
446+
this.end = end;
447+
this.tokenIds = Collections.singletonList(tokenId);
448+
this.isSubWord = isSubWord;
449+
}
450+
451+
private void init() {
452+
if (!initialized) {
453+
int entityIdx = (int) prob.argMax().getLong();
454+
entity = config.id2label.get(String.valueOf(entityIdx));
455+
score = prob.getFloat(entityIdx);
456+
initialized = true;
457+
}
458+
}
459+
460+
String getEntity() {
461+
init();
462+
return entity;
463+
}
464+
465+
float getScore() {
466+
init();
467+
return score;
468+
}
469+
470+
NamedEntity toNamedEntity() {
471+
init();
472+
return new NamedEntity(entity, score, index, word, start, end);
249473
}
250474
}
251475
}

0 commit comments

Comments
 (0)