Skip to content

Commit 1565a50

Browse files
authored
Merge pull request #667 from htm-community/predictor_precision_fix
SDRClassifier: fix precision by using Real64 for PDF
2 parents a9fbd62 + df7ae4e commit 1565a50

File tree

5 files changed

+110
-105
lines changed

5 files changed

+110
-105
lines changed

bindings/py/cpp_src/bindings/algorithms/py_SDRClassifier.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ Example Usage:
138138
# Give the predictor partial information, and make predictions
139139
# about the future.
140140
pred.reset()
141-
A = pred.infer( 0, sequence[0] )
141+
A = pred.infer( sequence[0] )
142142
numpy.argmax( A[1] ) -> labels[1]
143143
numpy.argmax( A[2] ) -> labels[2]
144144
145-
B = pred.infer( 1, sequence[1] )
145+
B = pred.infer( sequence[1] )
146146
numpy.argmax( B[1] ) -> labels[2]
147147
numpy.argmax( B[2] ) -> labels[3]
148148
)");
@@ -162,14 +162,10 @@ R"(For use with time series datasets.)");
162162
py_Predictor.def("infer", &Predictor::infer,
163163
R"(Compute the likelihoods.
164164
165-
Argument recordNum is an incrementing integer for each record.
166-
Gaps in numbers correspond to missing records.
167-
168165
Argument pattern is the SDR containing the active input bits.
169166
170167
Returns a dictionary whos keys are prediction steps, and values are PDFs.
171168
See help(Classifier.infer) for details about PDFs.)",
172-
py::arg("recordNum"),
173169
py::arg("pattern"));
174170

175171
py_Predictor.def("learn", &Predictor::learn,

bindings/py/tests/algorithms/sdr_classifier_test.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def testExampleUsage(self):
6969
# Give the predictor partial information, and make predictions
7070
# about the future.
7171
pred.reset()
72-
A = pred.infer( 0, sequence[0] )
72+
A = pred.infer( sequence[0] )
7373
assert( numpy.argmax( A[1] ) == labels[1] )
7474
assert( numpy.argmax( A[2] ) == labels[2] )
7575

76-
B = pred.infer( 1, sequence[1] )
76+
B = pred.infer( sequence[1] )
7777
assert( numpy.argmax( B[1] ) == labels[2] )
7878
assert( numpy.argmax( B[2] ) == labels[3] )
7979

@@ -121,7 +121,7 @@ def testSingleValue0Steps(self):
121121
for recordNum in range(10):
122122
pred.learn(recordNum, inp, 2)
123123

124-
retval = pred.infer( 10, inp )
124+
retval = pred.infer( inp )
125125
self.assertGreater(retval[0][2], 0.9)
126126

127127

@@ -131,15 +131,18 @@ def testComputeInferOrLearnOnly(self):
131131
inp.randomize( .3 )
132132

133133
# learn only
134-
c.infer(recordNum=0, pattern=inp) # Don't crash with not enough training data.
134+
with self.assertRaises(RuntimeError):
135+
c.infer(pattern=inp) # crash with not enough training data.
135136
c.learn(recordNum=0, pattern=inp, classification=4)
136-
c.infer(recordNum=1, pattern=inp) # Don't crash with not enough training data.
137+
with self.assertRaises(RuntimeError):
138+
c.infer(pattern=inp) # crash with not enough training data.
137139
c.learn(recordNum=2, pattern=inp, classification=4)
138140
c.learn(recordNum=3, pattern=inp, classification=4)
141+
c.infer(pattern=inp) # Don't crash with not enough training data.
139142

140143
# infer only
141-
retval1 = c.infer(recordNum=5, pattern=inp)
142-
retval2 = c.infer(recordNum=6, pattern=inp)
144+
retval1 = c.infer(pattern=inp)
145+
retval2 = c.infer(pattern=inp)
143146
self.assertSequenceEqual(list(retval1[1]), list(retval2[1]))
144147

145148

@@ -164,7 +167,7 @@ def testComputeComplex(self):
164167
classification=4,)
165168

166169
inp.sparse = [1, 5, 9]
167-
result = c.infer(recordNum=4, pattern=inp)
170+
result = c.infer(pattern=inp)
168171

169172
self.assertSetEqual(set(result.keys()), set([1]))
170173
self.assertEqual(len(result[1]), 6)
@@ -206,7 +209,7 @@ def testMultistepSingleValue(self):
206209
for recordNum in range(10):
207210
classifier.learn(recordNum, inp, 0)
208211

209-
retval = classifier.infer(10, inp)
212+
retval = classifier.infer(inp)
210213

211214
# Should have a probability of 100% for that bucket.
212215
self.assertEqual(retval[1], [1.])
@@ -221,7 +224,7 @@ def testMultistepSimple(self):
221224
inp.sparse = [i % 10]
222225
classifier.learn(recordNum=i, pattern=inp, classification=(i % 10))
223226

224-
retval = classifier.infer(99, inp)
227+
retval = classifier.infer(inp)
225228

226229
self.assertGreater(retval[1][0], 0.99)
227230
for i in range(1, 10):
@@ -267,15 +270,15 @@ def testMissingRecords(self):
267270
# At this point, we should have learned [1,3,5] => bucket 1
268271
# [2,4,6] => bucket 2
269272
inp.sparse = [1, 3, 5]
270-
result = c.infer(recordNum=recordNum, pattern=inp)
273+
result = c.infer(pattern=inp)
271274
c.learn(recordNum=recordNum, pattern=inp, classification=2)
272275
recordNum += 1
273276
self.assertLess(result[1][0], 0.1)
274277
self.assertGreater(result[1][1], 0.9)
275278
self.assertLess(result[1][2], 0.1)
276279

277280
inp.sparse = [2, 4, 6]
278-
result = c.infer(recordNum=recordNum, pattern=inp)
281+
result = c.infer(pattern=inp)
279282
c.learn(recordNum=recordNum, pattern=inp, classification=1)
280283
recordNum += 1
281284
self.assertLess(result[1][0], 0.1)
@@ -289,7 +292,7 @@ def testMissingRecords(self):
289292
# the previous learn associates with bucket 0
290293
recordNum += 1
291294
inp.sparse = [1, 3, 5]
292-
result = c.infer(recordNum=recordNum, pattern=inp)
295+
result = c.infer(pattern=inp)
293296
c.learn(recordNum=recordNum, pattern=inp, classification=0)
294297
recordNum += 1
295298
self.assertLess(result[1][0], 0.1)
@@ -300,7 +303,7 @@ def testMissingRecords(self):
300303
# the previous learn associates with bucket 0
301304
recordNum += 1
302305
inp.sparse = [2, 4, 6]
303-
result = c.infer(recordNum=recordNum, pattern=inp)
306+
result = c.infer(pattern=inp)
304307
c.learn(recordNum=recordNum, pattern=inp, classification=0)
305308
recordNum += 1
306309
self.assertLess(result[1][0], 0.1)
@@ -311,7 +314,7 @@ def testMissingRecords(self):
311314
# the previous learn associates with bucket 0
312315
recordNum += 1
313316
inp.sparse = [1, 3, 5]
314-
result = c.infer(recordNum=recordNum, pattern=inp)
317+
result = c.infer(pattern=inp)
315318
c.learn(recordNum=recordNum, pattern=inp, classification=0)
316319
recordNum += 1
317320
self.assertLess(result[1][0], 0.1)
@@ -548,8 +551,8 @@ def testMultiStepPredictions(self):
548551
c.learn(recordNum, pattern=SDR2, classification=1)
549552
recordNum += 1
550553

551-
result1 = c.infer(recordNum, SDR1)
552-
result2 = c.infer(recordNum, SDR2)
554+
result1 = c.infer(SDR1)
555+
result2 = c.infer(SDR2)
553556

554557
self.assertAlmostEqual(result1[0][0], 1.0, places=1)
555558
self.assertAlmostEqual(result1[0][1], 0.0, places=1)

src/htm/algorithms/SDRClassifier.cpp

+43-54
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,18 @@ void Classifier::initialize(const Real alpha)
3838
{
3939
NTA_CHECK(alpha > 0.0f);
4040
alpha_ = alpha;
41-
dimensions_.clear();
41+
dimensions_ = 0;
4242
numCategories_ = 0u;
4343
weights_.clear();
4444
}
4545

4646

47-
PDF Classifier::infer(const SDR & pattern)
48-
{
49-
// Check input dimensions, or if this is the first time the Classifier has
50-
// been used then initialize it with the given SDR's dimensions.
51-
if( dimensions_.empty() ) {
52-
dimensions_ = pattern.dimensions;
53-
while( weights_.size() < pattern.size ) {
54-
weights_.push_back( vector<Real>( numCategories_, 0.0f ));
55-
}
56-
} else if( pattern.dimensions != dimensions_ ) {
57-
stringstream err_msg;
58-
err_msg << "Classifier input SDR.dimensions mismatch: previously given SDR with dimensions ( ";
59-
for( auto dim : dimensions_ )
60-
{ err_msg << dim << " "; }
61-
err_msg << "), now given SDR with dimensions ( ";
62-
for( auto dim : pattern.dimensions )
63-
{ err_msg << dim << " "; }
64-
err_msg << ").";
65-
NTA_THROW << err_msg.str();
66-
}
47+
PDF Classifier::infer(const SDR & pattern) const {
48+
// Check input dimensions, or if this is the first time the Classifier is used and dimensions
49+
// are unset, return zeroes.
50+
NTA_CHECK( dimensions_ != 0 )
51+
<< "Classifier: must call `learn` before `infer`.";
52+
NTA_ASSERT(pattern.size == dimensions_) << "Input SDR does not match previously seen size!";
6753

6854
// Accumulate feed forward input.
6955
PDF probabilities( numCategories_, 0.0f );
@@ -81,8 +67,19 @@ PDF Classifier::infer(const SDR & pattern)
8167

8268
void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
8369
{
70+
// If this is the first time the Classifier is being used, weights are empty,
71+
// so we set the dimensions to that of the input `pattern`
72+
if( dimensions_ == 0 ) {
73+
dimensions_ = pattern.size;
74+
while( weights_.size() < pattern.size ) {
75+
const auto initialEmptyWeights = PDF( numCategories_, 0.0f );
76+
weights_.push_back( initialEmptyWeights );
77+
}
78+
}
79+
NTA_ASSERT(pattern.size == dimensions_) << "Input SDR does not match previously seen size!";
80+
8481
// Check if this is a new category & resize the weights table to hold it.
85-
const auto maxCategoryIdx = *max_element(categoryIdxList.begin(), categoryIdxList.end());
82+
const auto maxCategoryIdx = *max_element(categoryIdxList.cbegin(), categoryIdxList.cend());
8683
if( maxCategoryIdx >= numCategories_ ) {
8784
numCategories_ = maxCategoryIdx + 1;
8885
for( auto & vec : weights_ ) {
@@ -93,7 +90,7 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
9390
}
9491

9592
// Compute errors and update weights.
96-
const vector<Real> error = calculateError_(categoryIdxList, pattern);
93+
const auto& error = calculateError_(categoryIdxList, pattern);
9794
for( const auto& bit : pattern.getSparse() ) {
9895
for(size_t i = 0u; i < numCategories_; i++) {
9996
weights_[bit][i] += alpha_ * error[i];
@@ -103,9 +100,8 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
103100

104101

105102
// Helper function to compute the error signal in learning.
106-
std::vector<Real> Classifier::calculateError_(
107-
const std::vector<UInt> &categoryIdxList, const SDR &pattern)
108-
{
103+
std::vector<Real64> Classifier::calculateError_(const std::vector<UInt> &categoryIdxList,
104+
const SDR &pattern) const {
109105
// compute predicted likelihoods
110106
auto likelihoods = infer(pattern);
111107

@@ -165,56 +161,49 @@ void Predictor::reset() {
165161
}
166162

167163

168-
Predictions Predictor::infer(const UInt recordNum, const SDR &pattern)
169-
{
170-
updateHistory_( recordNum, pattern );
171-
164+
Predictions Predictor::infer(const SDR &pattern) const {
172165
Predictions result;
173166
for( const auto step : steps_ ) {
174-
result[step] = classifiers_[step].infer( pattern );
167+
result.insert({step, classifiers_.at(step).infer( pattern )});
175168
}
176169
return result;
177170
}
178171

179172

180-
void Predictor::learn(const UInt recordNum, const SDR &pattern,
173+
void Predictor::learn(const UInt recordNum, //TODO make recordNum optional, autoincrement as steps
174+
const SDR &pattern,
181175
const std::vector<UInt> &bucketIdxList)
182176
{
183-
updateHistory_( recordNum, pattern );
177+
checkMonotonic_(recordNum);
178+
179+
// Update pattern history if this is a new record.
180+
const UInt lastRecordNum = recordNumHistory_.empty() ? 0 : recordNumHistory_.back();
181+
if (recordNumHistory_.size() == 0u || recordNum > lastRecordNum) {
182+
patternHistory_.emplace_back( pattern );
183+
recordNumHistory_.push_back(recordNum);
184+
if (patternHistory_.size() > steps_.back() + 1u) { //steps_ are sorted, so steps_.back() is the "oldest/deepest" N-th step (ie 10 of [1,2,10])
185+
patternHistory_.pop_front();
186+
recordNumHistory_.pop_front();
187+
}
188+
}
184189

185190
// Iterate through all recently given inputs, starting from the furthest in the past.
186191
auto pastPattern = patternHistory_.begin();
187192
auto pastRecordNum = recordNumHistory_.begin();
188-
for( ; pastRecordNum != recordNumHistory_.end(); pastPattern++, pastRecordNum++ )
193+
for( ; pastRecordNum != recordNumHistory_.cend(); pastPattern++, pastRecordNum++ )
189194
{
190195
const UInt nSteps = recordNum - *pastRecordNum;
191196

192197
// Update weights.
193198
if( binary_search( steps_.begin(), steps_.end(), nSteps )) {
194-
classifiers_[nSteps].learn( *pastPattern, bucketIdxList );
199+
classifiers_.at(nSteps).learn( *pastPattern, bucketIdxList );
195200
}
196201
}
197202
}
198203

199204

200-
void Predictor::updateHistory_(const UInt recordNum, const SDR & pattern)
201-
{
205+
void Predictor::checkMonotonic_(const UInt recordNum) const {
202206
// Ensure that recordNum increases monotonically.
203-
UInt lastRecordNum = -1;
204-
if( not recordNumHistory_.empty() ) {
205-
lastRecordNum = recordNumHistory_.back();
206-
if (recordNum < lastRecordNum) {
207-
NTA_THROW << "The record number must increase monotonically.";
208-
}
209-
}
210-
211-
// Update pattern history if this is a new record.
212-
if (recordNumHistory_.size() == 0u || recordNum > lastRecordNum) {
213-
patternHistory_.emplace_back( pattern );
214-
recordNumHistory_.push_back(recordNum);
215-
if (patternHistory_.size() > steps_.back() + 1u) {
216-
patternHistory_.pop_front();
217-
recordNumHistory_.pop_front();
218-
}
219-
}
207+
const UInt lastRecordNum = recordNumHistory_.empty() ? 0 : recordNumHistory_.back();
208+
NTA_CHECK(recordNum >= lastRecordNum) << "The record number must increase monotonically.";
220209
}

0 commit comments

Comments
 (0)