Skip to content

Commit 6475fba

Browse files
authored
Merge pull request #537 from htm-community/tm_conn
TM Connections integration
2 parents 8e11cdb + 24983cf commit 6475fba

File tree

10 files changed

+336
-278
lines changed

10 files changed

+336
-278
lines changed

bindings/py/cpp_src/bindings/algorithms/py_Connections.cpp

+19-9
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,14 @@ R"(Compatibility Warning: This classes API is unstable and may change without wa
4949
[](const Connections &self) { return self.getConnectedThreshold(); });
5050

5151
py_Connections.def("createSegment", &Connections::createSegment,
52-
py::arg("cell"));
52+
py::arg("cell"),
53+
py::arg("maxSegmentsPerCell") = 0
54+
);
5355

5456
py_Connections.def("destroySegment", &Connections::destroySegment);
5557

58+
py_Connections.def("iteration", &Connections::iteration);
59+
5660
py_Connections.def("createSynapse", &Connections::createSynapse,
5761
py::arg("segment"),
5862
py::arg("presynaticCell"),
@@ -93,24 +97,26 @@ R"(Compatibility Warning: This classes API is unstable and may change without wa
9397
py_Connections.def("reset", &Connections::reset);
9498

9599
py_Connections.def("computeActivity",
96-
[](Connections &self, SDR &activePresynapticCells) {
100+
[](Connections &self, SDR &activePresynapticCells, bool learn=true) {
97101
// Allocate buffer to return & make a python destructor object for it.
98102
auto activeConnectedSynapses =
99103
new std::vector<SynapseIdx>( self.segmentFlatListLength(), 0u );
100104
auto destructor = py::capsule( activeConnectedSynapses,
101105
[](void *dataPtr) {
102106
delete reinterpret_cast<std::vector<SynapseIdx>*>(dataPtr); });
103-
// Call the C++ method.
104-
self.computeActivity(*activeConnectedSynapses, activePresynapticCells.getSparse());
105-
// Wrap vector in numpy array.
107+
108+
// Call the C++ method.
109+
self.computeActivity(*activeConnectedSynapses, activePresynapticCells.getSparse(), learn);
110+
111+
// Wrap vector in numpy array.
106112
return py::array(activeConnectedSynapses->size(),
107113
activeConnectedSynapses->data(),
108114
destructor);
109115
},
110116
R"(Returns numActiveConnectedSynapsesForSegment)");
111117

112118
py_Connections.def("computeActivityFull",
113-
[](Connections &self, SDR &activePresynapticCells) {
119+
[](Connections &self, SDR &activePresynapticCells, bool learn=true) {
114120
// Allocate buffer to return & make a python destructor object for it.
115121
auto activeConnectedSynapses =
116122
new std::vector<SynapseIdx>( self.segmentFlatListLength(), 0u );
@@ -123,9 +129,13 @@ R"(Returns numActiveConnectedSynapsesForSegment)");
123129
auto potentialDestructor = py::capsule( activePotentialSynapses,
124130
[](void *dataPtr) {
125131
delete reinterpret_cast<std::vector<SynapseIdx>*>(dataPtr); });
126-
// Call the C++ method.
127-
self.computeActivity(*activeConnectedSynapses, *activePotentialSynapses,
128-
activePresynapticCells.getSparse());
132+
133+
// Call the C++ method.
134+
self.computeActivity(*activeConnectedSynapses,
135+
*activePotentialSynapses,
136+
activePresynapticCells.getSparse(),
137+
learn);
138+
129139
// Wrap vector in numpy array.
130140
return py::make_tuple(
131141
py::array(activeConnectedSynapses->size(),

src/examples/hotgym/HelloSPTP.cpp

+24-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "htm/types/Sdr.hpp"
3030
#include "htm/utils/Random.hpp"
3131
#include "htm/utils/MovingAverage.hpp"
32+
#include "htm/utils/SdrMetrics.hpp"
3233

3334
namespace examples {
3435

@@ -79,6 +80,13 @@ Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool
7980
SDR outTM(spGlobal.getColumnDimensions());
8081
Real an = 0.0f, anLikely = 0.0f; //for anomaly:
8182
MovingAverage avgAnom10(1000); //chose the window large enough so there's (some) periodicity in the patter, so TM can learn something
83+
84+
//metrics
85+
Metrics statsInput(input, 1000);
86+
Metrics statsSPlocal(outSPlocal, 1000);
87+
Metrics statsSPglobal(outSPglobal, 1000);
88+
Metrics statsTM(outTM, 1000);
89+
8290
/*
8391
* For example: fn = sin(x) -> periodic >= 2Pi ~ 6.3 && x+=0.01 -> 630 steps to 1st period -> window >= 630
8492
*/
@@ -147,13 +155,26 @@ Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool
147155
if (e == EPOCHS - 1) {
148156
tAll.stop();
149157

158+
//print connections stats
159+
cout << "\nInput :\n" << statsInput
160+
<< "\nSP(local) " << spLocal.connections
161+
<< "\nSP(local) " << statsSPlocal
162+
<< "\nSP(global) " << spGlobal.connections
163+
<< "\nSP(global) " << statsSPglobal
164+
<< "\nTM " << tm.connections
165+
<< "\nTM " << statsTM
166+
<< "\n";
167+
168+
// output values
150169
cout << "Epoch = " << e << endl;
151170
cout << "Anomaly = " << an << endl;
152171
cout << "Anomaly (avg) = " << avgAnom10.getCurrentAvg() << endl;
153172
cout << "Anomaly (Likelihood) = " << anLikely << endl;
154173
cout << "SP (g)= " << outSP << endl;
155174
cout << "SP (l)= " << outSPlocal <<endl;
156175
cout << "TM= " << outTM << endl;
176+
177+
//timers
157178
cout << "==============TIMERS============" << endl;
158179
cout << "Init:\t" << tInit.getElapsed() << endl;
159180
cout << "Random:\t" << tRng.getElapsed() << endl;
@@ -184,12 +205,12 @@ Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool
184205

185206
SDR goldTM({COLS});
186207
const SDR_sparse_t deterministicTM{
187-
51, 62, 72, 77, 102, 155, 287, 306, 337, 340, 370, 493, 542, 952, 1089, 1110, 1115, 1193, 1463, 1488, 1507, 1518, 1547, 1626, 1668, 1694, 1781, 1803, 1805, 1827, 1841, 1858,1859, 1860, 1861, 1862, 1878, 1881, 1915, 1918, 1923, 1929, 1933, 1939, 1941, 1953, 1955, 1956, 1958, 1961, 1965, 1968, 1975, 1976, 1980, 1981, 1985, 1986, 1987, 1991, 1992, 1994, 1997, 2002, 2006, 2008, 2012, 2013, 2040, 2042
208+
62, 77, 85, 322, 340, 432, 952, 1120, 1488, 1502, 1512, 1518, 1547, 1627, 1633, 1668, 1727, 1729, 1797, 1803, 1805, 1812, 1858, 1859, 1896, 1918, 1923, 1925, 1929, 1931, 1939, 1941, 1942, 1944, 1950, 1953, 1955, 1956, 1965, 1966, 1967, 1968, 1974, 1980, 1987, 1996, 2006, 2008, 2011, 2027, 2030, 2042, 2046
188209
};
189210
goldTM.setSparse(deterministicTM);
190211

191-
const float goldAn = 0.745098f;
192-
const float goldAnAvg = 0.408286f;
212+
const float goldAn = 0.627451f;
213+
const float goldAnAvg = 0.407265f;
193214

194215
if(EPOCHS == 5000) { //these hand-written values are only valid for EPOCHS = 5000 (default), but not for debug and custom runs.
195216
NTA_CHECK(input == goldEnc) << "Deterministic output of Encoder failed!\n" << input << "should be:\n" << goldEnc;

0 commit comments

Comments
 (0)