1717import ai .djl .huggingface .tokenizers .jni .CharSpan ;
1818import ai .djl .modality .nlp .translator .NamedEntity ;
1919import ai .djl .ndarray .NDArray ;
20+ import ai .djl .ndarray .NDArrays ;
2021import ai .djl .ndarray .NDList ;
2122import ai .djl .ndarray .NDManager ;
2223import ai .djl .translate .ArgumentsUtil ;
2324import ai .djl .translate .Batchifier ;
2425import ai .djl .translate .Translator ;
2526import ai .djl .translate .TranslatorContext ;
2627import ai .djl .util .JsonUtils ;
28+ import ai .djl .util .Pair ;
2729
2830import java .io .IOException ;
2931import java .io .Reader ;
3032import java .nio .file .Files ;
3133import java .nio .file .Path ;
3234import java .util .ArrayList ;
35+ import java .util .Arrays ;
36+ import java .util .Collections ;
37+ import java .util .Comparator ;
3338import java .util .List ;
3439import java .util .Map ;
3540
@@ -40,20 +45,17 @@ public class TokenClassificationTranslator implements Translator<String, NamedEn
4045 private boolean includeTokenTypes ;
4146 private boolean int32 ;
4247 private boolean softmax ;
48+ private String aggregationStrategy ;
4349 private Batchifier batchifier ;
4450 private PretrainedConfig config ;
4551
46- TokenClassificationTranslator (
47- HuggingFaceTokenizer tokenizer ,
48- boolean includeTokenTypes ,
49- boolean int32 ,
50- boolean softmax ,
51- Batchifier batchifier ) {
52- this .tokenizer = tokenizer ;
53- this .includeTokenTypes = includeTokenTypes ;
54- this .int32 = int32 ;
55- this .softmax = softmax ;
56- this .batchifier = batchifier ;
52+ TokenClassificationTranslator (Builder builder ) {
53+ this .tokenizer = builder .tokenizer ;
54+ this .includeTokenTypes = builder .includeTokenTypes ;
55+ this .int32 = builder .int32 ;
56+ this .softmax = builder .softmax ;
57+ this .aggregationStrategy = builder .aggregationStrategy ;
58+ this .batchifier = builder .batchifier ;
5759 }
5860
5961 /** {@inheritDoc} */
@@ -77,6 +79,7 @@ public void prepare(TranslatorContext ctx) throws IOException {
7779 public NDList processInput (TranslatorContext ctx , String input ) {
7880 Encoding encoding = tokenizer .encode (input );
7981 ctx .setAttachment ("encoding" , encoding );
82+ ctx .setAttachment ("sentence" , input );
8083 return encoding .toNDList (ctx .getNDManager (), includeTokenTypes , int32 );
8184 }
8285
@@ -86,6 +89,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs) {
8689 NDManager manager = ctx .getNDManager ();
8790 Encoding [] encodings = tokenizer .batchEncode (inputs );
8891 ctx .setAttachment ("encodings" , encodings );
92+ ctx .setAttachment ("sentences" , inputs );
8993 NDList [] batch = new NDList [encodings .length ];
9094 for (int i = 0 ; i < encodings .length ; ++i ) {
9195 batch [i ] = encodings [i ].toNDList (manager , includeTokenTypes , int32 );
@@ -97,17 +101,20 @@ public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs) {
97101 @ Override
98102 public NamedEntity [] processOutput (TranslatorContext ctx , NDList list ) {
99103 Encoding encoding = (Encoding ) ctx .getAttachment ("encoding" );
100- return toNamedEntities (encoding , list );
104+ String sentence = (String ) ctx .getAttachment ("sentence" );
105+ return toNamedEntities (encoding , list , sentence );
101106 }
102107
103108 /** {@inheritDoc} */
104109 @ Override
110+ @ SuppressWarnings ("unchecked" )
105111 public List <NamedEntity []> batchProcessOutput (TranslatorContext ctx , NDList list ) {
106112 NDList [] batch = batchifier .unbatchify (list );
107113 Encoding [] encodings = (Encoding []) ctx .getAttachment ("encodings" );
114+ List <String > sentences = (List <String >) ctx .getAttachment ("sentences" );
108115 List <NamedEntity []> ret = new ArrayList <>(batch .length );
109116 for (int i = 0 ; i < batch .length ; ++i ) {
110- ret .add (toNamedEntities (encodings [i ], batch [i ]));
117+ ret .add (toNamedEntities (encodings [i ], batch [i ], sentences . get ( i ) ));
111118 }
112119 return ret ;
113120 }
@@ -136,46 +143,175 @@ public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arg
136143 return builder ;
137144 }
138145
139- private NamedEntity [] toNamedEntities (Encoding encoding , NDList list ) {
146+ private NamedEntity [] toNamedEntities (Encoding encoding , NDList list , String sentence ) {
140147 long [] inputIds = encoding .getIds ();
141148 CharSpan [] offsetMapping = encoding .getCharTokenSpans ();
142149 long [] specialTokenMasks = encoding .getSpecialTokenMask ();
150+ String [] words = encoding .getTokens ();
151+ long [] tokenIds = encoding .getIds ();
143152 NDArray probabilities = list .get (0 );
144153 if (softmax ) {
145154 probabilities = probabilities .softmax (1 );
146155 }
147156
148- List <NamedEntity > entities = new ArrayList <>();
149-
157+ List <NamedEntityEx > entities = new ArrayList <>();
150158 for (int i = 0 ; i < inputIds .length ; ++i ) {
151159 if (specialTokenMasks [i ] != 0 ) {
152160 continue ;
153161 }
154162
155- int entityIdx = (int ) probabilities .get (i ).argMax ().getLong ();
156- String entity = config .id2label .get (String .valueOf (entityIdx ));
163+ NDArray prob = probabilities .get (i );
164+ int start = offsetMapping [i ].getStart ();
165+ int end = offsetMapping [i ].getEnd ();
166+ boolean isSubWord = false ;
167+ if (start > 0
168+ && ("first" .equals (aggregationStrategy )
169+ || "average" .equals (aggregationStrategy )
170+ || "max" .equals (aggregationStrategy ))) {
171+ int pos = sentence .indexOf (' ' , start - 1 );
172+ if (pos < 0 || pos > start ) {
173+ isSubWord = true ;
174+ }
175+ }
176+
177+ NamedEntityEx item =
178+ new NamedEntityEx (prob , i , words [i ], start , end , tokenIds [i ], isSubWord );
179+ entities .add (item );
180+ }
181+ if ("first" .equals (aggregationStrategy )
182+ || "average" .equals (aggregationStrategy )
183+ || "max" .equals (aggregationStrategy )) {
184+ entities = aggregateWords (entities );
185+ entities = groupEntities (entities );
186+ } else if ("simple" .equals (aggregationStrategy )) {
187+ entities = groupEntities (entities );
188+ }
157189
158- if (! "O" . equals ( entity )) {
159- float score = probabilities . get ( i ). getFloat ( entityIdx );
160- String word = encoding . getTokens ()[ i ];
161- int start = offsetMapping [ i ]. getStart ( );
162- int end = offsetMapping [ i ]. getEnd ();
190+ return entities . stream ()
191+ . filter ( o -> ! "O" . equals ( o . getEntity ()))
192+ . map ( NamedEntityEx :: toNamedEntity )
193+ . toArray ( NamedEntity []:: new );
194+ }
163195
164- NamedEntity item = new NamedEntity (entity , score , i , word , start , end );
165- entities .add (item );
196+ private List <NamedEntityEx > aggregateWords (List <NamedEntityEx > entities ) {
197+ List <NamedEntityEx > agg = new ArrayList <>();
198+ List <NamedEntityEx > group = new ArrayList <>();
199+ for (NamedEntityEx entity : entities ) {
200+ if (!entity .isSubWord && !group .isEmpty ()) {
201+ agg .add (aggregateWord (group ));
202+ group .clear ();
166203 }
204+ group .add (entity );
167205 }
168- return entities .toArray (new NamedEntity [0 ]);
206+ if (!group .isEmpty ()) {
207+ agg .add (aggregateWord (group ));
208+ }
209+ return agg ;
210+ }
211+
212+ private NamedEntityEx aggregateWord (List <NamedEntityEx > entities ) {
213+ if (entities .size () == 1 ) {
214+ return entities .get (0 );
215+ }
216+ List <Long > tokenIds = new ArrayList <>();
217+ for (NamedEntityEx entity : entities ) {
218+ tokenIds .addAll (entity .tokenIds );
219+ }
220+ NamedEntityEx first = entities .get (0 );
221+ NamedEntityEx last = entities .get (entities .size () - 1 );
222+
223+ String entityName ;
224+ float score ;
225+
226+ if ("first" .equals (aggregationStrategy )) {
227+ entityName = first .getEntity ();
228+ score = first .getScore ();
229+ } else if ("max" .equals (aggregationStrategy )) {
230+ NamedEntityEx max =
231+ entities .stream ()
232+ .max (Comparator .comparingDouble (NamedEntityEx ::getScore ))
233+ .get ();
234+ entityName = max .getEntity ();
235+ score = max .getScore ();
236+ } else {
237+ // average
238+ NDArray [] arrays = entities .stream ().map (o -> o .prob ).toArray (NDArray []::new );
239+ NDList list = new NDList (arrays );
240+ NDArray array = NDArrays .stack (list ).mean (new int [] {0 });
241+ int entityIdx = (int ) array .argMax ().getLong ();
242+ entityName = config .id2label .get (String .valueOf (entityIdx ));
243+ score = array .getFloat (entityIdx );
244+ }
245+ return new NamedEntityEx (entityName , score , first .start , last .end , tokenIds );
246+ }
247+
248+ private List <NamedEntityEx > groupEntities (List <NamedEntityEx > entities ) {
249+ List <NamedEntityEx > disaggregateGroup = new ArrayList <>();
250+ List <NamedEntityEx > entityGroups = new ArrayList <>();
251+
252+ for (NamedEntityEx entity : entities ) {
253+ if (disaggregateGroup .isEmpty ()) {
254+ disaggregateGroup .add (entity );
255+ continue ;
256+ }
257+
258+ Pair <String , String > tag = getTag (entity .getEntity ());
259+ NamedEntityEx lastEntity = disaggregateGroup .get (disaggregateGroup .size () - 1 );
260+ Pair <String , String > lastTag = getTag (lastEntity .getEntity ());
261+ if (!tag .getValue ().equals (lastTag .getValue ()) || "B" .equals (tag .getKey ())) {
262+ entityGroups .add (groupSubEntities (disaggregateGroup ));
263+ disaggregateGroup .clear ();
264+ }
265+ disaggregateGroup .add (entity );
266+ }
267+
268+ if (!disaggregateGroup .isEmpty ()) {
269+ entityGroups .add (groupSubEntities (disaggregateGroup ));
270+ }
271+ return entityGroups ;
272+ }
273+
274+ private Pair <String , String > getTag (String entityName ) {
275+ if (entityName .startsWith ("B-" )) {
276+ return new Pair <>("B" , entityName .substring (2 ));
277+ } else if (entityName .startsWith ("I-" )) {
278+ return new Pair <>("I" , entityName .substring (2 ));
279+ } else {
280+ return new Pair <>("I" , entityName );
281+ }
282+ }
283+
284+ private NamedEntityEx groupSubEntities (List <NamedEntityEx > entities ) {
285+ List <Long > tokens = new ArrayList <>();
286+ double [] scores = new double [entities .size ()];
287+ for (int i = 0 ; i < scores .length ; ++i ) {
288+ NamedEntityEx entity = entities .get (i );
289+ tokens .addAll (entity .tokenIds );
290+ scores [i ] = entity .getScore ();
291+ }
292+ long [] tokenIds = tokens .stream ().mapToLong (Long ::longValue ).toArray ();
293+ String aggWord = tokenizer .decode (tokenIds );
294+ float aggScore = (float ) Arrays .stream (scores ).sum () / scores .length ;
295+ NamedEntityEx first = entities .get (0 );
296+ NamedEntityEx last = entities .get (entities .size () - 1 );
297+ String entityName = first .getEntity ();
298+ int pos = entityName .indexOf ('-' );
299+ if (pos > 0 ) {
300+ entityName = entityName .substring (pos + 1 );
301+ }
302+
303+ return new NamedEntityEx (entityName , aggScore , aggWord , first .start , last .end );
169304 }
170305
171306 /** The builder for token classification translator. */
172307 public static final class Builder {
173308
174- private HuggingFaceTokenizer tokenizer ;
175- private boolean includeTokenTypes ;
176- private boolean int32 ;
177- private boolean softmax = true ;
178- private Batchifier batchifier = Batchifier .STACK ;
309+ HuggingFaceTokenizer tokenizer ;
310+ boolean includeTokenTypes ;
311+ boolean int32 ;
312+ boolean softmax = true ;
313+ String aggregationStrategy ;
314+ Batchifier batchifier = Batchifier .STACK ;
179315
180316 Builder (HuggingFaceTokenizer tokenizer ) {
181317 this .tokenizer = tokenizer ;
@@ -225,6 +361,18 @@ public Builder optBatchifier(Batchifier batchifier) {
225361 return this ;
226362 }
227363
364+ /**
365+ * Sets the aggregation strategy for the {@link Translator}.
366+ *
367+ * @param aggregationStrategy the aggregation strategy, one of none, simple, first, average,
368+ * max
369+ * @return this builder
370+ */
371+ public Builder optAggregationStrategy (String aggregationStrategy ) {
372+ this .aggregationStrategy = aggregationStrategy ;
373+ return this ;
374+ }
375+
228376 /**
229377 * Configures the builder with the model arguments.
230378 *
@@ -234,6 +382,8 @@ public void configure(Map<String, ?> arguments) {
234382 optIncludeTokenTypes (ArgumentsUtil .booleanValue (arguments , "includeTokenTypes" ));
235383 optInt32 (ArgumentsUtil .booleanValue (arguments , "int32" ));
236384 optSoftmax (ArgumentsUtil .booleanValue (arguments , "softmax" , true ));
385+ optAggregationStrategy (
386+ ArgumentsUtil .stringValue (arguments , "aggregation_strategy" , "none" ));
237387 String batchifierStr = ArgumentsUtil .stringValue (arguments , "batchifier" , "stack" );
238388 optBatchifier (Batchifier .fromString (batchifierStr ));
239389 }
@@ -244,8 +394,82 @@ public void configure(Map<String, ?> arguments) {
244394 * @return the new translator
245395 */
246396 public TokenClassificationTranslator build () {
247- return new TokenClassificationTranslator (
248- tokenizer , includeTokenTypes , int32 , softmax , batchifier );
397+ return new TokenClassificationTranslator (this );
398+ }
399+ }
400+
401+ private class NamedEntityEx {
402+
403+ String entity ;
404+ float score ;
405+ int index ;
406+ String word ;
407+ int start ;
408+ int end ;
409+ List <Long > tokenIds ;
410+ boolean isSubWord ;
411+ NDArray prob ;
412+ private boolean initialized ;
413+
414+ NamedEntityEx (String entity , float score , String word , int start , int end ) {
415+ this .entity = entity ;
416+ this .score = score ;
417+ this .index = -1 ;
418+ this .word = word ;
419+ this .start = start ;
420+ this .end = end ;
421+ initialized = true ;
422+ }
423+
424+ NamedEntityEx (String entity , float score , int start , int end , List <Long > tokenIds ) {
425+ this .entity = entity ;
426+ this .score = score ;
427+ this .index = -1 ;
428+ this .start = start ;
429+ this .end = end ;
430+ this .tokenIds = tokenIds ;
431+ initialized = true ;
432+ }
433+
434+ NamedEntityEx (
435+ NDArray prob ,
436+ int index ,
437+ String word ,
438+ int start ,
439+ int end ,
440+ long tokenId ,
441+ boolean isSubWord ) {
442+ this .prob = prob ;
443+ this .index = index ;
444+ this .word = word ;
445+ this .start = start ;
446+ this .end = end ;
447+ this .tokenIds = Collections .singletonList (tokenId );
448+ this .isSubWord = isSubWord ;
449+ }
450+
451+ private void init () {
452+ if (!initialized ) {
453+ int entityIdx = (int ) prob .argMax ().getLong ();
454+ entity = config .id2label .get (String .valueOf (entityIdx ));
455+ score = prob .getFloat (entityIdx );
456+ initialized = true ;
457+ }
458+ }
459+
460+ String getEntity () {
461+ init ();
462+ return entity ;
463+ }
464+
465+ float getScore () {
466+ init ();
467+ return score ;
468+ }
469+
470+ NamedEntity toNamedEntity () {
471+ init ();
472+ return new NamedEntity (entity , score , index , word , start , end );
249473 }
250474 }
251475}
0 commit comments