17
17
18
18
import argparse
19
19
import random
20
- import gzip
21
20
import numpy as np
22
21
import os
22
+ import sys
23
23
24
24
from htm .bindings .algorithms import SpatialPooler , Classifier
25
25
from htm .bindings .sdr import SDR , Metrics
@@ -36,7 +36,7 @@ def int32(b):
36
36
return i
37
37
38
38
def load_labels (file_name ):
39
- with gzip . open (file_name , 'rb' ) as f :
39
+ with open (file_name , 'rb' ) as f :
40
40
raw = f .read ()
41
41
assert (int32 (raw [0 :4 ]) == 2049 ) # Magic number
42
42
labels = []
@@ -46,7 +46,7 @@ def load_labels(file_name):
46
46
return labels
47
47
48
48
def load_images (file_name ):
49
- with gzip . open (file_name , 'rb' ) as f :
49
+ with open (file_name , 'rb' ) as f :
50
50
raw = f .read ()
51
51
assert (int32 (raw [0 :4 ]) == 2051 ) # Magic number
52
52
num_imgs = int32 (raw [4 :8 ])
@@ -67,32 +67,36 @@ def load_images(file_name):
67
67
assert (len (raw ) == data_start + img_size * num_imgs ) # All data should be used.
68
68
return imgs
69
69
70
- train_labels = load_labels (os .path .join (path , 'train-labels-idx1-ubyte.gz ' ))
71
- train_images = load_images (os .path .join (path , 'train-images-idx3-ubyte.gz ' ))
72
- test_labels = load_labels (os .path .join (path , 't10k-labels-idx1-ubyte.gz ' ))
73
- test_images = load_images (os .path .join (path , 't10k-images-idx3-ubyte.gz ' ))
70
+ train_labels = load_labels (os .path .join (path , 'train-labels-idx1-ubyte' ))
71
+ train_images = load_images (os .path .join (path , 'train-images-idx3-ubyte' ))
72
+ test_labels = load_labels (os .path .join (path , 't10k-labels-idx1-ubyte' ))
73
+ test_images = load_images (os .path .join (path , 't10k-images-idx3-ubyte' ))
74
74
75
75
return train_labels , train_images , test_labels , test_images
76
76
77
-
77
+ # These parameters can be improved using parameter optimization,
78
+ # see py/htm/optimization/ae.py
79
+ # For more explanation of relations between the parameters, see
80
+ # src/examples/mnist/MNIST_CPP.cpp
78
81
default_parameters = {
79
- 'boostStrength' : 7.80643753517375 ,
80
- 'columnDimensions' : (35415 ,1 ),
81
- 'dutyCyclePeriod' : 1321 ,
82
- 'localAreaDensity' : 0.05361688506086096 ,
83
- 'minPctOverlapDutyCycle' : 0.0016316043362658 ,
84
- 'potentialPct' : 0.06799785776775163 ,
85
- 'stimulusThreshold' : 8 ,
86
- 'synPermActiveInc' : 0.01455789388651146 ,
87
- 'synPermConnected' : 0.021649964738697944 ,
88
- 'synPermInactiveDec' : 0.006442691852205935
82
+ 'potentialRadius' : 7 ,
83
+ 'boostStrength' : 7.0 ,
84
+ 'columnDimensions' : (28 * 28 * 8 , 1 ),
85
+ 'dutyCyclePeriod' : 1402 ,
86
+ 'localAreaDensity' : 0.1 ,
87
+ 'minPctOverlapDutyCycle' : 0.2 ,
88
+ 'potentialPct' : 0.1 ,
89
+ 'stimulusThreshold' : 6 ,
90
+ 'synPermActiveInc' : 0.14 ,
91
+ 'synPermConnected' : 0.5 ,
92
+ 'synPermInactiveDec' : 0.02
89
93
}
90
94
91
95
92
96
def main (parameters = default_parameters , argv = None , verbose = True ):
93
97
parser = argparse .ArgumentParser ()
94
98
parser .add_argument ('--data_dir' , type = str ,
95
- default = os .path .join ( os .path .dirname (__file__ ), 'MNIST_data ' ))
99
+ default = os .path .join ( os .path .dirname (__file__ ), '..' , '..' , '..' , 'build' , 'ThirdParty' , 'mnist_data' , 'mnist-src ' ))
96
100
args = parser .parse_args (args = argv )
97
101
98
102
# Load data.
@@ -107,11 +111,10 @@ def main(parameters=default_parameters, argv=None, verbose=True):
107
111
sp = SpatialPooler (
108
112
inputDimensions = enc .dimensions ,
109
113
columnDimensions = parameters ['columnDimensions' ],
110
- potentialRadius = 99999999 ,
114
+ potentialRadius = parameters [ 'potentialRadius' ] ,
111
115
potentialPct = parameters ['potentialPct' ],
112
116
globalInhibition = True ,
113
117
localAreaDensity = parameters ['localAreaDensity' ],
114
- numActiveColumnsPerInhArea = - 1 ,
115
118
stimulusThreshold = int (round (parameters ['stimulusThreshold' ])),
116
119
synPermInactiveDec = parameters ['synPermInactiveDec' ],
117
120
synPermActiveInc = parameters ['synPermActiveInc' ],
@@ -143,10 +146,11 @@ def main(parameters=default_parameters, argv=None, verbose=True):
143
146
sp .compute ( enc , False , columns )
144
147
if lbl == np .argmax ( sdrc .infer ( columns ) ):
145
148
score += 1
149
+ score = score / len (test_data )
146
150
147
- print ('Score:' , 100 * score / len ( test_data ) , '%' )
148
- return score / len ( test_data )
151
+ print ('Score:' , 100 * score , '%' )
152
+ return score < 0.95
149
153
150
154
151
155
if __name__ == '__main__' :
152
- main ()
156
+ sys . exit ( main () )
0 commit comments