Skip to content

Commit 0bc71e1

Browse files
authored
Merge pull request #827 from htm-community/seeds_connections
Connections: serialization, op==
2 parents 5d36c3d + ca47f64 commit 0bc71e1

File tree

2 files changed

+174
-85
lines changed

2 files changed

+174
-85
lines changed

src/htm/algorithms/Connections.cpp

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -763,49 +763,56 @@ std::ostream& operator<< (std::ostream& stream, const Connections& self)
763763

764764

765765

766-
bool Connections::operator==(const Connections &other) const {
767-
if (cells_.size() != other.cells_.size())
768-
return false;
766+
bool Connections::operator==(const Connections &o) const {
767+
try {
768+
NTA_CHECK (cells_.size() == o.cells_.size()) << "Connections equals: cells_" << cells_.size() << " vs. " << o.cells_.size();
769+
NTA_CHECK (cells_ == o.cells_) << "Connections equals: cells_" << cells_.size() << " vs. " << o.cells_.size();
770+
771+
NTA_CHECK (segments_ == o.segments_ ) << "Connections equals: segments_";
772+
NTA_CHECK (destroyedSegments_ == o.destroyedSegments_ ) << "Connections equals: destroyedSegments_";
773+
774+
NTA_CHECK (synapses_ == o.synapses_ ) << "Connections equals: synapses_";
775+
NTA_CHECK (destroyedSynapses_ == o.destroyedSynapses_ ) << "Connections equals: destroyedSynapses_";
776+
777+
778+
//also check underlying datastructures (segments, and subsequently synapses). Can be time consuming.
779+
//1.cells:
780+
for(const auto cellD : cells_) {
781+
//2.segments:
782+
const auto& segments = cellD.segments;
783+
for(const auto seg : segments) {
784+
NTA_CHECK( dataForSegment(seg) == o.dataForSegment(seg) ) << "CellData equals: segmentData";
785+
//3.synapses:
786+
const auto& synapses = dataForSegment(seg).synapses;
787+
for(const auto syn : synapses) {
788+
NTA_CHECK(dataForSynapse(syn) == o.dataForSynapse(syn) ) << "SegmentData equals: synapseData";
789+
}
790+
}
791+
}
769792

770-
if(iteration_ != other.iteration_) return false;
771793

772-
for (CellIdx i = 0; i < static_cast<CellIdx>(cells_.size()); i++) {
773-
const CellData &cellData = cells_[i];
774-
const CellData &otherCellData = other.cells_[i];
794+
NTA_CHECK (connectedThreshold_ == o.connectedThreshold_ ) << "Connections equals: connectedThreshold_";
795+
NTA_CHECK (iteration_ == o.iteration_ ) << "Connections equals: iteration_";
775796

776-
if (cellData.segments.size() != otherCellData.segments.size()) {
777-
return false;
778-
}
797+
NTA_CHECK(potentialSynapsesForPresynapticCell_ == o.potentialSynapsesForPresynapticCell_);
798+
NTA_CHECK(connectedSynapsesForPresynapticCell_ == o.connectedSynapsesForPresynapticCell_);
799+
NTA_CHECK(potentialSegmentsForPresynapticCell_ == o.potentialSegmentsForPresynapticCell_);
800+
NTA_CHECK(connectedSegmentsForPresynapticCell_ == o.connectedSegmentsForPresynapticCell_);
779801

780-
for (SegmentIdx j = 0; j < static_cast<SegmentIdx>(cellData.segments.size()); j++) {
781-
const Segment segment = cellData.segments[j];
782-
const SegmentData &segmentData = segments_[segment];
783-
const Segment otherSegment = otherCellData.segments[j];
784-
const SegmentData &otherSegmentData = other.segments_[otherSegment];
802+
NTA_CHECK (nextSegmentOrdinal_ == o.nextSegmentOrdinal_ ) << "Connections equals: nextSegmentOrdinal_";
803+
NTA_CHECK (nextSynapseOrdinal_ == o.nextSynapseOrdinal_ ) << "Connections equals: nextSynapseOrdinal_";
785804

786-
if (segmentData.synapses.size() != otherSegmentData.synapses.size() ||
787-
segmentData.cell != otherSegmentData.cell) {
788-
return false;
789-
}
805+
NTA_CHECK (timeseries_ == o.timeseries_ ) << "Connections equals: timeseries_";
806+
NTA_CHECK (previousUpdates_ == o.previousUpdates_ ) << "Connections equals: previousUpdates_";
807+
NTA_CHECK (currentUpdates_ == o.currentUpdates_ ) << "Connections equals: currentUpdates_";
790808

791-
for (SynapseIdx k = 0; k < static_cast<SynapseIdx>(segmentData.synapses.size()); k++) {
792-
const Synapse synapse = segmentData.synapses[k];
793-
const SynapseData &synapseData = synapses_[synapse];
794-
const Synapse otherSynapse = otherSegmentData.synapses[k];
795-
const SynapseData &otherSynapseData = other.synapses_[otherSynapse];
809+
NTA_CHECK (prunedSyns_ == o.prunedSyns_ ) << "Connections equals: prunedSyns_";
810+
NTA_CHECK (prunedSegs_ == o.prunedSegs_ ) << "Connections equals: prunedSegs_";
796811

797-
if (synapseData.presynapticCell != otherSynapseData.presynapticCell ||
798-
synapseData.permanence != otherSynapseData.permanence) {
799-
return false;
800-
}
801-
802-
// Two functionally identical instances may have different flatIdxs.
803-
NTA_ASSERT(synapseData.segment == segment);
804-
NTA_ASSERT(otherSynapseData.segment == otherSegment);
805-
}
806-
}
812+
} catch(const htm::Exception& ex) {
813+
std::cout << "Connection equals: differ! " << ex.what();
814+
return false;
807815
}
808-
809816
return true;
810817
}
811818

src/htm/algorithms/Connections.hpp

Lines changed: 132 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,40 @@ struct SynapseData: public Serializable {
6969

7070
SynapseData() {}
7171

72+
//Serialization
7273
CerealAdapter;
7374
template<class Archive>
7475
void save_ar(Archive & ar) const {
75-
ar(cereal::make_nvp("perm", permanence),
76-
cereal::make_nvp("presyn", presynapticCell));
76+
ar(CEREAL_NVP(permanence),
77+
CEREAL_NVP(presynapticCell),
78+
CEREAL_NVP(segment),
79+
CEREAL_NVP(presynapticMapIndex_),
80+
CEREAL_NVP(id)
81+
);
7782
}
7883
template<class Archive>
7984
void load_ar(Archive & ar) {
80-
ar( permanence, presynapticCell);
85+
ar( permanence, presynapticCell, segment, presynapticMapIndex_, id);
8186
}
8287

88+
//operator==
89+
bool operator==(const SynapseData& o) const {
90+
try {
91+
NTA_CHECK(presynapticCell == o.presynapticCell ) << "SynapseData equals: presynapticCell";
92+
NTA_CHECK(permanence == o.permanence ) << "SynapseData equals: permanence";
93+
NTA_CHECK(segment == o.segment ) << "SynapseData equals: segment";
94+
NTA_CHECK(presynapticMapIndex_ == o.presynapticMapIndex_ ) << "SynapseData equals: presynapticMapIndex_";
95+
NTA_CHECK(id == o.id ) << "SynapseData equals: id";
96+
} catch(const htm::Exception& ex) {
97+
//NTA_WARN << "SynapseData equals: " << ex.what(); //Note: uncomment for debug, tells you
98+
//where the diff is. It's perfectly OK for the "exception" to occur, as it just denotes
99+
//that the data is NOT equal.
100+
return false;
101+
}
102+
return true;
103+
}
104+
inline bool operator!=(const SynapseData& o) const { return !operator==(o); }
105+
83106
};
84107

85108
/**
@@ -94,14 +117,48 @@ struct SynapseData: public Serializable {
94117
* @param cell
95118
* The cell that this segment is on.
96119
*/
97-
struct SegmentData {
120+
struct SegmentData: public Serializable {
98121
SegmentData(const CellIdx cell, Segment id, UInt32 lastUsed = 0) : cell(cell), numConnected(0), lastUsed(lastUsed), id(id) {} //default constructor
99122

100123
std::vector<Synapse> synapses;
101124
CellIdx cell; //mother cell that this segment originates from
102125
SynapseIdx numConnected; //number of permanences from `synapses` that are >= synPermConnected, ie connected synapses
103126
UInt32 lastUsed = 0; //last used time (iteration). Used for segment pruning by "least recently used" (LRU) in `createSegment`
104127
Segment id;
128+
129+
//Serialize
130+
SegmentData() {}; //empty constructor for serialization, do not use
131+
CerealAdapter;
132+
template<class Archive>
133+
void save_ar(Archive & ar) const {
134+
ar(CEREAL_NVP(synapses),
135+
CEREAL_NVP(cell),
136+
CEREAL_NVP(numConnected),
137+
CEREAL_NVP(lastUsed),
138+
CEREAL_NVP(id)
139+
);
140+
}
141+
template<class Archive>
142+
void load_ar(Archive & ar) {
143+
ar( synapses, cell, numConnected, lastUsed, id);
144+
}
145+
146+
//equals op==
147+
bool operator==(const SegmentData& o) const {
148+
try {
149+
NTA_CHECK(synapses == o.synapses) << "SegmentData equals: synapses";
150+
NTA_CHECK(cell == o.cell) << "SegmentData equals: cell";
151+
NTA_CHECK(numConnected == o.numConnected) << "SegmentData equals: numConnected";
152+
NTA_CHECK(lastUsed == o.lastUsed) << "SegmentData equals: lastUsed";
153+
NTA_CHECK(id == o.id) << "SegmentData equals: id";
154+
155+
} catch(const htm::Exception& ex) {
156+
//NTA_WARN << "SegmentData equals: " << ex.what();
157+
return false;
158+
}
159+
return true;
160+
}
161+
inline bool operator!=(const SegmentData& o) const { return !operator==(o); }
105162
};
106163

107164
/**
@@ -115,10 +172,35 @@ struct SegmentData {
115172
* Segments on this cell.
116173
*
117174
*/
118-
struct CellData {
175+
struct CellData : public Serializable {
119176
std::vector<Segment> segments;
177+
178+
//Serialization
179+
CerealAdapter;
180+
template<class Archive>
181+
void save_ar(Archive & ar) const {
182+
ar(CEREAL_NVP(segments)
183+
);
184+
}
185+
template<class Archive>
186+
void load_ar(Archive & ar) {
187+
ar( segments);
188+
}
189+
190+
//operator==
191+
bool operator==(const CellData& o) const {
192+
try {
193+
NTA_CHECK( segments == o.segments ) << "CellData equals: segments";
194+
} catch(const htm::Exception& ex) {
195+
//NTA_WARN << "CellData equals: " << ex.what();
196+
return false;
197+
}
198+
return true;
199+
}
200+
inline bool operator!=(const CellData& o) const { return !operator==(o); }
120201
};
121202

203+
122204
/**
123205
* A base class for Connections event handlers.
124206
*
@@ -557,58 +639,58 @@ class Connections : public Serializable
557639
CerealAdapter;
558640
template<class Archive>
559641
void save_ar(Archive & ar) const {
560-
// make this look like a queue of items to be sent.
561-
// and a queue of sizes so we can distribute the
562-
// correct number for each level when deserializing.
563-
std::deque<SynapseData> syndata;
564-
std::deque<size_t> sizes;
565-
sizes.push_back(cells_.size());
566-
for (CellData cellData : cells_) {
567-
const std::vector<Segment> &segments = cellData.segments;
568-
sizes.push_back(segments.size());
569-
for (Segment segment : segments) {
570-
const SegmentData &segmentData = segments_[segment];
571-
const std::vector<Synapse> &synapses = segmentData.synapses;
572-
sizes.push_back(synapses.size());
573-
for (Synapse synapse : synapses) {
574-
const SynapseData &synapseData = synapses_[synapse];
575-
syndata.push_back(synapseData);
576-
}
577-
}
578-
}
579642
ar(CEREAL_NVP(connectedThreshold_));
580-
//the following member must not be serialized (so is set to =0).
581-
//That is because of we serialize only active segments & synapses,
582-
//excluding the "destroyed", so those fields start empty.
583-
//! ar(CEREAL_NVP(destroyedSegments_));
584-
ar(CEREAL_NVP(sizes));
585-
ar(CEREAL_NVP(syndata));
586643
ar(CEREAL_NVP(iteration_));
644+
ar(CEREAL_NVP(cells_));
645+
ar(CEREAL_NVP(segments_));
646+
ar(CEREAL_NVP(synapses_));
647+
648+
ar(CEREAL_NVP(destroyedSynapses_));
649+
ar(CEREAL_NVP(destroyedSegments_));
650+
651+
ar(CEREAL_NVP(potentialSynapsesForPresynapticCell_));
652+
ar(CEREAL_NVP(connectedSynapsesForPresynapticCell_));
653+
ar(CEREAL_NVP(potentialSegmentsForPresynapticCell_));
654+
ar(CEREAL_NVP(connectedSegmentsForPresynapticCell_));
655+
656+
ar(CEREAL_NVP(nextSegmentOrdinal_));
657+
ar(CEREAL_NVP(nextSynapseOrdinal_));
658+
659+
ar(CEREAL_NVP(timeseries_));
660+
ar(CEREAL_NVP(previousUpdates_));
661+
ar(CEREAL_NVP(currentUpdates_));
662+
663+
ar(CEREAL_NVP(prunedSyns_));
664+
ar(CEREAL_NVP(prunedSegs_));
587665
}
588666

589667
template<class Archive>
590668
void load_ar(Archive & ar) {
591-
std::deque<size_t> sizes;
592-
std::deque<SynapseData> syndata;
593669
ar(CEREAL_NVP(connectedThreshold_));
594-
ar(CEREAL_NVP(sizes));
595-
ar(CEREAL_NVP(syndata));
596-
597-
CellIdx numCells = static_cast<CellIdx>(sizes.front()); sizes.pop_front();
598-
initialize(numCells, connectedThreshold_);
599-
for (UInt cell = 0; cell < numCells; cell++) {
600-
size_t numSegments = sizes.front(); sizes.pop_front();
601-
for (SegmentIdx j = 0; j < static_cast<SegmentIdx>(numSegments); j++) {
602-
Segment segment = createSegment( cell );
603-
604-
size_t numSynapses = sizes.front(); sizes.pop_front();
605-
for (SynapseIdx k = 0; k < static_cast<SynapseIdx>(numSynapses); k++) {
606-
SynapseData& syn = syndata.front(); syndata.pop_front();
607-
createSynapse( segment, syn.presynapticCell, syn.permanence );
608-
}
609-
}
610-
}
611670
ar(CEREAL_NVP(iteration_));
671+
//!initialize(numCells, connectedThreshold_); //initialize Connections //Note: we actually don't call Connections
672+
//initialize() as all the members are de/serialized.
673+
ar(CEREAL_NVP(cells_));
674+
ar(CEREAL_NVP(segments_));
675+
ar(CEREAL_NVP(synapses_));
676+
677+
ar(CEREAL_NVP(destroyedSynapses_));
678+
ar(CEREAL_NVP(destroyedSegments_));
679+
680+
ar(CEREAL_NVP(potentialSynapsesForPresynapticCell_));
681+
ar(CEREAL_NVP(connectedSynapsesForPresynapticCell_));
682+
ar(CEREAL_NVP(potentialSegmentsForPresynapticCell_));
683+
ar(CEREAL_NVP(connectedSegmentsForPresynapticCell_));
684+
685+
ar(CEREAL_NVP(nextSegmentOrdinal_));
686+
ar(CEREAL_NVP(nextSynapseOrdinal_));
687+
688+
ar(CEREAL_NVP(timeseries_));
689+
ar(CEREAL_NVP(previousUpdates_));
690+
ar(CEREAL_NVP(currentUpdates_));
691+
692+
ar(CEREAL_NVP(prunedSyns_));
693+
ar(CEREAL_NVP(prunedSegs_));
612694
}
613695

614696
/**
@@ -771,7 +853,7 @@ class Connections : public Serializable
771853
Synapse prunedSyns_ = 0; //how many synapses have been removed?
772854
Segment prunedSegs_ = 0;
773855

774-
//for listeners
856+
//for listeners //TODO listeners are not serialized, nor included in equals ==
775857
UInt32 nextEventToken_;
776858
std::map<UInt32, ConnectionsEventHandler *> eventHandlers_;
777859
}; // end class Connections

0 commit comments

Comments
 (0)