Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.

support py3.6+linux #16

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Reference implementation of a two-level RCN model on MNIST classification. See t

## Setup

Note: Python 2.7 is supported. The code was tested on OSX 10.11. It may work on other system platforms but not guaranteed.
Note: Python 2.7 and python 3.6 is supported. The code was tested on OSX 10.11. It may work on other system platforms but not guaranteed.

Before starting please make sure gcc is installed (`brew install gcc`) and up to date in order to compile the various dependencies (particularly numpy).

Expand Down
11 changes: 5 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
decorator==4.1.2
decorator>=4.1.2
networkx==1.11
numpy==1.13.3
olefile==0.44
Pillow==4.1.1
rcn-ref==1.0.0
scipy==0.19.1
numpy>=1.13.3
olefile>=0.44
Pillow>=4.1.1
scipy>=0.19.1
33 changes: 26 additions & 7 deletions science_rcn/dilation/dilation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,39 @@ using namespace std;

/* ==== Set up the methods table ====================== */
static PyMethodDef dilationmethods[] = {
{"max_filter1d", py_max_filter1d, METH_VARARGS},
{"brute_max_filter1d", py_brute_max_filter1d, METH_VARARGS},
{NULL, NULL} /* Sentinel - marks the end of this structure */
{"max_filter1d", py_max_filter1d, METH_VARARGS, "max filter1d"},
{"brute_max_filter1d", py_brute_max_filter1d, METH_VARARGS, "brute max filter1d"},
{NULL, NULL, 0, NULL}
};


/* ==== Initialize the C_test functions ====================== */
extern "C" {
void init_dilation()
/* This initiates the module using the above definitions. */
#if PY_VERSION_HEX >= 0x03000000
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_dilation",
NULL,
-1,
dilationmethods,
NULL,
NULL,
NULL,
NULL
};

PyMODINIT_FUNC PyInit__dilation(void)
{
(void) Py_InitModule("_dilation", dilationmethods);
import_array(); // Must be present for NumPy. Called first after above line.
import_array();
return PyModule_Create(&moduledef);
}
#else
PyMODINIT_FUNC init_dilation(void)
{
(void) Py_InitModule("_dilation", dilationmethods);
import_array();
}
#endif

/// Check condition and return NULL (which will cause a python exception) if
/// it's false, and include an arbitrary format string as the error message
Expand Down
5 changes: 2 additions & 3 deletions science_rcn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Note that we use a faster implementation of 2D dilation, instead of the slower
scipy.ndimage.morphology.grey_dilation.
"""
from itertools import izip
import logging
import numpy as np
import networkx as nx
Expand Down Expand Up @@ -56,7 +55,7 @@ def test_image(img, model_factors,

# Forward pass inference
fp_scores = np.zeros(len(model_factors[0]))
for i, (frcs, _, graph) in enumerate(izip(*model_factors)):
for i, (frcs, _, graph) in enumerate(list(zip(*model_factors))):
fp_scores[i] = forward_pass(frcs,
bu_msg,
graph,
Expand Down Expand Up @@ -315,7 +314,7 @@ def infer_pbp(self):
"""Parallel loopy BP message passing, modifying state of `lat_messages`.
See bwd_pass() for parameters.
"""
for it in xrange(self.n_iters):
for it in range(self.n_iters):
new_lat_messages = self.new_messages()
delta = new_lat_messages - self.lat_messages
self.lat_messages += self.damping * delta
Expand Down
4 changes: 2 additions & 2 deletions science_rcn/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def generate_suppression_masks(filter_scale=4., num_orients=16):
# Compute for orientations [0, pi), then flip for [pi, 2*pi)
for i, angle in enumerate(np.linspace(0., np.pi, num_orients // 2, endpoint=False)):
x, y = np.cos(angle), np.sin(angle)
for r in xrange(1, int(np.sqrt(2) * size / 2)):
for r in range(1, int(np.sqrt(2) * size / 2)):
dx, dy = round(r * x), round(r * y)
if abs(dx) > cx or abs(dy) > cy:
continue
Expand Down Expand Up @@ -160,7 +160,7 @@ def local_nonmax_suppression(filtered, suppression_masks, num_orients=16):
localized = np.zeros_like(filtered)
cross_orient_max = filtered.max(0)
filtered[filtered < 0] = 0
for i, (layer, suppress_mask) in enumerate(zip(filtered, suppression_masks)):
for i, (layer, suppress_mask) in enumerate(list(zip(filtered, suppression_masks))):
competitor_maxs = maximum_filter(layer, footprint=suppress_mask, mode='nearest')
localized[i] = competitor_maxs <= layer
localized[cross_orient_max > filtered] = 0
Expand Down
4 changes: 2 additions & 2 deletions science_rcn/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run_experiment(data_dir='data/MNIST',
train_partial = partial(train_image,
perturb_factor=perturb_factor)
train_results = pool.map_async(train_partial, [d[0] for d in train_data]).get(9999999)
all_model_factors = zip(*train_results)
all_model_factors = list(zip(*train_results))

LOG.info("Testing on {} images...".format(len(test_data)))
test_partial = partial(test_image, model_factors=all_model_factors,
Expand All @@ -92,7 +92,7 @@ def run_experiment(data_dir='data/MNIST',
correct = 0
for test_idx, (winner_idx, _) in enumerate(test_results):
correct += int(test_data[test_idx][1]) == winner_idx // (train_size // 10)
print "Total test accuracy = {}".format(float(correct) / len(test_results))
print("Total test accuracy = {}".format(float(correct) / len(test_results)))

return all_model_factors, test_results

Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _findRequirements():
# Check for MNIST data dir
if not os.path.isdir('./data/MNIST'):
if os.path.exists('./data/MNIST.zip'):
print "Extracting MNIST data..."
print("Extracting MNIST data...")
with zipfile.ZipFile('./data/MNIST.zip', 'r') as z:
z.extractall('./data/')
else:
Expand Down Expand Up @@ -62,9 +62,9 @@ def finalize_options(self):
setup_requires=['numpy>=1.13.3'],
install_requires=[
'networkx>=1.11,<1.12',
'numpy==1.13.3',
'pillow>=4.1.0,<4.2',
'scipy>=0.19.0,<0.20',
'numpy>=1.13.3',
'pillow>=4.1.0',
'scipy>=0.19.0',
'setuptools>=36.5.0'
],
ext_modules=[dilation_module],
Expand All @@ -73,6 +73,7 @@ def finalize_options(self):
'Natural Language :: English',
'Operating System :: MacOS :: MacOS X',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: C'],
keywords='rcn',
Expand Down