Skip to content

Commit d02cbaa

Browse files
committed
OPENNLP-124 : Maxent/Perceptron training should report progess back via an API
1 parent d7e097d commit d02cbaa

21 files changed

+861
-18
lines changed

opennlp-docs/src/docbkx/postagger.xml

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,114 @@ try (OutputStream modelOut = new BufferedOutputStream(new FileOutputStream(model
237237
</para>
238238
<para>
239239
The dictionary is defined in a xml format and can be created and stored with the POSDictionary class.
240-
Please for now checkout the javadoc and source code of that class.
240+
Below is an example to train a custom model using a tag dictionary.
241241
</para>
242-
<para>Note: The format should be documented and sample code should show how to use the dictionary.
243-
Any contributions are very welcome. If you want to contribute please contact us on the mailing list
244-
or comment on the jira issue <ulink url="https://issues.apache.org/jira/browse/OPENNLP-287">OPENNLP-287</ulink>.
242+
<para>
243+
Sample POS Training material (file : en-custom-pos.train)
244+
<screen>
245+
<![CDATA[
246+
It_PRON is_OTHER spring_PROPN season_NOUN. The_DET flowers_NOUN are_OTHER red_ADJ and_CCONJ yellow_ADJ ._PUNCT
247+
Red_NOUN is_OTHER my_DET favourite_ADJ colour_NOUN ._PUNCT]]>
248+
</screen>
249+
</para>
250+
<para>
251+
Sample Tag Dictionary (file : dictionary.xml)
252+
<programlisting language="xml">
253+
<![CDATA[
254+
<?xml version="1.0" encoding="UTF-8"?>
255+
<dictionary case_sensitive="false">
256+
<entry tags="PRON">
257+
<token>It</token>
258+
</entry>
259+
<entry tags="OTHER">
260+
<token>is</token>
261+
</entry>
262+
<entry tags="PROPN">
263+
<token>Spring</token>
264+
</entry>
265+
<entry tags="NOUN">
266+
<token>season</token>
267+
</entry>
268+
<entry tags="DET">
269+
<token>the</token>
270+
</entry>
271+
<entry tags="NOUN">
272+
<token>flowers</token>
273+
</entry>
274+
<entry tags="OTHER">
275+
<token>are</token>
276+
</entry>
277+
<entry tags="NOUN">
278+
<token>red</token>
279+
</entry>
280+
<entry tags="CCONJ">
281+
<token>and</token>
282+
</entry>
283+
<entry tags="NOUN">
284+
<token>yellow</token>
285+
</entry>
286+
<entry tags="PRON">
287+
<token>my</token>
288+
</entry>
289+
<entry tags="ADJ">
290+
<token>favourite</token>
291+
</entry>
292+
<entry tags="NOUN">
293+
<token>colour</token>
294+
</entry>
295+
<entry tags="PUNCT">
296+
<token>.</token>
297+
</entry>
298+
</dictionary>]]>
299+
</programlisting>
300+
</para>
301+
<para>Sample code to train a model using above tag dictionary
302+
<programlisting language="java">
303+
<![CDATA[
304+
POSModel model = null;
305+
try {
306+
ObjectStream<String> lineStream = new PlainTextByLineStream(
307+
new MarkableFileInputStreamFactory(new File("en-custom-pos.train")), StandardCharsets.UTF_8);
308+
309+
ObjectStream<POSSample> sampleStream = new WordTagSampleStream(lineStream);
310+
311+
TrainingParameters params = ModelUtil.createDefaultTrainingParameters();
312+
params.put(TrainingParameters.CUTOFF_PARAM, 0);
313+
314+
POSTaggerFactory factory = new POSTaggerFactory();
315+
TagDictionary dict = factory.createTagDictionary(new File("dictionary.xml"));
316+
factory.setTagDictionary(dict);
317+
318+
model = POSTaggerME.train("eng", sampleStream, params, factory);
319+
320+
OutputStream modelOut = new BufferedOutputStream(new FileOutputStream("en-custom-pos-maxent.bin"));
321+
model.serialize(modelOut);
322+
323+
} catch (IOException e) {
324+
e.printStackTrace();
325+
}]]>
326+
</programlisting>
327+
</para>
328+
<para>
329+
The custom model is then used to tag a sequence.
330+
<programlisting language="java">
331+
<![CDATA[
332+
String[] sent = new String[]{"Spring", "is", "my", "favourite", "season", "."};
333+
String[] tags = tagger.tag(sent);
334+
Arrays.stream(tags).forEach(k -> System.out.print(k + " "));]]>
335+
</programlisting>
336+
</para>
337+
<para>
338+
<literallayout>
339+
Input
340+
Sentence: Spring is my favourite season.
341+
342+
Output
343+
POS Tags using the custom model (en-custom-pos-maxent.bin): PROPN OTHER PRON ADJ NOUN PUNCT
344+
345+
Output with the default model
346+
POS Tags using the default model (opennlp-en-ud-ewt-pos-1.2-2.5.0.bin): NOUN AUX PRON ADJ NOUN PUNCT
347+
</literallayout>
245348
</para>
246349
</section>
247350
</section>

opennlp-tools/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@
7272
<scope>test</scope>
7373
</dependency>
7474

75+
<dependency>
76+
<groupId>org.assertj</groupId>
77+
<artifactId>assertj-core</artifactId>
78+
<version>${assertj-core.version}</version>
79+
<scope>test</scope>
80+
</dependency>
81+
7582
</dependencies>
7683

7784
<build>

opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.Map;
2121

22+
import opennlp.tools.util.TrainingConfiguration;
2223
import opennlp.tools.util.TrainingParameters;
2324

2425
/**
@@ -35,4 +36,14 @@ public interface Trainer {
3536
*/
3637
void init(TrainingParameters trainParams, Map<String, String> reportMap);
3738

39+
/**
40+
* Conducts the initialization of a {@link Trainer} via
41+
* {@link TrainingParameters}, {@link Map report map} and {@link TrainingConfiguration}
42+
*
43+
* @param trainParams The {@link TrainingParameters} to use.
44+
* @param reportMap The {@link Map} instance used as report map.
45+
* @param config The {@link TrainingConfiguration} to use.
46+
*/
47+
void init(TrainingParameters trainParams, Map<String, String> reportMap, TrainingConfiguration config);
48+
3849
}

opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222

2323
import opennlp.tools.commons.Trainer;
2424
import opennlp.tools.ml.maxent.GISTrainer;
25+
import opennlp.tools.util.TrainingConfiguration;
2526
import opennlp.tools.util.TrainingParameters;
2627

2728
public abstract class AbstractTrainer implements Trainer {
2829

2930
protected TrainingParameters trainingParameters;
3031
protected Map<String,String> reportMap;
32+
protected TrainingConfiguration trainingConfiguration;
3133

3234
public AbstractTrainer() {
3335
}
@@ -55,6 +57,20 @@ public void init(TrainingParameters trainParams, Map<String,String> reportMap) {
5557
this.reportMap = reportMap;
5658
}
5759

60+
/**
61+
* Initializes a {@link AbstractTrainer} using following parameters.
62+
*
63+
* @param trainParams The {@link TrainingParameters} to use.
64+
* @param reportMap The {@link Map} instance used as report map.
65+
* @param config The {@link TrainingConfiguration} to use.
66+
*/
67+
@Override
68+
public void init(TrainingParameters trainParams, Map<String, String> reportMap,
69+
TrainingConfiguration config) {
70+
init(trainParams, reportMap);
71+
this.trainingConfiguration = config;
72+
}
73+
5874
/**
5975
* @return Retrieves the configured {@link TrainingParameters#ALGORITHM_PARAM} value.
6076
*/
@@ -108,4 +124,12 @@ protected void addToReport(String key, String value) {
108124
reportMap.put(key, value);
109125
}
110126

127+
/**
128+
* Retrieves the {@link TrainingConfiguration} associated with a {@link AbstractTrainer}.
129+
* @return {@link TrainingConfiguration}
130+
*/
131+
public TrainingConfiguration getTrainingConfiguration() {
132+
return trainingConfiguration;
133+
}
134+
111135
}

opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import opennlp.tools.ml.naivebayes.NaiveBayesTrainer;
2727
import opennlp.tools.ml.perceptron.PerceptronTrainer;
2828
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
29+
import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
30+
import opennlp.tools.util.TrainingConfiguration;
2931
import opennlp.tools.util.TrainingParameters;
3032
import opennlp.tools.util.ext.ExtensionLoader;
3133
import opennlp.tools.util.ext.ExtensionNotLoadedException;
@@ -180,6 +182,22 @@ public static <T> EventModelSequenceTrainer<T> getEventModelSequenceTrainer(
180182
}
181183
}
182184

185+
/**
186+
* Works like {@link TrainerFactory#getEventTrainer(TrainingParameters, Map, TrainingConfiguration)}
187+
* except that the {@link TrainingConfiguration} is initialized with {@link DefaultTrainingProgressMonitor}
188+
* and a null {@link opennlp.tools.monitoring.StopCriteria}.
189+
* If not provided, the actual {@link opennlp.tools.monitoring.StopCriteria}
190+
* will be decided by the {@link EventTrainer} implementation.
191+
*
192+
*/
193+
public static EventTrainer getEventTrainer(
194+
TrainingParameters trainParams, Map<String, String> reportMap) {
195+
196+
TrainingConfiguration trainingConfiguration
197+
= new TrainingConfiguration(new DefaultTrainingProgressMonitor(), null);
198+
return getEventTrainer(trainParams, reportMap, trainingConfiguration);
199+
}
200+
183201
/**
184202
* Retrieves an {@link EventTrainer} that fits the given parameters.
185203
*
@@ -189,11 +207,14 @@ public static <T> EventModelSequenceTrainer<T> getEventModelSequenceTrainer(
189207
* {@link GISTrainer#MAXENT_VALUE} will be used.
190208
* @param reportMap A {@link Map} that shall be used during initialization of
191209
* the {@link EventTrainer}.
210+
* @param config The {@link TrainingConfiguration} to be used. This determines the type of
211+
* {@link opennlp.tools.monitoring.TrainingProgressMonitor}
212+
* and the {@link opennlp.tools.monitoring.StopCriteria} to be used.
192213
*
193214
* @return A valid {@link EventTrainer} for the configured {@code trainParams}.
194215
*/
195216
public static EventTrainer getEventTrainer(
196-
TrainingParameters trainParams, Map<String, String> reportMap) {
217+
TrainingParameters trainParams, Map<String, String> reportMap, TrainingConfiguration config) {
197218

198219
// if the trainerType is not defined -- use the GISTrainer.
199220
String trainerType = trainParams.getStringParameter(
@@ -205,7 +226,7 @@ public static EventTrainer getEventTrainer(
205226
} else {
206227
trainer = ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType);
207228
}
208-
trainer.init(trainParams, reportMap);
229+
trainer.init(trainParams, reportMap, config);
209230
return trainer;
210231
}
211232

opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
import opennlp.tools.ml.model.OnePassDataIndexer;
4141
import opennlp.tools.ml.model.Prior;
4242
import opennlp.tools.ml.model.UniformPrior;
43+
import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
44+
import opennlp.tools.monitoring.LogLikelihoodThresholdBreached;
45+
import opennlp.tools.monitoring.StopCriteria;
46+
import opennlp.tools.monitoring.TrainingMeasure;
47+
import opennlp.tools.monitoring.TrainingProgressMonitor;
4348
import opennlp.tools.util.ObjectStream;
49+
import opennlp.tools.util.TrainingConfiguration;
4450
import opennlp.tools.util.TrainingParameters;
4551

4652
/**
@@ -497,6 +503,11 @@ private void findParameters(int iterations, double correctionConstant) {
497503
new ExecutorCompletionService<>(executor);
498504
double prevLL = 0.0;
499505
double currLL;
506+
507+
//Get the Training Progress Monitor and the StopCriteria.
508+
TrainingProgressMonitor progressMonitor = getTrainingProgressMonitor(trainingConfiguration);
509+
StopCriteria stopCriteria = getStopCriteria(trainingConfiguration);
510+
500511
logger.info("Performing {} iterations.", iterations);
501512
for (int i = 1; i <= iterations; i++) {
502513
currLL = nextIteration(correctionConstant, completionService, i);
@@ -505,13 +516,20 @@ private void findParameters(int iterations, double correctionConstant) {
505516
logger.warn("Model Diverging: loglikelihood decreased");
506517
break;
507518
}
508-
if (currLL - prevLL < llThreshold) {
519+
if (stopCriteria.test(currLL - prevLL)) {
520+
progressMonitor.finishedTraining(iterations, stopCriteria);
509521
break;
510522
}
511523
}
512524
prevLL = currLL;
513525
}
514526

527+
//At this point, all iterations have finished successfully.
528+
if (!progressMonitor.isTrainingFinished()) {
529+
progressMonitor.finishedTraining(iterations, null);
530+
}
531+
progressMonitor.displayAndClear();
532+
515533
// kill a bunch of these big objects now that we don't need them
516534
observedExpects = null;
517535
modelExpects = null;
@@ -628,8 +646,8 @@ private double nextIteration(double correctionConstant,
628646
}
629647
}
630648

631-
logger.info("{} - loglikelihood={}\t{}",
632-
iteration, loglikelihood, ((double) numCorrect / numEvents));
649+
getTrainingProgressMonitor(trainingConfiguration).
650+
finishedIteration(iteration, numCorrect, numEvents, TrainingMeasure.LOG_LIKELIHOOD, loglikelihood);
633651

634652
return loglikelihood;
635653
}
@@ -709,4 +727,25 @@ synchronized double getLoglikelihood() {
709727
return loglikelihood;
710728
}
711729
}
730+
731+
/**
732+
* Get the {@link StopCriteria} associated with this Trainer.
733+
* @param trainingConfig - If {@link TrainingConfiguration} is null or
734+
* {@link TrainingConfiguration#stopCriteria()} is null then return a default {@link StopCriteria}.
735+
*/
736+
private StopCriteria getStopCriteria(TrainingConfiguration trainingConfig) {
737+
return trainingConfig != null && trainingConfig.stopCriteria() != null
738+
? trainingConfig.stopCriteria() : new LogLikelihoodThresholdBreached(trainingParameters);
739+
}
740+
741+
/**
742+
* Get the {@link TrainingProgressMonitor} associated with this Trainer.
743+
* @param trainingConfig If {@link TrainingConfiguration} is null or
744+
* {@link TrainingConfiguration#progMon()} is null then return a default {@link TrainingProgressMonitor}.
745+
*/
746+
private TrainingProgressMonitor getTrainingProgressMonitor(TrainingConfiguration trainingConfig) {
747+
return trainingConfig != null && trainingConfig.progMon() != null ?
748+
trainingConfig.progMon() : new DefaultTrainingProgressMonitor();
749+
}
750+
712751
}

0 commit comments

Comments
 (0)