diff --git a/2016-11_Seminar/Session 3 - Relation CNN/code/CNN.py b/2016-11_Seminar/Session 3 - Relation CNN/code/CNN.py index 08dc767..fb79e49 100644 --- a/2016-11_Seminar/Session 3 - Relation CNN/code/CNN.py +++ b/2016-11_Seminar/Session 3 - Relation CNN/code/CNN.py @@ -133,7 +133,7 @@ def getPrecision(pred_test, yTest, targetLabel): f1Sum = 0 f1Count = 0 - for targetLabel in xrange(1, max(yTest)): + for targetLabel in xrange(1, max(yTest)+1): prec = getPrecision(pred_test, yTest, targetLabel) rec = getPrecision(yTest, pred_test, targetLabel) f1 = 0 if (prec+rec) == 0 else 2*prec*rec/(prec+rec) @@ -143,4 +143,4 @@ def getPrecision(pred_test, yTest, targetLabel): macroF1 = f1Sum / float(f1Count) max_f1 = max(max_f1, macroF1) - print "Non-other Macro-Averaged F1: %.4f (max: %.4f)\n" % (macroF1, max_f1) \ No newline at end of file + print "Non-other Macro-Averaged F1: %.4f (max: %.4f)\n" % (macroF1, max_f1)