Skip to content

Commit df7ae4e

Browse files
committed
Predictor: remove numRecord arg, fixes
as numRecord is no longer needed for inference. API changes to bidings and tests to reflect the change. Several fixes from earlier commits in this PR, this fixes the segfaults.
1 parent 4d8708d commit df7ae4e

File tree

5 files changed

+41
-48
lines changed

5 files changed

+41
-48
lines changed

bindings/py/cpp_src/bindings/algorithms/py_SDRClassifier.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ Example Usage:
138138
# Give the predictor partial information, and make predictions
139139
# about the future.
140140
pred.reset()
141-
A = pred.infer( 0, sequence[0] )
141+
A = pred.infer( sequence[0] )
142142
numpy.argmax( A[1] ) -> labels[1]
143143
numpy.argmax( A[2] ) -> labels[2]
144144
145-
B = pred.infer( 1, sequence[1] )
145+
B = pred.infer( sequence[1] )
146146
numpy.argmax( B[1] ) -> labels[2]
147147
numpy.argmax( B[2] ) -> labels[3]
148148
)");
@@ -162,14 +162,10 @@ R"(For use with time series datasets.)");
162162
py_Predictor.def("infer", &Predictor::infer,
163163
R"(Compute the likelihoods.
164164
165-
Argument recordNum is an incrementing integer for each record.
166-
Gaps in numbers correspond to missing records.
167-
168165
Argument pattern is the SDR containing the active input bits.
169166
170167
Returns a dictionary whos keys are prediction steps, and values are PDFs.
171168
See help(Classifier.infer) for details about PDFs.)",
172-
py::arg("recordNum"),
173169
py::arg("pattern"));
174170

175171
py_Predictor.def("learn", &Predictor::learn,

bindings/py/tests/algorithms/sdr_classifier_test.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def testExampleUsage(self):
6969
# Give the predictor partial information, and make predictions
7070
# about the future.
7171
pred.reset()
72-
A = pred.infer( 0, sequence[0] )
72+
A = pred.infer( sequence[0] )
7373
assert( numpy.argmax( A[1] ) == labels[1] )
7474
assert( numpy.argmax( A[2] ) == labels[2] )
7575

76-
B = pred.infer( 1, sequence[1] )
76+
B = pred.infer( sequence[1] )
7777
assert( numpy.argmax( B[1] ) == labels[2] )
7878
assert( numpy.argmax( B[2] ) == labels[3] )
7979

@@ -121,7 +121,7 @@ def testSingleValue0Steps(self):
121121
for recordNum in range(10):
122122
pred.learn(recordNum, inp, 2)
123123

124-
retval = pred.infer( 10, inp )
124+
retval = pred.infer( inp )
125125
self.assertGreater(retval[0][2], 0.9)
126126

127127

@@ -131,15 +131,18 @@ def testComputeInferOrLearnOnly(self):
131131
inp.randomize( .3 )
132132

133133
# learn only
134-
c.infer(recordNum=0, pattern=inp) # Don't crash with not enough training data.
134+
with self.assertRaises(RuntimeError):
135+
c.infer(pattern=inp) # crash with not enough training data.
135136
c.learn(recordNum=0, pattern=inp, classification=4)
136-
c.infer(recordNum=1, pattern=inp) # Don't crash with not enough training data.
137+
with self.assertRaises(RuntimeError):
138+
c.infer(pattern=inp) # crash with not enough training data.
137139
c.learn(recordNum=2, pattern=inp, classification=4)
138140
c.learn(recordNum=3, pattern=inp, classification=4)
141+
c.infer(pattern=inp) # Don't crash with not enough training data.
139142

140143
# infer only
141-
retval1 = c.infer(recordNum=5, pattern=inp)
142-
retval2 = c.infer(recordNum=6, pattern=inp)
144+
retval1 = c.infer(pattern=inp)
145+
retval2 = c.infer(pattern=inp)
143146
self.assertSequenceEqual(list(retval1[1]), list(retval2[1]))
144147

145148

@@ -164,7 +167,7 @@ def testComputeComplex(self):
164167
classification=4,)
165168

166169
inp.sparse = [1, 5, 9]
167-
result = c.infer(recordNum=4, pattern=inp)
170+
result = c.infer(pattern=inp)
168171

169172
self.assertSetEqual(set(result.keys()), set([1]))
170173
self.assertEqual(len(result[1]), 6)
@@ -206,7 +209,7 @@ def testMultistepSingleValue(self):
206209
for recordNum in range(10):
207210
classifier.learn(recordNum, inp, 0)
208211

209-
retval = classifier.infer(10, inp)
212+
retval = classifier.infer(inp)
210213

211214
# Should have a probability of 100% for that bucket.
212215
self.assertEqual(retval[1], [1.])
@@ -221,7 +224,7 @@ def testMultistepSimple(self):
221224
inp.sparse = [i % 10]
222225
classifier.learn(recordNum=i, pattern=inp, classification=(i % 10))
223226

224-
retval = classifier.infer(99, inp)
227+
retval = classifier.infer(inp)
225228

226229
self.assertGreater(retval[1][0], 0.99)
227230
for i in range(1, 10):
@@ -267,15 +270,15 @@ def testMissingRecords(self):
267270
# At this point, we should have learned [1,3,5] => bucket 1
268271
# [2,4,6] => bucket 2
269272
inp.sparse = [1, 3, 5]
270-
result = c.infer(recordNum=recordNum, pattern=inp)
273+
result = c.infer(pattern=inp)
271274
c.learn(recordNum=recordNum, pattern=inp, classification=2)
272275
recordNum += 1
273276
self.assertLess(result[1][0], 0.1)
274277
self.assertGreater(result[1][1], 0.9)
275278
self.assertLess(result[1][2], 0.1)
276279

277280
inp.sparse = [2, 4, 6]
278-
result = c.infer(recordNum=recordNum, pattern=inp)
281+
result = c.infer(pattern=inp)
279282
c.learn(recordNum=recordNum, pattern=inp, classification=1)
280283
recordNum += 1
281284
self.assertLess(result[1][0], 0.1)
@@ -289,7 +292,7 @@ def testMissingRecords(self):
289292
# the previous learn associates with bucket 0
290293
recordNum += 1
291294
inp.sparse = [1, 3, 5]
292-
result = c.infer(recordNum=recordNum, pattern=inp)
295+
result = c.infer(pattern=inp)
293296
c.learn(recordNum=recordNum, pattern=inp, classification=0)
294297
recordNum += 1
295298
self.assertLess(result[1][0], 0.1)
@@ -300,7 +303,7 @@ def testMissingRecords(self):
300303
# the previous learn associates with bucket 0
301304
recordNum += 1
302305
inp.sparse = [2, 4, 6]
303-
result = c.infer(recordNum=recordNum, pattern=inp)
306+
result = c.infer(pattern=inp)
304307
c.learn(recordNum=recordNum, pattern=inp, classification=0)
305308
recordNum += 1
306309
self.assertLess(result[1][0], 0.1)
@@ -311,7 +314,7 @@ def testMissingRecords(self):
311314
# the previous learn associates with bucket 0
312315
recordNum += 1
313316
inp.sparse = [1, 3, 5]
314-
result = c.infer(recordNum=recordNum, pattern=inp)
317+
result = c.infer(pattern=inp)
315318
c.learn(recordNum=recordNum, pattern=inp, classification=0)
316319
recordNum += 1
317320
self.assertLess(result[1][0], 0.1)
@@ -548,8 +551,8 @@ def testMultiStepPredictions(self):
548551
c.learn(recordNum, pattern=SDR2, classification=1)
549552
recordNum += 1
550553

551-
result1 = c.infer(recordNum, SDR1)
552-
result2 = c.infer(recordNum, SDR2)
554+
result1 = c.infer(SDR1)
555+
result2 = c.infer(SDR2)
553556

554557
self.assertAlmostEqual(result1[0][0], 1.0, places=1)
555558
self.assertAlmostEqual(result1[0][1], 0.0, places=1)

src/htm/algorithms/SDRClassifier.cpp

+10-13
Original file line numberDiff line numberDiff line change
@@ -161,28 +161,27 @@ void Predictor::reset() {
161161
}
162162

163163

164-
Predictions Predictor::infer(const UInt recordNum, const SDR &pattern) const {
165-
checkMonotonic_(recordNum);
166-
164+
Predictions Predictor::infer(const SDR &pattern) const {
167165
Predictions result;
168166
for( const auto step : steps_ ) {
169-
result[step] = classifiers_.at(step).infer( pattern );
167+
result.insert({step, classifiers_.at(step).infer( pattern )});
170168
}
171169
return result;
172170
}
173171

174172

175-
void Predictor::learn(const UInt recordNum, const SDR &pattern,
173+
void Predictor::learn(const UInt recordNum, //TODO make recordNum optional, autoincrement as steps
174+
const SDR &pattern,
176175
const std::vector<UInt> &bucketIdxList)
177176
{
178177
checkMonotonic_(recordNum);
179178

180179
// Update pattern history if this is a new record.
181-
const UInt lastRecordNum = recordNumHistory_.empty() ? -1 : recordNumHistory_.back();
180+
const UInt lastRecordNum = recordNumHistory_.empty() ? 0 : recordNumHistory_.back();
182181
if (recordNumHistory_.size() == 0u || recordNum > lastRecordNum) {
183182
patternHistory_.emplace_back( pattern );
184183
recordNumHistory_.push_back(recordNum);
185-
if (patternHistory_.size() > steps_.back() + 1u) {
184+
if (patternHistory_.size() > steps_.back() + 1u) { //steps_ are sorted, so steps_.back() is the "oldest/deepest" N-th step (ie 10 of [1,2,10])
186185
patternHistory_.pop_front();
187186
recordNumHistory_.pop_front();
188187
}
@@ -191,22 +190,20 @@ void Predictor::learn(const UInt recordNum, const SDR &pattern,
191190
// Iterate through all recently given inputs, starting from the furthest in the past.
192191
auto pastPattern = patternHistory_.begin();
193192
auto pastRecordNum = recordNumHistory_.begin();
194-
for( ; pastRecordNum != recordNumHistory_.end(); pastPattern++, pastRecordNum++ )
193+
for( ; pastRecordNum != recordNumHistory_.cend(); pastPattern++, pastRecordNum++ )
195194
{
196195
const UInt nSteps = recordNum - *pastRecordNum;
197196

198197
// Update weights.
199198
if( binary_search( steps_.begin(), steps_.end(), nSteps )) {
200-
classifiers_[nSteps].learn( *pastPattern, bucketIdxList );
199+
classifiers_.at(nSteps).learn( *pastPattern, bucketIdxList );
201200
}
202201
}
203202
}
204203

205204

206205
void Predictor::checkMonotonic_(const UInt recordNum) const {
207206
// Ensure that recordNum increases monotonically.
208-
const UInt lastRecordNum = recordNumHistory_.empty() ? -1 : recordNumHistory_.back();
209-
if( not recordNumHistory_.empty() ) {
210-
NTA_CHECK(recordNum >= lastRecordNum) << "The record number must increase monotonically.";
211-
}
207+
const UInt lastRecordNum = recordNumHistory_.empty() ? 0 : recordNumHistory_.back();
208+
NTA_CHECK(recordNum >= lastRecordNum) << "The record number must increase monotonically.";
212209
}

src/htm/algorithms/SDRClassifier.hpp

+3-6
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,11 @@ using Predictions = std::unordered_map<UInt, PDF>;
218218
* // Give the predictor partial information, and make predictions
219219
* // about the future.
220220
* pred.reset();
221-
* Predictions A = pred.infer( 0, sequence[0] );
221+
* Predictions A = pred.infer( sequence[0] );
222222
* argmax( A[1] ) -> labels[1]
223223
* argmax( A[2] ) -> labels[2]
224224
*
225-
* Predictions B = pred.infer( 1, sequence[1] );
225+
* Predictions B = pred.infer( sequence[1] );
226226
* argmax( B[1] ) -> labels[2]
227227
* argmax( B[2] ) -> labels[3]
228228
* ```
@@ -254,14 +254,11 @@ class Predictor : public Serializable
254254
/**
255255
* Compute the likelihoods.
256256
*
257-
* @param recordNum: An incrementing integer for each record. Gaps in
258-
* numbers correspond to missing records.
259-
*
260257
* @param pattern: The active input SDR.
261258
*
262259
* @returns: A mapping from prediction step to PDF.
263260
*/
264-
Predictions infer(const UInt recordNum, const SDR &pattern) const;
261+
Predictions infer(const SDR &pattern) const;
265262

266263
/**
267264
* Learn from example data.

src/test/unit/algorithms/SDRClassifierTest.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ TEST(SDRClassifierTest, ExampleUsagePredictor)
8181
// Give the predictor partial information, and make predictions
8282
// about the future.
8383
pred.reset();
84-
Predictions A = pred.infer( 0, sequence[0] );
84+
Predictions A = pred.infer( sequence[0] );
8585
ASSERT_EQ( argmax( A[1] ), labels[1] );
8686
ASSERT_EQ( argmax( A[2] ), labels[2] );
8787

88-
Predictions B = pred.infer( 1, sequence[1] );
88+
Predictions B = pred.infer( sequence[1] );
8989
ASSERT_EQ( argmax( B[1] ), labels[2] );
9090
ASSERT_EQ( argmax( B[2] ), labels[3] );
9191
}
@@ -103,7 +103,7 @@ TEST(SDRClassifierTest, SingleValue) {
103103
for (UInt i = 0u; i < 10u; ++i) {
104104
c.learn( i, input1, bucketIdxList );
105105
}
106-
Predictions result1 = c.infer( 10u, input1 );
106+
Predictions result1 = c.infer( input1 );
107107

108108
ASSERT_EQ( argmax( result1[1u] ), 4u )
109109
<< "Incorrect prediction for bucket 4";
@@ -138,7 +138,7 @@ TEST(SDRClassifierTest, ComputeComplex) {
138138
c.learn(1, input2, bucketIdxList2);
139139
c.learn(2, input3, bucketIdxList3);
140140
c.learn(3, input1, bucketIdxList4);
141-
auto result = c.infer(4, input1);
141+
auto result = c.infer(input1);
142142

143143
// Check the one-step prediction
144144
ASSERT_EQ(result.size(), 1u)
@@ -211,7 +211,7 @@ TEST(SDRClassifierTest, SaveLoad) {
211211
// Measure and save some output.
212212
A.addNoise( 0.20f ); // Change two bits.
213213
c1.reset();
214-
const auto c1_out = c1.infer( 0u, A );
214+
const auto c1_out = c1.infer( A );
215215

216216
// Save and load.
217217
stringstream ss;
@@ -220,7 +220,7 @@ TEST(SDRClassifierTest, SaveLoad) {
220220
EXPECT_NO_THROW(c2.load(ss));
221221

222222
// Expect identical results.
223-
const auto c2_out = c2.infer( 0u, A );
223+
const auto c2_out = c2.infer( A );
224224
ASSERT_EQ(c1_out, c2_out);
225225
}
226226

0 commit comments

Comments
 (0)