@@ -38,7 +38,7 @@ void Classifier::initialize(const Real alpha)
38
38
{
39
39
NTA_CHECK (alpha > 0 .0f );
40
40
alpha_ = alpha;
41
- dimensions_. clear () ;
41
+ dimensions_ = 0 ;
42
42
numCategories_ = 0u ;
43
43
weights_.clear ();
44
44
}
@@ -47,19 +47,9 @@ void Classifier::initialize(const Real alpha)
47
47
PDF Classifier::infer (const SDR & pattern) const {
48
48
// Check input dimensions, or if this is the first time the Classifier is used and dimensions
49
49
// are unset, return zeroes.
50
- if ( dimensions_.empty () ) {
51
- return PDF (numCategories_, 0 .0f ); // empty
52
- } else if ( pattern.dimensions != dimensions_ ) {
53
- stringstream err_msg;
54
- err_msg << " Classifier input SDR.dimensions mismatch: previously given SDR with dimensions ( " ;
55
- for ( auto dim : dimensions_ )
56
- { err_msg << dim << " " ; }
57
- err_msg << " ), now given SDR with dimensions ( " ;
58
- for ( auto dim : pattern.dimensions )
59
- { err_msg << dim << " " ; }
60
- err_msg << " )." ;
61
- NTA_THROW << err_msg.str ();
62
- }
50
+ NTA_CHECK ( dimensions_ != 0 )
51
+ << " Classifier: must call `learn` before `infer`." ;
52
+ NTA_ASSERT (pattern.size == dimensions_) << " Input SDR does not match previously seen size!" ;
63
53
64
54
// Accumulate feed forward input.
65
55
PDF probabilities ( numCategories_, 0 .0f );
@@ -79,15 +69,17 @@ void Classifier::learn(const SDR &pattern, const vector<UInt> &categoryIdxList)
79
69
{
80
70
// If this is the first time the Classifier is being used, weights are empty,
81
71
// so we set the dimensions to that of the input `pattern`
82
- if ( dimensions_. empty () ) {
83
- dimensions_ = pattern.dimensions ;
72
+ if ( dimensions_ == 0 ) {
73
+ dimensions_ = pattern.size ;
84
74
while ( weights_.size () < pattern.size ) {
85
75
const auto initialEmptyWeights = PDF ( numCategories_, 0 .0f );
86
76
weights_.push_back ( initialEmptyWeights );
87
77
}
88
78
}
79
+ NTA_ASSERT (pattern.size == dimensions_) << " Input SDR does not match previously seen size!" ;
80
+
89
81
// Check if this is a new category & resize the weights table to hold it.
90
- const auto maxCategoryIdx = *max_element (categoryIdxList.begin (), categoryIdxList.end ());
82
+ const auto maxCategoryIdx = *max_element (categoryIdxList.cbegin (), categoryIdxList.cend ());
91
83
if ( maxCategoryIdx >= numCategories_ ) {
92
84
numCategories_ = maxCategoryIdx + 1 ;
93
85
for ( auto & vec : weights_ ) {
0 commit comments