diff --git a/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java b/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java index efd8ee76e..ce8ed0224 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java @@ -19,6 +19,7 @@ import java.util.Map; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; /** @@ -35,4 +36,14 @@ public interface Trainer { */ void init(TrainingParameters trainParams, Map reportMap); + /** + * Conducts the initialization of a {@link Trainer} via + * {@link TrainingParameters}, {@link Map report map} and {@link TrainingConfiguration} + * + * @param trainParams The {@link TrainingParameters} to use. + * @param reportMap The {@link Map} instance used as report map. + * @param config The {@link TrainingConfiguration} to use. If null, suitable defaults will be used. + */ + void init(TrainingParameters trainParams, Map reportMap, TrainingConfiguration config); + } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java index 54e315c84..2401e35f4 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java @@ -22,12 +22,14 @@ import opennlp.tools.commons.Trainer; import opennlp.tools.ml.maxent.GISTrainer; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; public abstract class AbstractTrainer implements Trainer { protected TrainingParameters trainingParameters; protected Map reportMap; + protected TrainingConfiguration trainingConfiguration; public AbstractTrainer() { } @@ -55,6 +57,16 @@ public void init(TrainingParameters trainParams, Map reportMap) { this.reportMap = reportMap; } + /** + * {@inheritDoc} + */ + @Override + public void init(TrainingParameters trainParams, Map reportMap, + TrainingConfiguration config) { + init(trainParams, reportMap); + this.trainingConfiguration = config; + } + /** * @return Retrieves the configured {@link TrainingParameters#ALGORITHM_PARAM} value. */ @@ -108,4 +120,12 @@ protected void addToReport(String key, String value) { reportMap.put(key, value); } + /** + * Retrieves the {@link TrainingConfiguration} associated with an {@link AbstractTrainer}. + * @return {@link TrainingConfiguration} + */ + public TrainingConfiguration getTrainingConfiguration() { + return trainingConfiguration; + } + } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java index b47e3a757..6aaf73b60 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java @@ -26,6 +26,8 @@ import opennlp.tools.ml.naivebayes.NaiveBayesTrainer; import opennlp.tools.ml.perceptron.PerceptronTrainer; import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer; +import opennlp.tools.monitoring.DefaultTrainingProgressMonitor; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; import opennlp.tools.util.ext.ExtensionLoader; import opennlp.tools.util.ext.ExtensionNotLoadedException; @@ -62,12 +64,11 @@ public enum TrainerType { * {@link TrainingParameters#ALGORITHM_PARAM} value. * * @param trainParams - A mapping of {@link TrainingParameters training parameters}. - * * @return The {@link TrainerType} or {@code null} if the type couldn't be determined. */ public static TrainerType getTrainerType(TrainingParameters trainParams) { - String algorithmValue = trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM,null); + String algorithmValue = trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM, null); // Check if it is defaulting to the MAXENT trainer if (algorithmValue == null) { @@ -80,11 +81,9 @@ public static TrainerType getTrainerType(TrainingParameters trainParams) { if (EventTrainer.class.isAssignableFrom(trainerClass)) { return TrainerType.EVENT_MODEL_TRAINER; - } - else if (EventModelSequenceTrainer.class.isAssignableFrom(trainerClass)) { + } else if (EventModelSequenceTrainer.class.isAssignableFrom(trainerClass)) { return TrainerType.EVENT_MODEL_SEQUENCE_TRAINER; - } - else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) { + } else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) { return TrainerType.SEQUENCE_TRAINER; } } @@ -94,24 +93,21 @@ else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) { try { ExtensionLoader.instantiateExtension(EventTrainer.class, algorithmValue); return TrainerType.EVENT_MODEL_TRAINER; - } - catch (ExtensionNotLoadedException ignored) { + } catch (ExtensionNotLoadedException ignored) { // this is ignored } try { ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class, algorithmValue); return TrainerType.EVENT_MODEL_SEQUENCE_TRAINER; - } - catch (ExtensionNotLoadedException ignored) { + } catch (ExtensionNotLoadedException ignored) { // this is ignored } try { ExtensionLoader.instantiateExtension(SequenceTrainer.class, algorithmValue); return TrainerType.SEQUENCE_TRAINER; - } - catch (ExtensionNotLoadedException ignored) { + } catch (ExtensionNotLoadedException ignored) { // this is ignored } @@ -124,15 +120,14 @@ else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) { * @param trainParams The {@link TrainingParameters} to check for the trainer type. * Note: The entry {@link TrainingParameters#ALGORITHM_PARAM} is used * to determine the type. - * @param reportMap A {@link Map} that shall be used during initialization of - * the {@link SequenceTrainer}. - * + * @param reportMap A {@link Map} that shall be used during initialization of + * the {@link SequenceTrainer}. * @return A valid {@link SequenceTrainer} for the configured {@code trainParams}. * @throws IllegalArgumentException Thrown if the trainer type could not be determined. */ public static SequenceTrainer getSequenceModelTrainer( - TrainingParameters trainParams, Map reportMap) { - String trainerType = trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM,null); + TrainingParameters trainParams, Map reportMap) { + String trainerType = trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM, null); if (trainerType != null) { final SequenceTrainer trainer; @@ -143,8 +138,7 @@ public static SequenceTrainer getSequenceModelTrainer( } trainer.init(trainParams, reportMap); return trainer; - } - else { + } else { throw new IllegalArgumentException("Trainer type couldn't be determined!"); } } @@ -155,15 +149,14 @@ public static SequenceTrainer getSequenceModelTrainer( * @param trainParams The {@link TrainingParameters} to check for the trainer type. * Note: The entry {@link TrainingParameters#ALGORITHM_PARAM} is used * to determine the type. - * @param reportMap A {@link Map} that shall be used during initialization of - * the {@link EventModelSequenceTrainer}. - * + * @param reportMap A {@link Map} that shall be used during initialization of + * the {@link EventModelSequenceTrainer}. * @return A valid {@link EventModelSequenceTrainer} for the configured {@code trainParams}. * @throws IllegalArgumentException Thrown if the trainer type could not be determined. */ public static EventModelSequenceTrainer getEventModelSequenceTrainer( - TrainingParameters trainParams, Map reportMap) { - String trainerType = trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM,null); + TrainingParameters trainParams, Map reportMap) { + String trainerType = trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM, null); if (trainerType != null) { final EventModelSequenceTrainer trainer; @@ -174,12 +167,23 @@ public static EventModelSequenceTrainer getEventModelSequenceTrainer( } trainer.init(trainParams, reportMap); return trainer; - } - else { + } else { throw new IllegalArgumentException("Trainer type couldn't be determined!"); } } + /** + * Works just like {@link TrainerFactory#getEventTrainer(TrainingParameters, Map, TrainingConfiguration)} + * except that {@link TrainingConfiguration} is initialized with default values. + */ + public static EventTrainer getEventTrainer( + TrainingParameters trainParams, Map reportMap) { + + TrainingConfiguration trainingConfiguration + = new TrainingConfiguration(new DefaultTrainingProgressMonitor(), null); + return getEventTrainer(trainParams, reportMap, trainingConfiguration); + } + /** * Retrieves an {@link EventTrainer} that fits the given parameters. * @@ -187,13 +191,13 @@ public static EventModelSequenceTrainer getEventModelSequenceTrainer( * Note: The entry {@link TrainingParameters#ALGORITHM_PARAM} is used * to determine the type. If the type is not defined, the * {@link GISTrainer#MAXENT_VALUE} will be used. - * @param reportMap A {@link Map} that shall be used during initialization of - * the {@link EventTrainer}. - * + * @param reportMap A {@link Map} that shall be used during initialization of + * the {@link EventTrainer}. + * @param config The {@link TrainingConfiguration} to be used. * @return A valid {@link EventTrainer} for the configured {@code trainParams}. */ public static EventTrainer getEventTrainer( - TrainingParameters trainParams, Map reportMap) { + TrainingParameters trainParams, Map reportMap, TrainingConfiguration config) { // if the trainerType is not defined -- use the GISTrainer. String trainerType = trainParams.getStringParameter( @@ -205,7 +209,7 @@ public static EventTrainer getEventTrainer( } else { trainer = ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType); } - trainer.init(trainParams, reportMap); + trainer.init(trainParams, reportMap, config); return trainer; } @@ -232,8 +236,7 @@ public static boolean isValid(TrainingParameters trainParams) { TrainingParameters.CUTOFF_DEFAULT_VALUE); trainParams.getIntParameter(TrainingParameters.ITERATIONS_PARAM, TrainingParameters.ITERATIONS_DEFAULT_VALUE); - } - catch (NumberFormatException e) { + } catch (NumberFormatException e) { return false; } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java index d2eabeb91..ae71d942f 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java @@ -30,6 +30,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import opennlp.tools.commons.Trainer; import opennlp.tools.ml.AbstractEventTrainer; import opennlp.tools.ml.ArrayMath; import opennlp.tools.ml.model.DataIndexer; @@ -40,7 +41,13 @@ import opennlp.tools.ml.model.OnePassDataIndexer; import opennlp.tools.ml.model.Prior; import opennlp.tools.ml.model.UniformPrior; +import opennlp.tools.monitoring.DefaultTrainingProgressMonitor; +import opennlp.tools.monitoring.LogLikelihoodThresholdBreached; +import opennlp.tools.monitoring.StopCriteria; +import opennlp.tools.monitoring.TrainingMeasure; +import opennlp.tools.monitoring.TrainingProgressMonitor; import opennlp.tools.util.ObjectStream; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; /** @@ -497,6 +504,11 @@ private void findParameters(int iterations, double correctionConstant) { new ExecutorCompletionService<>(executor); double prevLL = 0.0; double currLL; + + //Get the Training Progress Monitor and the StopCriteria. + TrainingProgressMonitor progressMonitor = getTrainingProgressMonitor(trainingConfiguration); + StopCriteria stopCriteria = getStopCriteria(trainingConfiguration); + logger.info("Performing {} iterations.", iterations); for (int i = 1; i <= iterations; i++) { currLL = nextIteration(correctionConstant, completionService, i); @@ -505,13 +517,20 @@ private void findParameters(int iterations, double correctionConstant) { logger.warn("Model Diverging: loglikelihood decreased"); break; } - if (currLL - prevLL < llThreshold) { + if (stopCriteria.test(currLL - prevLL)) { + progressMonitor.finishedTraining(iterations, stopCriteria); break; } } prevLL = currLL; } + //At this point, all iterations have finished successfully. + if (!progressMonitor.isTrainingFinished()) { + progressMonitor.finishedTraining(iterations, null); + } + progressMonitor.display(true); + // kill a bunch of these big objects now that we don't need them observedExpects = null; modelExpects = null; @@ -628,8 +647,8 @@ private double nextIteration(double correctionConstant, } } - logger.info("{} - loglikelihood={}\t{}", - iteration, loglikelihood, ((double) numCorrect / numEvents)); + getTrainingProgressMonitor(trainingConfiguration). + finishedIteration(iteration, numCorrect, numEvents, TrainingMeasure.LOG_LIKELIHOOD, loglikelihood); return loglikelihood; } @@ -709,4 +728,31 @@ synchronized double getLoglikelihood() { return loglikelihood; } } + + /** + * Get the {@link StopCriteria} associated with this {@link Trainer}. + * + * @param trainingConfig {@link TrainingConfiguration} + * @return {@link StopCriteria}. If {@link TrainingConfiguration} is {@code null} or + * {@link TrainingConfiguration#stopCriteria()} is {@code null}, + * then return the default {@link StopCriteria}. + */ + private StopCriteria getStopCriteria(TrainingConfiguration trainingConfig) { + return trainingConfig != null && trainingConfig.stopCriteria() != null + ? trainingConfig.stopCriteria() : new LogLikelihoodThresholdBreached(trainingParameters); + } + + /** + * Get the {@link TrainingProgressMonitor} associated with this {@link Trainer}. + * + * @param trainingConfig {@link TrainingConfiguration}. + * @return {@link TrainingProgressMonitor}. If {@link TrainingConfiguration} is {@code null} or + * {@link TrainingConfiguration#progMon()} is {@code null}, + * then return the default {@link TrainingProgressMonitor}. + */ + private TrainingProgressMonitor getTrainingProgressMonitor(TrainingConfiguration trainingConfig) { + return trainingConfig != null && trainingConfig.progMon() != null ? + trainingConfig.progMon() : new DefaultTrainingProgressMonitor(); + } + } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java index d4d51b6c2..9958e0359 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java @@ -22,12 +22,19 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import opennlp.tools.commons.Trainer; import opennlp.tools.ml.AbstractEventTrainer; import opennlp.tools.ml.ArrayMath; import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.DataIndexer; import opennlp.tools.ml.model.EvalParameters; import opennlp.tools.ml.model.MutableContext; +import opennlp.tools.monitoring.DefaultTrainingProgressMonitor; +import opennlp.tools.monitoring.IterDeltaAccuracyUnderTolerance; +import opennlp.tools.monitoring.StopCriteria; +import opennlp.tools.monitoring.TrainingMeasure; +import opennlp.tools.monitoring.TrainingProgressMonitor; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; /** @@ -293,6 +300,10 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) { } } + //Get the Training Progress Monitor and the StopCriteria. + TrainingProgressMonitor progressMonitor = getTrainingProgressMonitor(trainingConfiguration); + StopCriteria stopCriteria = getStopCriteria(trainingConfiguration); + // Keep track of the previous three accuracies. The difference of // the mean of these and the current training set accuracy is used // with tolerance to decide whether to stop. @@ -349,10 +360,12 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) { } } - // Calculate the training accuracy and display. + // Calculate the training accuracy. double trainingAccuracy = (double) numCorrect / numEvents; - if (i < 10 || (i % 10) == 0) - logger.info("{}: ({}/{}) {}", i, numCorrect, numEvents, trainingAccuracy); + if (i < 10 || (i % 10) == 0) { + progressMonitor.finishedIteration(i, numCorrect, numEvents, + TrainingMeasure.ACCURACY, trainingAccuracy); + } // TODO: Make averaging configurable !!! @@ -370,10 +383,10 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) { // If the tolerance is greater than the difference between the // current training accuracy and all of the previous three // training accuracies, stop training. - if (StrictMath.abs(prevAccuracy1 - trainingAccuracy) < tolerance - && StrictMath.abs(prevAccuracy2 - trainingAccuracy) < tolerance - && StrictMath.abs(prevAccuracy3 - trainingAccuracy) < tolerance) { - logger.warn("Stopping: change in training set accuracy less than {}", tolerance); + if (stopCriteria.test(prevAccuracy1 - trainingAccuracy) + && stopCriteria.test(prevAccuracy2 - trainingAccuracy) + && stopCriteria.test(prevAccuracy3 - trainingAccuracy)) { + progressMonitor.finishedTraining(iterations, stopCriteria); break; } @@ -383,6 +396,12 @@ private MutableContext[] findParameters(int iterations, boolean useAverage) { prevAccuracy3 = trainingAccuracy; } + //At this point, all iterations have finished successfully. + if (!progressMonitor.isTrainingFinished()) { + progressMonitor.finishedTraining(iterations, null); + } + progressMonitor.display(true); + // Output the final training stats. trainingStats(evalParams); @@ -432,4 +451,30 @@ private static boolean isPerfectSquare(int n) { return root * root == n; } + /** + * Get the {@link StopCriteria} associated with this {@link Trainer}. + * + * @param trainingConfig {@link TrainingConfiguration} + * @return {@link StopCriteria}. If {@link TrainingConfiguration} is {@code null} or + * {@link TrainingConfiguration#stopCriteria()} is {@code null}, + * then return the default {@link StopCriteria}. + */ + private StopCriteria getStopCriteria(TrainingConfiguration trainingConfig) { + return trainingConfig != null && trainingConfig.stopCriteria() != null + ? trainingConfig.stopCriteria() : new IterDeltaAccuracyUnderTolerance(trainingParameters); + } + + /** + * Get the {@link TrainingProgressMonitor} associated with this {@link Trainer}. + * + * @param trainingConfig {@link TrainingConfiguration}. + * @return {@link TrainingProgressMonitor}. If {@link TrainingConfiguration} is {@code null} or + * {@link TrainingConfiguration#progMon()} is {@code null}, + * then return the default {@link TrainingProgressMonitor}. + */ + private TrainingProgressMonitor getTrainingProgressMonitor(TrainingConfiguration trainingConfig) { + return trainingConfig != null && trainingConfig.progMon() != null ? trainingConfig.progMon() : + new DefaultTrainingProgressMonitor(); + } + } diff --git a/opennlp-tools/src/main/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitor.java b/opennlp-tools/src/main/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitor.java new file mode 100644 index 000000000..2d1082519 --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitor.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + + +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static opennlp.tools.monitoring.StopCriteria.FINISHED; + +/** + * The default implementation of {@link TrainingProgressMonitor}. + * This publishes model training progress to the chosen logging destination. + */ +public class DefaultTrainingProgressMonitor implements TrainingProgressMonitor { + + private static final Logger logger = LoggerFactory.getLogger(DefaultTrainingProgressMonitor.class); + + /** + * Keeps a track of whether training was already finished. + */ + private volatile boolean isTrainingFinished; + + /** + * An underlying list to capture training progress events. + */ + private final List progress; + + public DefaultTrainingProgressMonitor() { + this.progress = new LinkedList<>(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void finishedIteration(int iteration, int numberCorrectEvents, int totalEvents, + TrainingMeasure measure, double measureValue) { + progress.add(String.format("%s: (%s/%s) %s : %s", iteration, numberCorrectEvents, totalEvents, + measure.getMeasureName(), measureValue)); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void finishedTraining(int iterations, StopCriteria stopCriteria) { + if (!Objects.isNull(stopCriteria)) { + progress.add(stopCriteria.getMessageIfSatisfied()); + } else { + progress.add(String.format(FINISHED, iterations)); + } + isTrainingFinished = true; + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void display(boolean clear) { + progress.stream().forEach(logger::info); + if (clear) { + progress.clear(); + } + } + + /** + * {@inheritDoc} + */ + @Override + public boolean isTrainingFinished() { + return isTrainingFinished; + } +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderTolerance.java b/opennlp-tools/src/main/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderTolerance.java new file mode 100644 index 000000000..958a27b36 --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderTolerance.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +import opennlp.tools.ml.perceptron.PerceptronTrainer; +import opennlp.tools.util.TrainingParameters; + +/** + * A {@link StopCriteria} implementation to identify whether the absolute + * difference between the training accuracy of current and previous iteration is under the defined tolerance. + */ +public class IterDeltaAccuracyUnderTolerance implements StopCriteria { + + public static final String STOP = "Stopping: change in training set accuracy less than {%s}"; + private final TrainingParameters trainingParameters; + + public IterDeltaAccuracyUnderTolerance(TrainingParameters trainingParameters) { + this.trainingParameters = trainingParameters; + } + + @Override + public String getMessageIfSatisfied() { + return String.format(STOP, getTolerance()); + } + + @Override + public boolean test(Double deltaAccuracy) { + return StrictMath.abs(deltaAccuracy) < getTolerance(); + } + + private double getTolerance() { + return trainingParameters != null ? trainingParameters.getDoubleParameter("Tolerance", + PerceptronTrainer.TOLERANCE_DEFAULT) : PerceptronTrainer.TOLERANCE_DEFAULT; + } + +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreached.java b/opennlp-tools/src/main/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreached.java new file mode 100644 index 000000000..f6f348965 --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreached.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +import opennlp.tools.util.TrainingParameters; + +import static opennlp.tools.ml.maxent.GISTrainer.LOG_LIKELIHOOD_THRESHOLD_DEFAULT; +import static opennlp.tools.ml.maxent.GISTrainer.LOG_LIKELIHOOD_THRESHOLD_PARAM; + +/** + * A {@link StopCriteria} implementation to identify whether the + * difference between the log likelihood of current and previous iteration is under the defined threshold. + */ +public class LogLikelihoodThresholdBreached implements StopCriteria { + + public static String STOP = "Stopping: Difference between log likelihood of current" + + " and previous iteration is less than threshold %s ."; + + private final TrainingParameters trainingParameters; + + public LogLikelihoodThresholdBreached(TrainingParameters trainingParameters) { + this.trainingParameters = trainingParameters; + } + + @Override + public String getMessageIfSatisfied() { + return String.format(STOP, getThreshold()); + + } + + @Override + public boolean test(Double currVsPrevLLDiff) { + return currVsPrevLLDiff < getThreshold(); + } + + private double getThreshold() { + return trainingParameters != null ? trainingParameters.getDoubleParameter(LOG_LIKELIHOOD_THRESHOLD_PARAM, + LOG_LIKELIHOOD_THRESHOLD_DEFAULT) : LOG_LIKELIHOOD_THRESHOLD_DEFAULT; + } + +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/monitoring/StopCriteria.java b/opennlp-tools/src/main/java/opennlp/tools/monitoring/StopCriteria.java new file mode 100644 index 000000000..576aa7e59 --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/StopCriteria.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +import java.util.function.Predicate; + +import opennlp.tools.ml.model.AbstractModel; + + +/** + * Stop criteria for model training. If the predicate is met, then the training is aborted. + * + * @see Predicate + * @see AbstractModel + */ +public interface StopCriteria extends Predicate { + + String FINISHED = "Training Finished after completing %s Iterations successfully."; + + /** + * @return A detailed message captured upon hitting the {@link StopCriteria} during model training. + */ + String getMessageIfSatisfied(); + +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingMeasure.java b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingMeasure.java new file mode 100644 index 000000000..9b61ed2ed --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingMeasure.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +/** + * Enumeration of Training measures. + */ +public enum TrainingMeasure { + ACCURACY("Training Accuracy"), + LOG_LIKELIHOOD("Log Likelihood"); + + private String measureName; + + TrainingMeasure(String measureName) { + this.measureName = measureName; + } + + public String getMeasureName() { + return measureName; + } +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingProgressMonitor.java b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingProgressMonitor.java new file mode 100644 index 000000000..be35b78a7 --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingProgressMonitor.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +import opennlp.tools.ml.model.AbstractModel; + +/** + * An interface to capture training progress of an {@link AbstractModel}. + */ + +public interface TrainingProgressMonitor { + + /** + * Captures the Iteration progress. + * + * @param iteration The completed iteration number. + * @param numberCorrectEvents Number of correctly predicted events in this iteration. + * @param totalEvents Total count of events processed in this iteration. + * @param measure {@link TrainingMeasure}. + * @param measureValue measure value corresponding to the applicable {@link TrainingMeasure}. + */ + void finishedIteration(int iteration, int numberCorrectEvents, int totalEvents, + TrainingMeasure measure, double measureValue); + + /** + * Captures the training completion progress. + * + * @param iterations Total number of iterations configured for the training. + * @param stopCriteria {@link StopCriteria} for the training. + */ + void finishedTraining(int iterations, StopCriteria stopCriteria); + + /** + * Checks whether the training has finished. + * + * @return {@code true} if the training has finished, {@code false} if the training is not yet completed. + */ + boolean isTrainingFinished(); + + /** + * Displays the training progress and optionally clears the recorded progress (to save memory). + * Callers of this method can invoke it periodically + * during training, to avoid holding too much progress related data in memory. + * + * @param clear Set to true to clear the recorded progress. + */ + void display(boolean clear); +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/util/TrainingConfiguration.java b/opennlp-tools/src/main/java/opennlp/tools/util/TrainingConfiguration.java new file mode 100644 index 000000000..f3e05cdc4 --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/util/TrainingConfiguration.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.util; + +import opennlp.tools.ml.model.AbstractModel; +import opennlp.tools.monitoring.StopCriteria; +import opennlp.tools.monitoring.TrainingProgressMonitor; + +/** + * Configuration used for {@link AbstractModel} training. + * @param progMon {@link TrainingProgressMonitor} used to monitor the training progress. + * @param stopCriteria {@link StopCriteria} used to abort training when the criteria is met. + */ +public record TrainingConfiguration(TrainingProgressMonitor progMon, StopCriteria stopCriteria) {} diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java b/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java index 7a7b63832..56903c1b0 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java @@ -23,6 +23,7 @@ import opennlp.tools.ml.model.Event; import opennlp.tools.ml.model.MaxentModel; import opennlp.tools.util.ObjectStream; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; public class MockEventTrainer implements EventTrainer { @@ -39,4 +40,10 @@ public MaxentModel train(DataIndexer indexer) { @Override public void init(TrainingParameters trainingParams, Map reportMap) { } + + @Override + public void init(TrainingParameters trainParams, Map reportMap, + TrainingConfiguration config) { + } + } diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java index 0d26ffbc1..26e65b508 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java @@ -22,6 +22,7 @@ import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.Event; import opennlp.tools.ml.model.SequenceStream; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; public class MockSequenceTrainer implements EventModelSequenceTrainer { @@ -34,5 +35,10 @@ public AbstractModel train(SequenceStream events) { @Override public void init(TrainingParameters trainParams, Map reportMap) { } - + + @Override + public void init(TrainingParameters trainParams, Map reportMap, + TrainingConfiguration config) { + } + } diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java index a8f1224a6..72388b4eb 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java @@ -24,8 +24,14 @@ import opennlp.tools.ml.TrainerFactory.TrainerType; import opennlp.tools.ml.maxent.GISTrainer; import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer; +import opennlp.tools.monitoring.DefaultTrainingProgressMonitor; +import opennlp.tools.monitoring.LogLikelihoodThresholdBreached; +import opennlp.tools.util.TrainingConfiguration; import opennlp.tools.util.TrainingParameters; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertTrue; + public class TrainerFactoryTest { private TrainingParameters mlParams; @@ -78,4 +84,18 @@ void testIsSequenceTrainerFalse() { Assertions.assertNotEquals(TrainerType.EVENT_MODEL_SEQUENCE_TRAINER, trainerType); } + @Test + void testGetEventTrainerConfiguration() { + mlParams.put(TrainingParameters.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); + + TrainingConfiguration config = new TrainingConfiguration(new DefaultTrainingProgressMonitor(), + new LogLikelihoodThresholdBreached(mlParams)); + + AbstractTrainer trainer = (AbstractTrainer) TrainerFactory.getEventTrainer(mlParams, null, config); + + assertAll(() -> assertTrue(trainer.getTrainingConfiguration().progMon() instanceof + DefaultTrainingProgressMonitor), + () -> assertTrue(trainer.getTrainingConfiguration().stopCriteria() instanceof + LogLikelihoodThresholdBreached)); + } } diff --git a/opennlp-tools/src/test/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitorTest.java b/opennlp-tools/src/test/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitorTest.java new file mode 100644 index 000000000..59c68a6af --- /dev/null +++ b/opennlp-tools/src/test/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitorTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +import java.util.List; +import java.util.Map; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; + +import opennlp.tools.util.TrainingParameters; + +import static java.util.stream.Collectors.toList; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class DefaultTrainingProgressMonitorTest { + + private static final String LOGGER_NAME = "opennlp"; + private static final Logger logger = (Logger) LoggerFactory.getLogger(LOGGER_NAME); + private static final Level originalLogLevel = logger != null ? logger.getLevel() : Level.OFF; + + private TrainingProgressMonitor progressMonitor; + private final ListAppender appender = new ListAppender<>(); + + + @BeforeAll + static void beforeAll() { + logger.setLevel(Level.INFO); + } + + @BeforeEach + public void setup() { + progressMonitor = new DefaultTrainingProgressMonitor(); + appender.list.clear(); + logger.addAppender(appender); + appender.start(); + } + + @Test + void testFinishedIteration() { + progressMonitor.finishedIteration(1, 19830, 20801, TrainingMeasure.ACCURACY, 0.953319551944618); + progressMonitor.finishedIteration(2, 19852, 20801, TrainingMeasure.ACCURACY, 0.9543771934041633); + progressMonitor.display(true); + + //Assert that two logging events are captured for two iterations. + List actual = appender.list.stream().map(ILoggingEvent::getMessage). + collect(toList()); + List expected = List.of("1: (19830/20801) Training Accuracy : 0.953319551944618", + "2: (19852/20801) Training Accuracy : 0.9543771934041633"); + assertArrayEquals(expected.toArray(), actual.toArray()); + + } + + @Test + void testFinishedTrainingWithStopCriteria() { + StopCriteria stopCriteria = new IterDeltaAccuracyUnderTolerance(new TrainingParameters(Map.of("Tolerance", + .00002))); + progressMonitor.finishedTraining(150, stopCriteria); + progressMonitor.display(true); + + //Assert that the logs captured the training completion message with StopCriteria satisfied. + List actual = appender.list.stream().map(ILoggingEvent::getMessage). + collect(toList()); + List expected = List.of("Stopping: change in training set accuracy less than {2.0E-5}"); + assertArrayEquals(expected.toArray(), actual.toArray()); + } + + @Test + void testFinishedTrainingWithoutStopCriteria() { + progressMonitor.finishedTraining(150, null); + progressMonitor.display(true); + + //Assert that the logs captured the training completion message when all iterations are exhausted. + List actual = appender.list.stream().map(ILoggingEvent::getMessage). + collect(toList()); + List expected = List.of("Training Finished after completing 150 Iterations successfully."); + assertArrayEquals(expected.toArray(), actual.toArray()); + } + + @Test + void displayAndClear() { + progressMonitor.finishedTraining(150, null); + progressMonitor.display(true); + + //Assert that the previous invocation of display has cleared the recorded training progress. + appender.list.clear(); + progressMonitor.display(true); + assertEquals(0, appender.list.size()); + } + + @Test + void displayAndKeep() { + progressMonitor.finishedTraining(150, null); + progressMonitor.display(false); + + //Assert that the previous invocation of display has not cleared the recorded training progress. + progressMonitor.display(false); + assertEquals(2, appender.list.size()); + } + + @AfterAll + static void afterAll() { + logger.setLevel(originalLogLevel); + } +} diff --git a/opennlp-tools/src/test/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderToleranceTest.java b/opennlp-tools/src/test/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderToleranceTest.java new file mode 100644 index 000000000..4ca7c2eb7 --- /dev/null +++ b/opennlp-tools/src/test/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderToleranceTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import opennlp.tools.util.TrainingParameters; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class IterDeltaAccuracyUnderToleranceTest { + + private StopCriteria stopCriteria; + + @BeforeEach + public void setup() { + stopCriteria = new IterDeltaAccuracyUnderTolerance(new TrainingParameters(Map.of("Tolerance", + .00002))); + } + + @ParameterizedTest + @CsvSource( {"0.01,false", "-0.01,false", "0.00001,true", "-0.00001,true"}) + void testCriteria(double val, String expectedVal) { + assertEquals(Boolean.parseBoolean(expectedVal), stopCriteria.test(val)); + } + + @Test + void testMessageIfSatisfied() { + assertEquals("Stopping: change in training set accuracy less than {2.0E-5}", + stopCriteria.getMessageIfSatisfied()); + } +} diff --git a/opennlp-tools/src/test/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreachedTest.java b/opennlp-tools/src/test/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreachedTest.java new file mode 100644 index 000000000..7786847b6 --- /dev/null +++ b/opennlp-tools/src/test/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreachedTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.monitoring; + +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import opennlp.tools.util.TrainingParameters; + +import static opennlp.tools.ml.maxent.GISTrainer.LOG_LIKELIHOOD_THRESHOLD_PARAM; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +class LogLikelihoodThresholdBreachedTest { + + private StopCriteria stopCriteria; + + @BeforeEach + public void setup() { + stopCriteria = new LogLikelihoodThresholdBreached( + new TrainingParameters(Map.of(LOG_LIKELIHOOD_THRESHOLD_PARAM,5.))); + } + + @ParameterizedTest + @CsvSource({"0.01,true", "-0.01,true", "6.0,false", "-6.0,true"}) + void testCriteria(double val, String expectedVal) { + assertEquals(Boolean.parseBoolean(expectedVal), stopCriteria.test(val)); + } + + @Test + void testMessageIfSatisfied() { + assertEquals("Stopping: Difference between log likelihood of current" + + " and previous iteration is less than threshold 5.0 .", + stopCriteria.getMessageIfSatisfied()); + } + +}