Skip to content

Commit 4d8708d

Browse files
committed
Classifier: simplify asserts
1 parent 2aedac2 commit 4d8708d

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

src/htm/algorithms/SDRClassifier.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void Classifier::initialize(const Real alpha)
3838
{
3939
NTA_CHECK(alpha > 0.0f);
4040
alpha_ = alpha;
41-
dimensions_.clear();
41+
dimensions_ = 0;
4242
numCategories_ = 0u;
4343
weights_.clear();
4444
}
@@ -47,19 +47,9 @@ void Classifier::initialize(const Real alpha)
4747
PDF Classifier::infer(const SDR & pattern) const {
4848
// Check input dimensions, or if this is the first time the Classifier is used and dimensions
4949
// are unset, return zeroes.
50-
if( dimensions_.empty() ) {
51-
return PDF(numCategories_, 0.0f); //empty
52-
} else if( pattern.dimensions != dimensions_ ) {
53-
stringstream err_msg;
54-
err_msg << "Classifier input SDR.dimensions mismatch: previously given SDR with dimensions ( ";
55-
for( auto dim : dimensions_ )
56-
{ err_msg << dim << " "; }
57-
err_msg << "), now given SDR with dimensions ( ";
58-
for( auto dim : pattern.dimensions )
59-
{ err_msg << dim << " "; }
60-
err_msg << ").";
61-
NTA_THROW << err_msg.str();
62-
}
50+
NTA_CHECK( dimensions_ != 0 )
51+
<< "Classifier: must call `learn` before `infer`.";
52+
NTA_ASSERT(pattern.size == dimensions_) << "Input SDR does not match previously seen size!";
6353

6454
// Accumulate feed forward input.
6555
PDF probabilities( numCategories_, 0.0f );
@@ -79,15 +69,17 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
7969
{
8070
// If this is the first time the Classifier is being used, weights are empty,
8171
// so we set the dimensions to that of the input `pattern`
82-
if( dimensions_.empty() ) {
83-
dimensions_ = pattern.dimensions;
72+
if( dimensions_ == 0 ) {
73+
dimensions_ = pattern.size;
8474
while( weights_.size() < pattern.size ) {
8575
const auto initialEmptyWeights = PDF( numCategories_, 0.0f );
8676
weights_.push_back( initialEmptyWeights );
8777
}
8878
}
79+
NTA_ASSERT(pattern.size == dimensions_) << "Input SDR does not match previously seen size!";
80+
8981
// Check if this is a new category & resize the weights table to hold it.
90-
const auto maxCategoryIdx = *max_element(categoryIdxList.begin(), categoryIdxList.end());
82+
const auto maxCategoryIdx = *max_element(categoryIdxList.cbegin(), categoryIdxList.cend());
9183
if( maxCategoryIdx >= numCategories_ ) {
9284
numCategories_ = maxCategoryIdx + 1;
9385
for( auto & vec : weights_ ) {

src/htm/algorithms/SDRClassifier.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class Classifier : public Serializable
155155

156156
private:
157157
Real alpha_;
158-
std::vector<UInt> dimensions_;
158+
UInt dimensions_;
159159
UInt numCategories_;
160160

161161
/**

0 commit comments

Comments
 (0)