Skip to content

Commit 8824e1b

Browse files
committed
Added a lot more to the training metrics to make them give better details and allow for subclassing better
1 parent ac307c5 commit 8824e1b

12 files changed

+464
-210
lines changed

src/main/java/coursesketch/recognition/test/RecognitionMetric.java

-77
This file was deleted.

src/main/java/coursesketch/recognition/test/RecognitionTesting.java

+133-76
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
import coursesketch.recognition.framework.TemplateDatabaseInterface;
55
import coursesketch.recognition.framework.exceptions.RecognitionException;
66
import coursesketch.recognition.framework.exceptions.TemplateException;
7+
import coursesketch.recognition.test.converter.ScoreMetricsConverter;
8+
import coursesketch.recognition.test.converter.ScoreMetricsConverterFactory;
9+
import coursesketch.recognition.test.score.RecognitionScore;
10+
import coursesketch.recognition.test.score.RecognitionScoreFactory;
11+
import coursesketch.recognition.test.score.TrainingScore;
12+
import coursesketch.recognition.test.score.TrainingScoreFactory;
713
import protobuf.srl.sketch.Sketch;
814

915
import java.util.ArrayList;
@@ -31,6 +37,9 @@ public class RecognitionTesting {
3137
private int MAX_THREADS = 500;
3238

3339
ExecutorService executor;
40+
protected RecognitionScoreFactory recognitionFactory = new DefaultRecognitionScoreFactory();
41+
protected TrainingScoreFactory trainingFactory = new DefaultTrainingScoreFactory();
42+
private ScoreMetricsConverterFactory converterFactory = new DefaultScoreMetricsConverterFactory();
3443

3544
/**
3645
*
@@ -43,19 +52,31 @@ public RecognitionTesting(TemplateDatabaseInterface databaseInterface, Recogniti
4352
this.recognitionSystems = recognitionSystems;
4453
}
4554

46-
public List<RecognitionScoreMetrics> testAgainstAllTemplates() throws TemplateException {
55+
public void setRecognitionScoreFactory(RecognitionScoreFactory recognitionFactory) {
56+
this.recognitionFactory = recognitionFactory;
57+
}
58+
59+
public void setTrainingScoreFactory(TrainingScoreFactory trainingFactory) {
60+
this.trainingFactory = trainingFactory;
61+
}
62+
63+
public void setScoreMetricsConverterFactory(ScoreMetricsConverterFactory converterFactory) {
64+
this.converterFactory = converterFactory;
65+
}
66+
67+
public List<ScoreMetricsConverter> testAgainstAllTemplates() throws TemplateException {
4768
return testAgainstTemplates(databaseInterface.getAllTemplates());
4869
}
4970

50-
public List<RecognitionScoreMetrics> testAgainstInterpretation(Sketch.SrlInterpretation interpretation)
71+
public List<ScoreMetricsConverter> testAgainstInterpretation(Sketch.SrlInterpretation interpretation)
5172
throws TemplateException {
5273
return testAgainstTemplates(databaseInterface.getTemplate(interpretation));
5374
}
5475

5576
/**
5677
* This uses cross validation to test against templates.
5778
*/
58-
public List<RecognitionScoreMetrics> testAgainstTemplates(List<Sketch.RecognitionTemplate> allTemplates)
79+
public List<ScoreMetricsConverter> testAgainstTemplates(List<Sketch.RecognitionTemplate> allTemplates)
5980
throws TemplateException {
6081

6182
List<Sketch.RecognitionTemplate> testTemplates = splitTrainingAndTest(allTemplates);
@@ -65,14 +86,29 @@ public List<RecognitionScoreMetrics> testAgainstTemplates(List<Sketch.Recognitio
6586
Map<RecognitionInterface, List<RecognitionScore>> recognitionScore =
6687
recognizeAgainstTemplates(testTemplates);
6788

68-
List<RecognitionScoreMetrics> metrics = new ArrayList<>();
89+
List<ScoreMetricsConverter> metrics = new ArrayList<>();
6990
for (RecognitionInterface recognitionSystem : recognitionSystems) {
70-
metrics.add(new RecognitionScoreMetrics(recognitionSystem.getClass().getSimpleName(), trainingScores.get(recognitionSystem),
71-
recognitionScore.get(recognitionSystem)));
91+
ScoreMetricsConverter scoreMetricsConverter = converterFactory.getScoreMetricsConverter(recognitionSystem,
92+
trainingScores.get(recognitionSystem), recognitionScore.get(recognitionSystem));
93+
scoreMetricsConverter.computeRecognitionMetrics();
94+
metrics.add(scoreMetricsConverter);
7295
}
7396
return metrics;
7497
}
7598

99+
protected List<Sketch.SrlInterpretation> testTemplate(Sketch.RecognitionTemplate testTemplate,
100+
RecognitionInterface recognitionSystem,
101+
RecognitionScore score) {
102+
List<Sketch.SrlInterpretation> recognize = null;
103+
try {
104+
recognize = recognitionSystem.recognize(testTemplate.getTemplateId(), testTemplate);
105+
} catch (Exception e) {
106+
score.setNotRecognized(true);
107+
score.setFailed(e);
108+
}
109+
return recognize;
110+
}
111+
76112
public Map<RecognitionInterface, List<RecognitionScore>> recognizeAgainstTemplates(
77113
List<Sketch.RecognitionTemplate> testTemplates) {
78114
Map<RecognitionInterface, List<RecognitionScore>> scoreMap = new HashMap<>();
@@ -90,51 +126,54 @@ public Map<RecognitionInterface, List<RecognitionScore>> recognizeAgainstTemplat
90126
List<Future> taskFutures = new ArrayList<>();
91127
for (Sketch.RecognitionTemplate testTemplate : testTemplates) {
92128
final int thisCount = counter;
93-
taskFutures.add(executor.submit(new Callable(){
94-
@Override
95-
public Object call() throws Exception {
96-
RecognitionScore score = new RecognitionScore(recognitionSystem, testTemplate.getTemplateId());
97-
long startTime = System.nanoTime();
98-
try {
99-
List<Sketch.SrlInterpretation>
100-
recognize = recognitionSystem.recognize(testTemplate.getTemplateId(), testTemplate);
101-
long endTime = System.nanoTime();
102-
score.setRecognitionTime(endTime - startTime);
103-
if (recognize == null) {
104-
score.setFailed(new NullPointerException("List of returned interpretations is null"));
105-
recognitionScoreList.add(score);
106-
return null;
107-
}
108-
generateScore(score, recognize, testTemplate.getInterpretation());
109-
} catch (Exception e) {
110-
score.setFailed(e);
111-
}
112-
recognitionScoreList.add(score);
113-
if (thisCount % percent == 0) {
114-
LOG.debug("gone through {} sketches, {} left", thisCount, testTemplates.size() - thisCount);
115-
}
116-
return null;
129+
taskFutures.add(executor.submit((Callable) () -> {
130+
RecognitionScore score = recognitionFactory.createRecognitionScore(recognitionSystem,
131+
testTemplate.getTemplateId());
132+
long startTime = System.nanoTime();
133+
final List<Sketch.SrlInterpretation> interpretations =
134+
testTemplate(testTemplate, recognitionSystem, score);
135+
generateScore(score, interpretations, testTemplate.getInterpretation());
136+
long endTime = System.nanoTime();
137+
score.setRecognitionTime(endTime - startTime);
138+
recognitionScoreList.add(score);
139+
if (thisCount % percent == 0) {
140+
LOG.debug("gone through {} sketches, {} left", thisCount, testTemplates.size() - thisCount);
117141
}
142+
return null;
118143
}));
119144
counter++;
120145
}
121146

122147
LOG.debug("Waiting for all tasks to finish");
123148
// Waits for the executor to finish
124-
for (Future taskFuture : taskFutures) {
125-
try {
126-
taskFuture.get();
127-
} catch (InterruptedException e) {
128-
LOG.debug("INTERUPTIONS EXCEPTION", e);
129-
} catch (ExecutionException e) {
130-
LOG.debug("EXECUTION EXCEPTION", e);
131-
}
132-
}
149+
waitForFutures(taskFutures);
133150
LOG.debug("All recognition testing tasks have finished");
134151
}
135152
return scoreMap;
136153
}
137154

155+
private void waitForFutures(List<Future> taskFutures) {
156+
for (Future taskFuture : taskFutures) {
157+
try {
158+
taskFuture.get();
159+
} catch (InterruptedException e) {
160+
LOG.debug("INTERUPTIONS EXCEPTION", e);
161+
} catch (ExecutionException e) {
162+
LOG.debug("EXECUTION EXCEPTION", e);
163+
}
164+
}
165+
}
166+
167+
protected void trainSystem(Sketch.RecognitionTemplate template, RecognitionInterface recognitionSystem,
168+
TrainingScore score) {
169+
try {
170+
recognitionSystem.trainTemplate(template);
171+
} catch (Exception e) {
172+
score.addException(new RecognitionTestException("Error with training template " + template.getTemplateId(),
173+
e, recognitionSystem));
174+
}
175+
}
176+
138177
public Map<RecognitionInterface, List<TrainingScore>> trainAgainstTemplates(List<Sketch.RecognitionTemplate> templates) {
139178
Map<RecognitionInterface, List<TrainingScore>> scoreMap = new HashMap<>();
140179

@@ -150,63 +189,57 @@ public Map<RecognitionInterface, List<TrainingScore>> trainAgainstTemplates(List
150189
List<Future> taskFutures = new ArrayList<>();
151190
for (Sketch.RecognitionTemplate template : templates) {
152191
final int thisCount = counter;
153-
taskFutures.add(executor.submit(new Callable() {
154-
@Override
155-
public Object call() throws Exception {
156-
TrainingScore score = new TrainingScore();
157-
long startTime = System.nanoTime();
158-
try {
159-
recognitionSystem.trainTemplate(template);
160-
} catch (Exception e) {
161-
score.addException(new RecognitionTestException("Error with training template " + template.getTemplateId(),
162-
e, recognitionSystem));
163-
}
164-
long endTime = System.nanoTime();
165-
score.setTrainingTime(endTime - startTime);
166-
trainingScores.add(score);
167-
168-
if (thisCount % percent == 0) {
169-
LOG.debug("gone through {} sketches, {} left", thisCount, templates.size() - thisCount);
170-
}
171-
return null;
192+
taskFutures.add(executor.submit((Callable) () -> {
193+
TrainingScore score = trainingFactory.createTrainingScore(recognitionSystem, template.getTemplateId());
194+
long startTime = System.nanoTime();
195+
trainSystem(template, recognitionSystem, score);
196+
long endTime = System.nanoTime();
197+
score.setTrainingTime(endTime - startTime);
198+
trainingScores.add(score);
199+
if (thisCount % percent == 0) {
200+
LOG.debug("gone through {} sketches, {} left", thisCount, templates.size() - thisCount);
172201
}
202+
return null;
173203
}));
174204
counter++;
175205
}
176206

177207
LOG.debug("Waiting for all tasks to finish");
178208
// Waits for the executor to finish
179-
for (Future taskFuture : taskFutures) {
180-
try {
181-
taskFuture.get();
182-
} catch (InterruptedException e) {
183-
LOG.debug("INTERUPTIONS EXCEPTION", e);
184-
} catch (ExecutionException e) {
185-
LOG.debug("EXECUTION EXCEPTION", e);
186-
}
187-
}
188-
try {
189-
recognitionSystem.finishTraining();
190-
} catch (RecognitionException e) {
191-
LOG.debug("EXCEPTION WHEN TRAINING", e);
192-
}
209+
waitForFutures(taskFutures);
210+
finishTraining(recognitionSystem);
193211
LOG.debug("All trainings tasks have finished");
194212
}
195213
return scoreMap;
196214
}
197215

198-
private void generateScore(RecognitionScore score,
216+
protected void finishTraining(RecognitionInterface recognitionSystem) {
217+
try {
218+
recognitionSystem.finishTraining();
219+
} catch (RecognitionException e) {
220+
LOG.debug("EXCEPTION WHEN TRAINING", e);
221+
}
222+
}
223+
224+
protected void generateScore(RecognitionScore score,
199225
List<Sketch.SrlInterpretation> recognize, Sketch.SrlInterpretation interpretation) {
226+
if (recognize == null) {
227+
score.setNotRecognized(true);
228+
score.setFailed(new NullPointerException("List of returned interpretations is null"));
229+
return;
230+
}
200231
double scoreValue = 1;
201232
int topGuesses = Math.min(5, recognize.size());
202233
int subtractAmount = 1/topGuesses;
203234
score.setRecognizedInterpretations(recognize);
204235
score.setCorrectInterpretations(interpretation);
205236
for (int i = 0; i < topGuesses; i++) {
206-
if (recognize.get(i).getLabel().equals(interpretation.getLabel())) {
207-
score.setRecognized(true);
237+
// We won't consider it recognized if it has no confidence in its values
238+
if (recognize.get(i).getLabel().equals(interpretation.getLabel())
239+
&& recognize.get(i).getConfidence() > 0) {
240+
score.setRecognized(i);
208241
score.setScoreValue(scoreValue * recognize.get(i).getConfidence());
209-
return ;
242+
return;
210243
}
211244
if (i == 0) {
212245
score.setPotentialMissRecognized(true);
@@ -234,4 +267,28 @@ private List<Sketch.RecognitionTemplate> splitTrainingAndTest(List<Sketch.Recogn
234267
LOG.debug("TrainingSet: {}, TestingSet: {}", allTemplates.size(), testTemplates.size());
235268
return testTemplates;
236269
}
270+
271+
private static final class DefaultRecognitionScoreFactory implements RecognitionScoreFactory {
272+
@Override
273+
public RecognitionScore createRecognitionScore(RecognitionInterface recognitionSystem, String templateId) {
274+
return new RecognitionScore(recognitionSystem, templateId);
275+
}
276+
}
277+
278+
private static final class DefaultTrainingScoreFactory implements TrainingScoreFactory {
279+
@Override
280+
public TrainingScore createTrainingScore(RecognitionInterface recognitionSystem, String templateId) {
281+
return new TrainingScore(recognitionSystem, templateId);
282+
}
283+
}
284+
285+
private static final class DefaultScoreMetricsConverterFactory implements ScoreMetricsConverterFactory {
286+
287+
@Override
288+
public ScoreMetricsConverter getScoreMetricsConverter(RecognitionInterface recognitionSystem,
289+
List<TrainingScore> trainingScores, List<RecognitionScore> recognitionScores) {
290+
return new ScoreMetricsConverter(recognitionSystem.getClass().getSimpleName(),
291+
trainingScores, recognitionScores);
292+
}
293+
}
237294
}

0 commit comments

Comments
 (0)