4
4
import coursesketch .recognition .framework .TemplateDatabaseInterface ;
5
5
import coursesketch .recognition .framework .exceptions .RecognitionException ;
6
6
import coursesketch .recognition .framework .exceptions .TemplateException ;
7
+ import coursesketch .recognition .test .converter .ScoreMetricsConverter ;
8
+ import coursesketch .recognition .test .converter .ScoreMetricsConverterFactory ;
9
+ import coursesketch .recognition .test .score .RecognitionScore ;
10
+ import coursesketch .recognition .test .score .RecognitionScoreFactory ;
11
+ import coursesketch .recognition .test .score .TrainingScore ;
12
+ import coursesketch .recognition .test .score .TrainingScoreFactory ;
7
13
import protobuf .srl .sketch .Sketch ;
8
14
9
15
import java .util .ArrayList ;
@@ -31,6 +37,9 @@ public class RecognitionTesting {
31
37
private int MAX_THREADS = 500 ;
32
38
33
39
ExecutorService executor ;
40
+ protected RecognitionScoreFactory recognitionFactory = new DefaultRecognitionScoreFactory ();
41
+ protected TrainingScoreFactory trainingFactory = new DefaultTrainingScoreFactory ();
42
+ private ScoreMetricsConverterFactory converterFactory = new DefaultScoreMetricsConverterFactory ();
34
43
35
44
/**
36
45
*
@@ -43,19 +52,31 @@ public RecognitionTesting(TemplateDatabaseInterface databaseInterface, Recogniti
43
52
this .recognitionSystems = recognitionSystems ;
44
53
}
45
54
46
- public List <RecognitionScoreMetrics > testAgainstAllTemplates () throws TemplateException {
55
+ public void setRecognitionScoreFactory (RecognitionScoreFactory recognitionFactory ) {
56
+ this .recognitionFactory = recognitionFactory ;
57
+ }
58
+
59
+ public void setTrainingScoreFactory (TrainingScoreFactory trainingFactory ) {
60
+ this .trainingFactory = trainingFactory ;
61
+ }
62
+
63
+ public void setScoreMetricsConverterFactory (ScoreMetricsConverterFactory converterFactory ) {
64
+ this .converterFactory = converterFactory ;
65
+ }
66
+
67
+ public List <ScoreMetricsConverter > testAgainstAllTemplates () throws TemplateException {
47
68
return testAgainstTemplates (databaseInterface .getAllTemplates ());
48
69
}
49
70
50
- public List <RecognitionScoreMetrics > testAgainstInterpretation (Sketch .SrlInterpretation interpretation )
71
+ public List <ScoreMetricsConverter > testAgainstInterpretation (Sketch .SrlInterpretation interpretation )
51
72
throws TemplateException {
52
73
return testAgainstTemplates (databaseInterface .getTemplate (interpretation ));
53
74
}
54
75
55
76
/**
56
77
* This uses cross validation to test against templates.
57
78
*/
58
- public List <RecognitionScoreMetrics > testAgainstTemplates (List <Sketch .RecognitionTemplate > allTemplates )
79
+ public List <ScoreMetricsConverter > testAgainstTemplates (List <Sketch .RecognitionTemplate > allTemplates )
59
80
throws TemplateException {
60
81
61
82
List <Sketch .RecognitionTemplate > testTemplates = splitTrainingAndTest (allTemplates );
@@ -65,14 +86,29 @@ public List<RecognitionScoreMetrics> testAgainstTemplates(List<Sketch.Recognitio
65
86
Map <RecognitionInterface , List <RecognitionScore >> recognitionScore =
66
87
recognizeAgainstTemplates (testTemplates );
67
88
68
- List <RecognitionScoreMetrics > metrics = new ArrayList <>();
89
+ List <ScoreMetricsConverter > metrics = new ArrayList <>();
69
90
for (RecognitionInterface recognitionSystem : recognitionSystems ) {
70
- metrics .add (new RecognitionScoreMetrics (recognitionSystem .getClass ().getSimpleName (), trainingScores .get (recognitionSystem ),
71
- recognitionScore .get (recognitionSystem )));
91
+ ScoreMetricsConverter scoreMetricsConverter = converterFactory .getScoreMetricsConverter (recognitionSystem ,
92
+ trainingScores .get (recognitionSystem ), recognitionScore .get (recognitionSystem ));
93
+ scoreMetricsConverter .computeRecognitionMetrics ();
94
+ metrics .add (scoreMetricsConverter );
72
95
}
73
96
return metrics ;
74
97
}
75
98
99
+ protected List <Sketch .SrlInterpretation > testTemplate (Sketch .RecognitionTemplate testTemplate ,
100
+ RecognitionInterface recognitionSystem ,
101
+ RecognitionScore score ) {
102
+ List <Sketch .SrlInterpretation > recognize = null ;
103
+ try {
104
+ recognize = recognitionSystem .recognize (testTemplate .getTemplateId (), testTemplate );
105
+ } catch (Exception e ) {
106
+ score .setNotRecognized (true );
107
+ score .setFailed (e );
108
+ }
109
+ return recognize ;
110
+ }
111
+
76
112
public Map <RecognitionInterface , List <RecognitionScore >> recognizeAgainstTemplates (
77
113
List <Sketch .RecognitionTemplate > testTemplates ) {
78
114
Map <RecognitionInterface , List <RecognitionScore >> scoreMap = new HashMap <>();
@@ -90,51 +126,54 @@ public Map<RecognitionInterface, List<RecognitionScore>> recognizeAgainstTemplat
90
126
List <Future > taskFutures = new ArrayList <>();
91
127
for (Sketch .RecognitionTemplate testTemplate : testTemplates ) {
92
128
final int thisCount = counter ;
93
- taskFutures .add (executor .submit (new Callable (){
94
- @ Override
95
- public Object call () throws Exception {
96
- RecognitionScore score = new RecognitionScore (recognitionSystem , testTemplate .getTemplateId ());
97
- long startTime = System .nanoTime ();
98
- try {
99
- List <Sketch .SrlInterpretation >
100
- recognize = recognitionSystem .recognize (testTemplate .getTemplateId (), testTemplate );
101
- long endTime = System .nanoTime ();
102
- score .setRecognitionTime (endTime - startTime );
103
- if (recognize == null ) {
104
- score .setFailed (new NullPointerException ("List of returned interpretations is null" ));
105
- recognitionScoreList .add (score );
106
- return null ;
107
- }
108
- generateScore (score , recognize , testTemplate .getInterpretation ());
109
- } catch (Exception e ) {
110
- score .setFailed (e );
111
- }
112
- recognitionScoreList .add (score );
113
- if (thisCount % percent == 0 ) {
114
- LOG .debug ("gone through {} sketches, {} left" , thisCount , testTemplates .size () - thisCount );
115
- }
116
- return null ;
129
+ taskFutures .add (executor .submit ((Callable ) () -> {
130
+ RecognitionScore score = recognitionFactory .createRecognitionScore (recognitionSystem ,
131
+ testTemplate .getTemplateId ());
132
+ long startTime = System .nanoTime ();
133
+ final List <Sketch .SrlInterpretation > interpretations =
134
+ testTemplate (testTemplate , recognitionSystem , score );
135
+ generateScore (score , interpretations , testTemplate .getInterpretation ());
136
+ long endTime = System .nanoTime ();
137
+ score .setRecognitionTime (endTime - startTime );
138
+ recognitionScoreList .add (score );
139
+ if (thisCount % percent == 0 ) {
140
+ LOG .debug ("gone through {} sketches, {} left" , thisCount , testTemplates .size () - thisCount );
117
141
}
142
+ return null ;
118
143
}));
119
144
counter ++;
120
145
}
121
146
122
147
LOG .debug ("Waiting for all tasks to finish" );
123
148
// Waits for the executor to finish
124
- for (Future taskFuture : taskFutures ) {
125
- try {
126
- taskFuture .get ();
127
- } catch (InterruptedException e ) {
128
- LOG .debug ("INTERUPTIONS EXCEPTION" , e );
129
- } catch (ExecutionException e ) {
130
- LOG .debug ("EXECUTION EXCEPTION" , e );
131
- }
132
- }
149
+ waitForFutures (taskFutures );
133
150
LOG .debug ("All recognition testing tasks have finished" );
134
151
}
135
152
return scoreMap ;
136
153
}
137
154
155
+ private void waitForFutures (List <Future > taskFutures ) {
156
+ for (Future taskFuture : taskFutures ) {
157
+ try {
158
+ taskFuture .get ();
159
+ } catch (InterruptedException e ) {
160
+ LOG .debug ("INTERUPTIONS EXCEPTION" , e );
161
+ } catch (ExecutionException e ) {
162
+ LOG .debug ("EXECUTION EXCEPTION" , e );
163
+ }
164
+ }
165
+ }
166
+
167
+ protected void trainSystem (Sketch .RecognitionTemplate template , RecognitionInterface recognitionSystem ,
168
+ TrainingScore score ) {
169
+ try {
170
+ recognitionSystem .trainTemplate (template );
171
+ } catch (Exception e ) {
172
+ score .addException (new RecognitionTestException ("Error with training template " + template .getTemplateId (),
173
+ e , recognitionSystem ));
174
+ }
175
+ }
176
+
138
177
public Map <RecognitionInterface , List <TrainingScore >> trainAgainstTemplates (List <Sketch .RecognitionTemplate > templates ) {
139
178
Map <RecognitionInterface , List <TrainingScore >> scoreMap = new HashMap <>();
140
179
@@ -150,63 +189,57 @@ public Map<RecognitionInterface, List<TrainingScore>> trainAgainstTemplates(List
150
189
List <Future > taskFutures = new ArrayList <>();
151
190
for (Sketch .RecognitionTemplate template : templates ) {
152
191
final int thisCount = counter ;
153
- taskFutures .add (executor .submit (new Callable () {
154
- @ Override
155
- public Object call () throws Exception {
156
- TrainingScore score = new TrainingScore ();
157
- long startTime = System .nanoTime ();
158
- try {
159
- recognitionSystem .trainTemplate (template );
160
- } catch (Exception e ) {
161
- score .addException (new RecognitionTestException ("Error with training template " + template .getTemplateId (),
162
- e , recognitionSystem ));
163
- }
164
- long endTime = System .nanoTime ();
165
- score .setTrainingTime (endTime - startTime );
166
- trainingScores .add (score );
167
-
168
- if (thisCount % percent == 0 ) {
169
- LOG .debug ("gone through {} sketches, {} left" , thisCount , templates .size () - thisCount );
170
- }
171
- return null ;
192
+ taskFutures .add (executor .submit ((Callable ) () -> {
193
+ TrainingScore score = trainingFactory .createTrainingScore (recognitionSystem , template .getTemplateId ());
194
+ long startTime = System .nanoTime ();
195
+ trainSystem (template , recognitionSystem , score );
196
+ long endTime = System .nanoTime ();
197
+ score .setTrainingTime (endTime - startTime );
198
+ trainingScores .add (score );
199
+ if (thisCount % percent == 0 ) {
200
+ LOG .debug ("gone through {} sketches, {} left" , thisCount , templates .size () - thisCount );
172
201
}
202
+ return null ;
173
203
}));
174
204
counter ++;
175
205
}
176
206
177
207
LOG .debug ("Waiting for all tasks to finish" );
178
208
// Waits for the executor to finish
179
- for (Future taskFuture : taskFutures ) {
180
- try {
181
- taskFuture .get ();
182
- } catch (InterruptedException e ) {
183
- LOG .debug ("INTERUPTIONS EXCEPTION" , e );
184
- } catch (ExecutionException e ) {
185
- LOG .debug ("EXECUTION EXCEPTION" , e );
186
- }
187
- }
188
- try {
189
- recognitionSystem .finishTraining ();
190
- } catch (RecognitionException e ) {
191
- LOG .debug ("EXCEPTION WHEN TRAINING" , e );
192
- }
209
+ waitForFutures (taskFutures );
210
+ finishTraining (recognitionSystem );
193
211
LOG .debug ("All trainings tasks have finished" );
194
212
}
195
213
return scoreMap ;
196
214
}
197
215
198
- private void generateScore (RecognitionScore score ,
216
+ protected void finishTraining (RecognitionInterface recognitionSystem ) {
217
+ try {
218
+ recognitionSystem .finishTraining ();
219
+ } catch (RecognitionException e ) {
220
+ LOG .debug ("EXCEPTION WHEN TRAINING" , e );
221
+ }
222
+ }
223
+
224
+ protected void generateScore (RecognitionScore score ,
199
225
List <Sketch .SrlInterpretation > recognize , Sketch .SrlInterpretation interpretation ) {
226
+ if (recognize == null ) {
227
+ score .setNotRecognized (true );
228
+ score .setFailed (new NullPointerException ("List of returned interpretations is null" ));
229
+ return ;
230
+ }
200
231
double scoreValue = 1 ;
201
232
int topGuesses = Math .min (5 , recognize .size ());
202
233
int subtractAmount = 1 /topGuesses ;
203
234
score .setRecognizedInterpretations (recognize );
204
235
score .setCorrectInterpretations (interpretation );
205
236
for (int i = 0 ; i < topGuesses ; i ++) {
206
- if (recognize .get (i ).getLabel ().equals (interpretation .getLabel ())) {
207
- score .setRecognized (true );
237
+ // We won't consider it recognized if it has no confidence in its values
238
+ if (recognize .get (i ).getLabel ().equals (interpretation .getLabel ())
239
+ && recognize .get (i ).getConfidence () > 0 ) {
240
+ score .setRecognized (i );
208
241
score .setScoreValue (scoreValue * recognize .get (i ).getConfidence ());
209
- return ;
242
+ return ;
210
243
}
211
244
if (i == 0 ) {
212
245
score .setPotentialMissRecognized (true );
@@ -234,4 +267,28 @@ private List<Sketch.RecognitionTemplate> splitTrainingAndTest(List<Sketch.Recogn
234
267
LOG .debug ("TrainingSet: {}, TestingSet: {}" , allTemplates .size (), testTemplates .size ());
235
268
return testTemplates ;
236
269
}
270
+
271
+ private static final class DefaultRecognitionScoreFactory implements RecognitionScoreFactory {
272
+ @ Override
273
+ public RecognitionScore createRecognitionScore (RecognitionInterface recognitionSystem , String templateId ) {
274
+ return new RecognitionScore (recognitionSystem , templateId );
275
+ }
276
+ }
277
+
278
+ private static final class DefaultTrainingScoreFactory implements TrainingScoreFactory {
279
+ @ Override
280
+ public TrainingScore createTrainingScore (RecognitionInterface recognitionSystem , String templateId ) {
281
+ return new TrainingScore (recognitionSystem , templateId );
282
+ }
283
+ }
284
+
285
+ private static final class DefaultScoreMetricsConverterFactory implements ScoreMetricsConverterFactory {
286
+
287
+ @ Override
288
+ public ScoreMetricsConverter getScoreMetricsConverter (RecognitionInterface recognitionSystem ,
289
+ List <TrainingScore > trainingScores , List <RecognitionScore > recognitionScores ) {
290
+ return new ScoreMetricsConverter (recognitionSystem .getClass ().getSimpleName (),
291
+ trainingScores , recognitionScores );
292
+ }
293
+ }
237
294
}
0 commit comments