Skip to content

Commit 13446b0

Browse files
authored
Merge pull request #566 from htm-community/hotgym-fix
Fix Hot Gym Python Example
2 parents 438228a + b5a4cfe commit 13446b0

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

py/htm/examples/hotgym.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
'sdrc_alpha': 0.1,
2222
'sp': {'boostStrength': 3.0,
2323
'columnCount': 1638,
24-
'numActiveColumnsPerInhArea': 72,
24+
'localAreaDensity': 0.04395604395604396,
2525
'potentialPct': 0.85,
2626
'synPermActiveInc': 0.04,
2727
'synPermConnected': 0.13999999999999999,
@@ -63,8 +63,7 @@ def main(parameters=default_parameters, argv=None, verbose=True):
6363
potentialPct = spParams["potentialPct"],
6464
potentialRadius = encodingWidth,
6565
globalInhibition = True,
66-
localAreaDensity = -1,
67-
numActiveColumnsPerInhArea = spParams["numActiveColumnsPerInhArea"],
66+
localAreaDensity = spParams["localAreaDensity"],
6867
synPermInactiveDec = spParams["synPermInactiveDec"],
6968
synPermActiveInc = spParams["synPermActiveInc"],
7069
synPermConnected = spParams["synPermConnected"],

py/htm/examples/mnist.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
import argparse
1919
import random
20-
import gzip
2120
import numpy as np
2221
import os
22+
import sys
2323

2424
from htm.bindings.algorithms import SpatialPooler, Classifier
2525
from htm.bindings.sdr import SDR, Metrics
@@ -36,7 +36,7 @@ def int32(b):
3636
return i
3737

3838
def load_labels(file_name):
39-
with gzip.open(file_name, 'rb') as f:
39+
with open(file_name, 'rb') as f:
4040
raw = f.read()
4141
assert(int32(raw[0:4]) == 2049) # Magic number
4242
labels = []
@@ -46,7 +46,7 @@ def load_labels(file_name):
4646
return labels
4747

4848
def load_images(file_name):
49-
with gzip.open(file_name, 'rb') as f:
49+
with open(file_name, 'rb') as f:
5050
raw = f.read()
5151
assert(int32(raw[0:4]) == 2051) # Magic number
5252
num_imgs = int32(raw[4:8])
@@ -67,32 +67,36 @@ def load_images(file_name):
6767
assert(len(raw) == data_start + img_size * num_imgs) # All data should be used.
6868
return imgs
6969

70-
train_labels = load_labels(os.path.join(path, 'train-labels-idx1-ubyte.gz'))
71-
train_images = load_images(os.path.join(path, 'train-images-idx3-ubyte.gz'))
72-
test_labels = load_labels(os.path.join(path, 't10k-labels-idx1-ubyte.gz'))
73-
test_images = load_images(os.path.join(path, 't10k-images-idx3-ubyte.gz'))
70+
train_labels = load_labels(os.path.join(path, 'train-labels-idx1-ubyte'))
71+
train_images = load_images(os.path.join(path, 'train-images-idx3-ubyte'))
72+
test_labels = load_labels(os.path.join(path, 't10k-labels-idx1-ubyte'))
73+
test_images = load_images(os.path.join(path, 't10k-images-idx3-ubyte'))
7474

7575
return train_labels, train_images, test_labels, test_images
7676

77-
77+
# These parameters can be improved using parameter optimization,
78+
# see py/htm/optimization/ae.py
79+
# For more explanation of relations between the parameters, see
80+
# src/examples/mnist/MNIST_CPP.cpp
7881
default_parameters = {
79-
'boostStrength': 7.80643753517375,
80-
'columnDimensions': (35415,1),
81-
'dutyCyclePeriod': 1321,
82-
'localAreaDensity': 0.05361688506086096,
83-
'minPctOverlapDutyCycle': 0.0016316043362658,
84-
'potentialPct': 0.06799785776775163,
85-
'stimulusThreshold': 8,
86-
'synPermActiveInc': 0.01455789388651146,
87-
'synPermConnected': 0.021649964738697944,
88-
'synPermInactiveDec': 0.006442691852205935
82+
'potentialRadius': 7,
83+
'boostStrength': 7.0,
84+
'columnDimensions': (28*28*8, 1),
85+
'dutyCyclePeriod': 1402,
86+
'localAreaDensity': 0.1,
87+
'minPctOverlapDutyCycle': 0.2,
88+
'potentialPct': 0.1,
89+
'stimulusThreshold': 6,
90+
'synPermActiveInc': 0.14,
91+
'synPermConnected': 0.5,
92+
'synPermInactiveDec': 0.02
8993
}
9094

9195

9296
def main(parameters=default_parameters, argv=None, verbose=True):
9397
parser = argparse.ArgumentParser()
9498
parser.add_argument('--data_dir', type=str,
95-
default = os.path.join( os.path.dirname(__file__), 'MNIST_data'))
99+
default = os.path.join( os.path.dirname(__file__), '..', '..', '..', 'build', 'ThirdParty', 'mnist_data', 'mnist-src'))
96100
args = parser.parse_args(args = argv)
97101

98102
# Load data.
@@ -107,11 +111,10 @@ def main(parameters=default_parameters, argv=None, verbose=True):
107111
sp = SpatialPooler(
108112
inputDimensions = enc.dimensions,
109113
columnDimensions = parameters['columnDimensions'],
110-
potentialRadius = 99999999,
114+
potentialRadius = parameters['potentialRadius'],
111115
potentialPct = parameters['potentialPct'],
112116
globalInhibition = True,
113117
localAreaDensity = parameters['localAreaDensity'],
114-
numActiveColumnsPerInhArea = -1,
115118
stimulusThreshold = int(round(parameters['stimulusThreshold'])),
116119
synPermInactiveDec = parameters['synPermInactiveDec'],
117120
synPermActiveInc = parameters['synPermActiveInc'],
@@ -143,10 +146,11 @@ def main(parameters=default_parameters, argv=None, verbose=True):
143146
sp.compute( enc, False, columns )
144147
if lbl == np.argmax( sdrc.infer( columns ) ):
145148
score += 1
149+
score = score / len(test_data)
146150

147-
print('Score:', 100 * score / len(test_data), '%')
148-
return score / len(test_data)
151+
print('Score:', 100 * score, '%')
152+
return score < 0.95
149153

150154

151155
if __name__ == '__main__':
152-
main()
156+
sys.exit( main() )

0 commit comments

Comments
 (0)