@@ -38,32 +38,18 @@ void Classifier::initialize(const Real alpha)
38
38
{
39
39
NTA_CHECK (alpha > 0 .0f );
40
40
alpha_ = alpha;
41
- dimensions_. clear () ;
41
+ dimensions_ = 0 ;
42
42
numCategories_ = 0u ;
43
43
weights_.clear ();
44
44
}
45
45
46
46
47
- PDF Classifier::infer (const SDR & pattern)
48
- {
49
- // Check input dimensions, or if this is the first time the Classifier has
50
- // been used then initialize it with the given SDR's dimensions.
51
- if ( dimensions_.empty () ) {
52
- dimensions_ = pattern.dimensions ;
53
- while ( weights_.size () < pattern.size ) {
54
- weights_.push_back ( vector<Real>( numCategories_, 0 .0f ));
55
- }
56
- } else if ( pattern.dimensions != dimensions_ ) {
57
- stringstream err_msg;
58
- err_msg << " Classifier input SDR.dimensions mismatch: previously given SDR with dimensions ( " ;
59
- for ( auto dim : dimensions_ )
60
- { err_msg << dim << " " ; }
61
- err_msg << " ), now given SDR with dimensions ( " ;
62
- for ( auto dim : pattern.dimensions )
63
- { err_msg << dim << " " ; }
64
- err_msg << " )." ;
65
- NTA_THROW << err_msg.str ();
66
- }
47
+ PDF Classifier::infer (const SDR & pattern) const {
48
+ // Check input dimensions, or if this is the first time the Classifier is used and dimensions
49
+ // are unset, return zeroes.
50
+ NTA_CHECK ( dimensions_ != 0 )
51
+ << " Classifier: must call `learn` before `infer`." ;
52
+ NTA_ASSERT (pattern.size == dimensions_) << " Input SDR does not match previously seen size!" ;
67
53
68
54
// Accumulate feed forward input.
69
55
PDF probabilities ( numCategories_, 0 .0f );
@@ -81,8 +67,19 @@ PDF Classifier::infer(const SDR & pattern)
81
67
82
68
void Classifier::learn (const SDR &pattern, const vector<UInt> &categoryIdxList)
83
69
{
70
+ // If this is the first time the Classifier is being used, weights are empty,
71
+ // so we set the dimensions to that of the input `pattern`
72
+ if ( dimensions_ == 0 ) {
73
+ dimensions_ = pattern.size ;
74
+ while ( weights_.size () < pattern.size ) {
75
+ const auto initialEmptyWeights = PDF ( numCategories_, 0 .0f );
76
+ weights_.push_back ( initialEmptyWeights );
77
+ }
78
+ }
79
+ NTA_ASSERT (pattern.size == dimensions_) << " Input SDR does not match previously seen size!" ;
80
+
84
81
// Check if this is a new category & resize the weights table to hold it.
85
- const auto maxCategoryIdx = *max_element (categoryIdxList.begin (), categoryIdxList.end ());
82
+ const auto maxCategoryIdx = *max_element (categoryIdxList.cbegin (), categoryIdxList.cend ());
86
83
if ( maxCategoryIdx >= numCategories_ ) {
87
84
numCategories_ = maxCategoryIdx + 1 ;
88
85
for ( auto & vec : weights_ ) {
@@ -93,7 +90,7 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
93
90
}
94
91
95
92
// Compute errors and update weights.
96
- const vector<Real> error = calculateError_ (categoryIdxList, pattern);
93
+ const auto & error = calculateError_ (categoryIdxList, pattern);
97
94
for ( const auto & bit : pattern.getSparse () ) {
98
95
for (size_t i = 0u ; i < numCategories_; i++) {
99
96
weights_[bit][i] += alpha_ * error[i];
@@ -103,9 +100,8 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
103
100
104
101
105
102
// Helper function to compute the error signal in learning.
106
- std::vector<Real> Classifier::calculateError_ (
107
- const std::vector<UInt> &categoryIdxList, const SDR &pattern)
108
- {
103
+ std::vector<Real64> Classifier::calculateError_ (const std::vector<UInt> &categoryIdxList,
104
+ const SDR &pattern) const {
109
105
// compute predicted likelihoods
110
106
auto likelihoods = infer (pattern);
111
107
@@ -165,56 +161,49 @@ void Predictor::reset() {
165
161
}
166
162
167
163
168
- Predictions Predictor::infer (const UInt recordNum, const SDR &pattern)
169
- {
170
- updateHistory_ ( recordNum, pattern );
171
-
164
+ Predictions Predictor::infer (const SDR &pattern) const {
172
165
Predictions result;
173
166
for ( const auto step : steps_ ) {
174
- result[ step] = classifiers_[ step] .infer ( pattern );
167
+ result. insert ({ step, classifiers_. at ( step) .infer ( pattern )} );
175
168
}
176
169
return result;
177
170
}
178
171
179
172
180
- void Predictor::learn (const UInt recordNum, const SDR &pattern,
173
+ void Predictor::learn (const UInt recordNum, // TODO make recordNum optional, autoincrement as steps
174
+ const SDR &pattern,
181
175
const std::vector<UInt> &bucketIdxList)
182
176
{
183
- updateHistory_ ( recordNum, pattern );
177
+ checkMonotonic_ (recordNum);
178
+
179
+ // Update pattern history if this is a new record.
180
+ const UInt lastRecordNum = recordNumHistory_.empty () ? 0 : recordNumHistory_.back ();
181
+ if (recordNumHistory_.size () == 0u || recordNum > lastRecordNum) {
182
+ patternHistory_.emplace_back ( pattern );
183
+ recordNumHistory_.push_back (recordNum);
184
+ if (patternHistory_.size () > steps_.back () + 1u ) { // steps_ are sorted, so steps_.back() is the "oldest/deepest" N-th step (ie 10 of [1,2,10])
185
+ patternHistory_.pop_front ();
186
+ recordNumHistory_.pop_front ();
187
+ }
188
+ }
184
189
185
190
// Iterate through all recently given inputs, starting from the furthest in the past.
186
191
auto pastPattern = patternHistory_.begin ();
187
192
auto pastRecordNum = recordNumHistory_.begin ();
188
- for ( ; pastRecordNum != recordNumHistory_.end (); pastPattern++, pastRecordNum++ )
193
+ for ( ; pastRecordNum != recordNumHistory_.cend (); pastPattern++, pastRecordNum++ )
189
194
{
190
195
const UInt nSteps = recordNum - *pastRecordNum;
191
196
192
197
// Update weights.
193
198
if ( binary_search ( steps_.begin (), steps_.end (), nSteps )) {
194
- classifiers_[ nSteps] .learn ( *pastPattern, bucketIdxList );
199
+ classifiers_. at ( nSteps) .learn ( *pastPattern, bucketIdxList );
195
200
}
196
201
}
197
202
}
198
203
199
204
200
- void Predictor::updateHistory_ (const UInt recordNum, const SDR & pattern)
201
- {
205
+ void Predictor::checkMonotonic_ (const UInt recordNum) const {
202
206
// Ensure that recordNum increases monotonically.
203
- UInt lastRecordNum = -1 ;
204
- if ( not recordNumHistory_.empty () ) {
205
- lastRecordNum = recordNumHistory_.back ();
206
- if (recordNum < lastRecordNum) {
207
- NTA_THROW << " The record number must increase monotonically." ;
208
- }
209
- }
210
-
211
- // Update pattern history if this is a new record.
212
- if (recordNumHistory_.size () == 0u || recordNum > lastRecordNum) {
213
- patternHistory_.emplace_back ( pattern );
214
- recordNumHistory_.push_back (recordNum);
215
- if (patternHistory_.size () > steps_.back () + 1u ) {
216
- patternHistory_.pop_front ();
217
- recordNumHistory_.pop_front ();
218
- }
219
- }
207
+ const UInt lastRecordNum = recordNumHistory_.empty () ? 0 : recordNumHistory_.back ();
208
+ NTA_CHECK (recordNum >= lastRecordNum) << " The record number must increase monotonically." ;
220
209
}
0 commit comments