Skip to content

OPENNLP-124: Maxent/Perceptron training should report progess back via an API #758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions opennlp-tools/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj-core.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down
11 changes: 11 additions & 0 deletions opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.Map;

import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;

/**
Expand All @@ -35,4 +36,14 @@ public interface Trainer {
*/
void init(TrainingParameters trainParams, Map<String, String> reportMap);

/**
* Conducts the initialization of a {@link Trainer} via
* {@link TrainingParameters}, {@link Map report map} and {@link TrainingConfiguration}
*
* @param trainParams The {@link TrainingParameters} to use.
* @param reportMap The {@link Map} instance used as report map.
* @param config The {@link TrainingConfiguration} to use.
*/
void init(TrainingParameters trainParams, Map<String, String> reportMap, TrainingConfiguration config);

}
24 changes: 24 additions & 0 deletions opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@

import opennlp.tools.commons.Trainer;
import opennlp.tools.ml.maxent.GISTrainer;
import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;

public abstract class AbstractTrainer implements Trainer {

protected TrainingParameters trainingParameters;
protected Map<String,String> reportMap;
protected TrainingConfiguration trainingConfiguration;

public AbstractTrainer() {
}
Expand Down Expand Up @@ -55,6 +57,20 @@ public void init(TrainingParameters trainParams, Map<String,String> reportMap) {
this.reportMap = reportMap;
}

/**
* Initializes a {@link AbstractTrainer} using following parameters.
*
* @param trainParams The {@link TrainingParameters} to use.
* @param reportMap The {@link Map} instance used as report map.
* @param config The {@link TrainingConfiguration} to use.
*/
@Override
public void init(TrainingParameters trainParams, Map<String, String> reportMap,
TrainingConfiguration config) {
init(trainParams, reportMap);
this.trainingConfiguration = config;
}

/**
* @return Retrieves the configured {@link TrainingParameters#ALGORITHM_PARAM} value.
*/
Expand Down Expand Up @@ -108,4 +124,12 @@ protected void addToReport(String key, String value) {
reportMap.put(key, value);
}

/**
* Retrieves the {@link TrainingConfiguration} associated with a {@link AbstractTrainer}.
* @return {@link TrainingConfiguration}
*/
public TrainingConfiguration getTrainingConfiguration() {
return trainingConfiguration;
}

}
25 changes: 23 additions & 2 deletions opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import opennlp.tools.ml.naivebayes.NaiveBayesTrainer;
import opennlp.tools.ml.perceptron.PerceptronTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.ext.ExtensionLoader;
import opennlp.tools.util.ext.ExtensionNotLoadedException;
Expand Down Expand Up @@ -180,6 +182,22 @@ public static <T> EventModelSequenceTrainer<T> getEventModelSequenceTrainer(
}
}

/**
* Works like {@link TrainerFactory#getEventTrainer(TrainingParameters, Map, TrainingConfiguration)}
* except that the {@link TrainingConfiguration} is initialized with {@link DefaultTrainingProgressMonitor}
* and a null {@link opennlp.tools.monitoring.StopCriteria}.
* If not provided, the actual {@link opennlp.tools.monitoring.StopCriteria}
* will be decided by the {@link EventTrainer} implementation.
*
*/
public static EventTrainer getEventTrainer(
TrainingParameters trainParams, Map<String, String> reportMap) {

TrainingConfiguration trainingConfiguration
= new TrainingConfiguration(new DefaultTrainingProgressMonitor(), null);
return getEventTrainer(trainParams, reportMap, trainingConfiguration);
}

/**
* Retrieves an {@link EventTrainer} that fits the given parameters.
*
Expand All @@ -189,11 +207,14 @@ public static <T> EventModelSequenceTrainer<T> getEventModelSequenceTrainer(
* {@link GISTrainer#MAXENT_VALUE} will be used.
* @param reportMap A {@link Map} that shall be used during initialization of
* the {@link EventTrainer}.
* @param config The {@link TrainingConfiguration} to be used. This determines the type of
* {@link opennlp.tools.monitoring.TrainingProgressMonitor}
* and the {@link opennlp.tools.monitoring.StopCriteria} to be used.
*
* @return A valid {@link EventTrainer} for the configured {@code trainParams}.
*/
public static EventTrainer getEventTrainer(
TrainingParameters trainParams, Map<String, String> reportMap) {
TrainingParameters trainParams, Map<String, String> reportMap, TrainingConfiguration config) {

// if the trainerType is not defined -- use the GISTrainer.
String trainerType = trainParams.getStringParameter(
Expand All @@ -205,7 +226,7 @@ public static EventTrainer getEventTrainer(
} else {
trainer = ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType);
}
trainer.init(trainParams, reportMap);
trainer.init(trainParams, reportMap, config);
return trainer;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Prior;
import opennlp.tools.ml.model.UniformPrior;
import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
import opennlp.tools.monitoring.LogLikelihoodThresholdBreached;
import opennlp.tools.monitoring.StopCriteria;
import opennlp.tools.monitoring.TrainingMeasure;
import opennlp.tools.monitoring.TrainingProgressMonitor;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;

/**
Expand Down Expand Up @@ -497,6 +503,11 @@ private void findParameters(int iterations, double correctionConstant) {
new ExecutorCompletionService<>(executor);
double prevLL = 0.0;
double currLL;

//Get the Training Progress Monitor and the StopCriteria.
TrainingProgressMonitor progressMonitor = getTrainingProgressMonitor(trainingConfiguration);
StopCriteria stopCriteria = getStopCriteria(trainingConfiguration);

logger.info("Performing {} iterations.", iterations);
for (int i = 1; i <= iterations; i++) {
currLL = nextIteration(correctionConstant, completionService, i);
Expand All @@ -505,13 +516,20 @@ private void findParameters(int iterations, double correctionConstant) {
logger.warn("Model Diverging: loglikelihood decreased");
break;
}
if (currLL - prevLL < llThreshold) {
if (stopCriteria.test(currLL - prevLL)) {
progressMonitor.finishedTraining(iterations, stopCriteria);
break;
}
}
prevLL = currLL;
}

//At this point, all iterations have finished successfully.
if (!progressMonitor.isTrainingFinished()) {
progressMonitor.finishedTraining(iterations, null);
}
progressMonitor.displayAndClear();

// kill a bunch of these big objects now that we don't need them
observedExpects = null;
modelExpects = null;
Expand Down Expand Up @@ -628,8 +646,8 @@ private double nextIteration(double correctionConstant,
}
}

logger.info("{} - loglikelihood={}\t{}",
iteration, loglikelihood, ((double) numCorrect / numEvents));
getTrainingProgressMonitor(trainingConfiguration).
finishedIteration(iteration, numCorrect, numEvents, TrainingMeasure.LOG_LIKELIHOOD, loglikelihood);

return loglikelihood;
}
Expand Down Expand Up @@ -709,4 +727,25 @@ synchronized double getLoglikelihood() {
return loglikelihood;
}
}

/**
* Get the {@link StopCriteria} associated with this Trainer.
* @param trainingConfig - If {@link TrainingConfiguration} is null or
* {@link TrainingConfiguration#stopCriteria()} is null then return a default {@link StopCriteria}.
*/
private StopCriteria getStopCriteria(TrainingConfiguration trainingConfig) {
return trainingConfig != null && trainingConfig.stopCriteria() != null
? trainingConfig.stopCriteria() : new LogLikelihoodThresholdBreached(trainingParameters);
}

/**
* Get the {@link TrainingProgressMonitor} associated with this Trainer.
* @param trainingConfig If {@link TrainingConfiguration} is null or
* {@link TrainingConfiguration#progMon()} is null then return a default {@link TrainingProgressMonitor}.
*/
private TrainingProgressMonitor getTrainingProgressMonitor(TrainingConfiguration trainingConfig) {
return trainingConfig != null && trainingConfig.progMon() != null ?
trainingConfig.progMon() : new DefaultTrainingProgressMonitor();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
import opennlp.tools.monitoring.IterDeltaAccuracyUnderTolerance;
import opennlp.tools.monitoring.StopCriteria;
import opennlp.tools.monitoring.TrainingMeasure;
import opennlp.tools.monitoring.TrainingProgressMonitor;
import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;

/**
Expand Down Expand Up @@ -293,6 +299,10 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) {
}
}

//Get the Training Progress Monitor and the StopCriteria.
TrainingProgressMonitor progressMonitor = getTrainingProgressMonitor(trainingConfiguration);
StopCriteria stopCriteria = getStopCriteria(trainingConfiguration);

// Keep track of the previous three accuracies. The difference of
// the mean of these and the current training set accuracy is used
// with tolerance to decide whether to stop.
Expand Down Expand Up @@ -349,10 +359,12 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) {
}
}

// Calculate the training accuracy and display.
// Calculate the training accuracy.
double trainingAccuracy = (double) numCorrect / numEvents;
if (i < 10 || (i % 10) == 0)
logger.info("{}: ({}/{}) {}", i, numCorrect, numEvents, trainingAccuracy);
if (i < 10 || (i % 10) == 0) {
progressMonitor.finishedIteration(i, numCorrect, numEvents,
TrainingMeasure.ACCURACY, trainingAccuracy);
}

// TODO: Make averaging configurable !!!

Expand All @@ -370,10 +382,10 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) {
// If the tolerance is greater than the difference between the
// current training accuracy and all of the previous three
// training accuracies, stop training.
if (StrictMath.abs(prevAccuracy1 - trainingAccuracy) < tolerance
&& StrictMath.abs(prevAccuracy2 - trainingAccuracy) < tolerance
&& StrictMath.abs(prevAccuracy3 - trainingAccuracy) < tolerance) {
logger.warn("Stopping: change in training set accuracy less than {}", tolerance);
if (stopCriteria.test(prevAccuracy1 - trainingAccuracy)
&& stopCriteria.test(prevAccuracy2 - trainingAccuracy)
&& stopCriteria.test(prevAccuracy3 - trainingAccuracy)) {
progressMonitor.finishedTraining(iterations, stopCriteria);
break;
}

Expand All @@ -383,6 +395,12 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) {
prevAccuracy3 = trainingAccuracy;
}

//At this point, all iterations have finished successfully.
if (!progressMonitor.isTrainingFinished()) {
progressMonitor.finishedTraining(iterations, null);
}
progressMonitor.displayAndClear();

// Output the final training stats.
trainingStats(evalParams);

Expand Down Expand Up @@ -432,4 +450,25 @@ private static boolean isPerfectSquare(int n) {
return root * root == n;
}

/**
* Get the {@link StopCriteria} associated with this Trainer.
* @param trainingConfig - If {@link TrainingConfiguration} is null or
* {@link TrainingConfiguration#stopCriteria()} is null then return a default {@link StopCriteria}.
*/
private StopCriteria getStopCriteria(TrainingConfiguration trainingConfig) {
return trainingConfig != null && trainingConfig.stopCriteria() != null
? trainingConfig.stopCriteria() : new IterDeltaAccuracyUnderTolerance(trainingParameters);
}

/**
* Get the {@link TrainingProgressMonitor} associated with this Trainer.
* @param trainingConfig - If {@link TrainingConfiguration} is null or
* {@link TrainingConfiguration#progMon()}is null then
* return the default {@link TrainingProgressMonitor}.
*/
private TrainingProgressMonitor getTrainingProgressMonitor(TrainingConfiguration trainingConfig) {
return trainingConfig != null && trainingConfig.progMon() != null ? trainingConfig.progMon() :
new DefaultTrainingProgressMonitor();
}

}
Loading