Skip to content

Commit 0733d11

Browse files
committed
Output probs to have better accurracy
1 parent 2edb97d commit 0733d11

3 files changed

Lines changed: 122 additions & 74 deletions

File tree

analysis/src/java/com/github/oeuvres/alix/lucene/analysis/PosTaggingFilter.java

Lines changed: 64 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
/*
22
* Alix, A Lucene Indexer for XML documents.
33
*
4+
* Copyright 2026 Frédéric Glorieux <frederic.glorieux@fictif.org> & Unige
5+
* Copyright 2016 Frédéric Glorieux <frederic.glorieux@fictif.org>
46
* Copyright 2009 Pierre Dittgen <pierre@dittgen.org>
57
* Frédéric Glorieux <frederic.glorieux@fictif.org>
6-
* Copyright 2016 Frédéric Glorieux <frederic.glorieux@fictif.org>
78
*
89
* Alix is a java library to index and search XML text documents
910
* with Lucene https://lucene.apache.org/core/
@@ -32,18 +33,17 @@
3233
*/
3334
package com.github.oeuvres.alix.lucene.analysis;
3435

35-
import java.util.Map;
3636

3737
import java.io.IOException;
38-
import java.io.InputStream;
39-
import java.util.Locale;
4038

4139
import org.apache.lucene.analysis.TokenFilter;
4240
import org.apache.lucene.analysis.TokenStream;
4341
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
44-
import org.apache.lucene.analysis.tokenattributes.FlagsAttribute;
4542

4643
import com.github.oeuvres.alix.common.Upos;
44+
import com.github.oeuvres.alix.lucene.analysis.tokenattributes.PosAttribute;
45+
import com.github.oeuvres.alix.lucene.analysis.tokenattributes.ProbAttribute;
46+
4747
import static com.github.oeuvres.alix.common.Upos.*;
4848

4949
import opennlp.tools.postag.POSModel;
@@ -60,44 +60,15 @@ public class PosTaggingFilter extends TokenFilter
6060
}
6161
/** The term provided by the Tokenizer */
6262
private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
63-
/** Current Flags */
64-
private final FlagsAttribute flagsAtt = addAttribute(FlagsAttribute.class);
63+
private final PosAttribute posAtt = addAttribute(PosAttribute.class);
64+
private final ProbAttribute probAtt = addAttribute(ProbAttribute.class);
6565
/** A stack of states */
6666
private TokenStateQueue queue;
6767
/** Maximum size of a sentence to send to the tagger */
6868
final static int SENTMAX = 300;
6969

7070
/** non thread safe tagger, one by instance of filter */
7171
private POSTaggerME tagger;
72-
/** tagset https://universaldependencies.org/u/pos/ */
73-
private static final Map<String, Upos> TAG_LIST = Map.ofEntries(
74-
Map.entry("ADJ", ADJ),
75-
Map.entry("ADP", ADP),
76-
Map.entry("ADP+DET", ADP_DET),
77-
Map.entry("ADP+PRON", ADP_PRON),
78-
Map.entry("ADV", ADV),
79-
Map.entry("AUX", AUX),
80-
Map.entry("CCONJ", CCONJ),
81-
Map.entry("DET", DET),
82-
Map.entry("INTJ", INTJ),
83-
Map.entry("NOUN", NOUN),
84-
Map.entry("NUM", NUM),
85-
Map.entry("PRON", PRON),
86-
Map.entry("PROPN", PROPN),
87-
Map.entry("PUNCT", TOKEN), // pun is filtered upper, tagger bug
88-
Map.entry("SCONJ", SCONJ),
89-
Map.entry("SYM", TOKEN),
90-
Map.entry("VERB", VERB),
91-
Map.entry("X", TOKEN)
92-
);
93-
/** state of the queue */
94-
private boolean tagged = false;
95-
96-
public PosTaggingFilter(TokenStream input)
97-
{
98-
super(input);
99-
throw new Error("TODO");
100-
}
10172

10273

10374
/**
@@ -114,69 +85,85 @@ public PosTaggingFilter(TokenStream input, POSModel posModel)
11485
@Override
11586
public final boolean incrementToken() throws IOException
11687
{
117-
// needed here to have all atts in queue
118-
if (queue == null) {
119-
queue = new TokenStateQueue(SENTMAX, this);
120-
}
121-
// empty the queue
88+
89+
// 0) Drain queued tokens first
12290
if (!queue.isEmpty()) {
12391
clearAttributes();
12492
queue.removeFirst(this);
12593
return true;
12694
}
127-
boolean toksLeft = true;
128-
// store states till pun
95+
96+
97+
// 2) Fill queue until boundary or SENTMAX or EOF
12998
while (queue.size() < SENTMAX) {
130-
clearAttributes(); // clear before next incrementToken
131-
toksLeft = input.incrementToken();
132-
if (!toksLeft)
99+
clearAttributes();
100+
if (!input.incrementToken())
133101
break;
102+
134103
queue.addLast(this);
135-
final int flags = flagsAtt.getFlags();
136-
if (flags == PUNCTsection.code || flags == PUNCTpara.code || flags == PUNCTsent.code)
104+
105+
final int pos = posAtt.getPos(); // structural classification from upstream
106+
if (pos == PUNCTsection.code || pos == PUNCTpara.code || pos == PUNCTsent.code) {
137107
break;
108+
}
138109
}
139-
// should be finisehd here
110+
140111
final int n = queue.size();
141112
if (n == 0)
142113
return false;
143114

144-
String[] sentence = new String[queue.size()];
145-
boolean firstToken = true;
146-
boolean needsTagging = false;
115+
// 3) Build sentence[] for the tagger + detect if we have any lexical TOKEN
116+
final String[] sentence = new String[n];
117+
147118
for (int i = 0; i < n; i++) {
148-
final FlagsAttribute flags = queue.get(i).getAttribute(FlagsAttribute.class);
149-
// those tags will not help tagger
150-
if (flags.getFlags() == PUNCTsection.code || flags.getFlags() == PUNCTpara.code) {
119+
final PosAttribute p = queue.get(i).getAttribute(PosAttribute.class);
120+
final int pos = p.getPos();
121+
122+
if (pos == PUNCTsection.code || pos == PUNCTpara.code) {
151123
sentence[i] = ".";
152124
continue;
153125
}
154-
final CharTermAttribute term = queue.get(i).getAttribute(CharTermAttribute.class);
155-
String s = new String(term.buffer(), 0, term.length());
156126

157-
// bug initial cap, Tu_NAME vas_VERB bien_ ?_PUN
158-
if (firstToken && !s.isEmpty() && Character.isUpperCase(s.charAt(0))) {
159-
s = s.toLowerCase(Locale.ROOT);
127+
final CharTermAttribute t = queue.get(i).getAttribute(CharTermAttribute.class);
128+
String s = t.toString();
129+
130+
/*
131+
if (pos == TOKEN.code) {
132+
// Your “first word” workaround: apply only once, on the first lexical token
133+
if (firstLex && !s.isEmpty() && Character.isUpperCase(s.charAt(0))) {
134+
s = s.toLowerCase(Locale.ROOT);
135+
}
136+
firstLex = false;
137+
needsTagging = true;
160138
}
139+
*/
140+
161141
sentence[i] = s;
162-
if (flags.getFlags() == TOKEN.code)
163-
needsTagging = true;
164142
}
165-
if (needsTagging) {
166-
final String[] tags = tagger.tag(sentence);
167143

168-
for (int i = 0; i < n; i++) {
169-
final FlagsAttribute f = queue.get(i).getAttribute(FlagsAttribute.class);
170-
if (f.getFlags() != TOKEN.code)
171-
continue;
144+
// 4) Tag + write back into PosAttribute (only where upstream said TOKEN)
145+
final String[] tags = tagger.tag(sentence);
146+
double[] probs = tagger.probs();
172147

173-
final Upos upos = TAG_LIST.get(tags[i]);
174-
if (upos != null) {
175-
f.setFlags(upos.code());
176-
}
177-
// else: keep TOKEN (or choose a fallback)
148+
for (int i = 0; i < n; i++) {
149+
final ProbAttribute prob = queue.get(i).getAttribute(ProbAttribute.class);
150+
prob.setProb(probs[i]);
151+
final PosAttribute pos = queue.get(i).getAttribute(PosAttribute.class);
152+
final int origPos = pos.getPos();
153+
// Upper filter provide more precise punctuation than the tagger, keep it.
154+
if (Upos.isPunct(origPos)) continue;
155+
Upos upos = Upos.get(tags[i].replace('+', '_'));
156+
if (upos == null) {
157+
// for testing only
158+
// System.out.println(tags[i]);
178159
}
160+
else {
161+
pos.setPos(upos.code());
162+
}
163+
// else keep TOKEN (or choose a fallback)
179164
}
165+
166+
// 5) Emit first token of the now-tagged queue
180167
clearAttributes();
181168
queue.removeFirst(this);
182169
return true;
@@ -186,7 +173,10 @@ public final boolean incrementToken() throws IOException
186173
public void reset() throws IOException
187174
{
188175
super.reset();
189-
queue.clear();
176+
if (queue == null)
177+
queue = new TokenStateQueue(SENTMAX, this);
178+
else
179+
queue.clear();
190180
}
191181

192182
@Override
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package com.github.oeuvres.alix.lucene.analysis.tokenattributes;
2+
3+
import org.apache.lucene.util.Attribute;
4+
5+
public interface ProbAttribute extends Attribute {
6+
7+
8+
9+
/**
10+
* Return the prob of the pos for the current token.
11+
*
12+
* @return prob
13+
*/
14+
double getProb();
15+
16+
/**
17+
* Set the prob of the pos for the current token.
18+
*
19+
* @param pos POS code, or {@link #UNKNOWN} to clear.
20+
*/
21+
void setProb(double prob);
22+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.github.oeuvres.alix.lucene.analysis.tokenattributes;
2+
3+
import org.apache.lucene.util.AttributeImpl;
4+
import org.apache.lucene.util.AttributeReflector;
5+
6+
public final class ProbAttributeImpl extends AttributeImpl implements ProbAttribute {
7+
private static double UNKNOWN = -1;
8+
private double prob = UNKNOWN;
9+
10+
11+
12+
@Override
13+
public double getProb() {
14+
return prob;
15+
}
16+
17+
@Override
18+
public void setProb(final double prob) {
19+
this.prob = prob;
20+
}
21+
22+
@Override
23+
public void clear() {
24+
prob = UNKNOWN;
25+
}
26+
27+
@Override
28+
public void copyTo(final AttributeImpl target) {
29+
((ProbAttribute) target).setProb(prob);
30+
}
31+
32+
@Override
33+
public void reflectWith(final AttributeReflector reflector) {
34+
reflector.reflect(ProbAttribute.class, "prob", prob);
35+
}
36+
}

0 commit comments

Comments
 (0)