Skip to content

Add low memory implementation of core computation #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions forestci/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .forestci import (calc_inbag, random_forest_error,
_cycore_computation,
_core_computation, _bias_correction) # noqa

from .version import __version__ # noqa
Expand Down
64 changes: 64 additions & 0 deletions forestci/cyfci.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# cython: boundscheck=False
# cython: wraparound=False
# cython: nonecheck=False
from numpy cimport ndarray
cimport numpy
import numpy
cimport cython
from cython.parallel cimport prange
cimport scipy.linalg.cython_blas as blas

def _cycore_computation(inbag, pred_centered):
"""
Helper function performs core computation using cython
and avoids storing intermediate matrices in-memory.

Parameters
----------
inbag: ndarray
The inbag matrix that fit the data. If set to `None` (default) it
will be inferred from the forest. However, this only works for trees
for which bootstrapping was set to `True`. That is, if sampling was
done with replacement. Otherwise, users need to provide their own
inbag matrix.

pred_centered : ndarray
Centered predictions that are an intermediate result in the
computation.
"""
result = numpy.zeros(pred_centered.shape[0], dtype=numpy.float64)
inbag = inbag-1
_matmul_colsum(inbag, pred_centered, result)
return result

cdef _matmul_colsum(double[:,::1] a, double[:,::1] b, double[::1] c):
"""
Matrix multiply `a` and `b` and then sum over columns without
storing the intermediate matrix (a dot b) in memory.
Result is stored in `c`.
Equivalent to `numpy.sum(numpy.dot(a,b), axis=0)`

Parameters
----------
a: ndarray
`(n,p)` 2d input array

b: ndarray
`(p,m)` 2d input array

c: ndarray
`(m)` 1d output array (data overwritten with result)

Returns
-------
None
"""
cdef int i, j
cdef int n=a.shape[0], m=b.shape[0], B=a.shape[1]
cdef int ONE=1
cdef double x=0.0;
for i in prange(m, nogil=True, schedule='static'):
x = 0.0
for j in range(n):
x = x + blas.ddot(&B, &(a[j,0]), &ONE, &(b[i,0]), &ONE) ** 2
c[i] = x / B**2
65 changes: 42 additions & 23 deletions forestci/forestci.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import copy
from .calibration import calibrateEB
from .cyfci import _cycore_computation
from sklearn.ensemble.forest import _generate_sample_indices
from .due import _due, _BibTeX

Expand Down Expand Up @@ -68,7 +69,7 @@ def calc_inbag(n_samples, forest):


def _core_computation(X_train, X_test, inbag, pred_centered, n_trees,
memory_constrained=False, memory_limit=None,
low_memory=False, memory_limit=None,
test_mode=False):
"""
Helper function, that performs the core computation
Expand All @@ -92,24 +93,34 @@ def _core_computation(X_train, X_test, inbag, pred_centered, n_trees,
Centered predictions that are an intermediate result in the
computation.

memory_constrained: boolean (optional)
Whether or not there is a restriction on memory. If False, it is
assumed that a ndarry of shape (n_train_sample,n_test_sample) fits
in main memory. Setting to True can actually provide a speed up if
memory_limit is tuned to the optimal range.
low_memory: boolean, optional
Whether or not to use a low memory (but slower) calculation. If `False`,
intermediate matrices with size (n_train_size, n_test_size) are stored
in memory (preferable if matrices fit in memory). If matrices are too
large either:
1. Set `low_memory=True`, which avoids storing intermediate matrices
but is slower.
2. Set `memory_limit`, which chunks the intermediate matrices in memory.
This may be faster than setting `low_memory` depending on the
number of memory chunks.

memory_limit: int (optional)
An upper bound for how much memory the itermediate matrices will take
up in Megabytes. This must be provided if memory_constrained=True.
Ignored if `low_memory=True`.
"""

# Use low memory computation
if low_memory:
return _cycore_computation(inbag, pred_centered)

"""
if not memory_constrained:
# Use full in-memory computation
elif memory_limit is None:
return np.sum((np.dot(inbag - 1, pred_centered.T) / n_trees) ** 2, 0)

if not memory_limit:
raise ValueError('If memory_constrained=True, must provide',
'memory_limit.')
# user has specified a memory limit. Use in-memory chunked computation
else:
pass

# Assumes double precision float
chunk_size = int((memory_limit * 1e6) / (8.0 * X_train.shape[0]))
Expand All @@ -127,8 +138,8 @@ def _core_computation(X_train, X_test, inbag, pred_centered, n_trees,
if test_mode:
print('Number of chunks: %d' % (len(chunks),))
V_IJ = np.concatenate([
np.sum((np.dot(inbag-1, pred_centered[chunk].T)/n_trees)**2, 0)
for chunk in chunks])
np.sum((np.dot(inbag-1, pred_centered[chunk].T)/n_trees)**2, 0)
for chunk in chunks])
return V_IJ


Expand Down Expand Up @@ -165,7 +176,7 @@ def _bias_correction(V_IJ, inbag, pred_centered, n_trees):


def random_forest_error(forest, X_train, X_test, inbag=None,
calibrate=True, memory_constrained=False,
calibrate=True, low_memory=False,
memory_limit=None):
"""
Calculate error bars from scikit-learn RandomForest estimators.
Expand Down Expand Up @@ -199,15 +210,21 @@ def random_forest_error(forest, X_train, X_test, inbag=None,
the number of trees in the forest is too small. To use calibration,
Default: True

memory_constrained: boolean, optional
Whether or not there is a restriction on memory. If False, it is
assumed that a ndarry of shape (n_train_sample,n_test_sample) fits
in main memory. Setting to True can actually provide a speed up if
memory_limit is tuned to the optimal range.
low_memory: boolean, optional
Whether or not to use a low memory (but slower) calculation. If `False`,
intermediate matrices with size (n_train_size, n_test_size) are stored
in memory (preferable if matrices fit in memory). If matrices are too
large either:
1. Set `low_memory=True`, which avoids storing intermediate matrices
but is slower.
2. Set `memory_limit`, which chunks the intermediate matrices in memory.
This may be faster than setting `low_memory` depending on the
number of memory chunks.

memory_limit: int, optional.
An upper bound for how much memory the itermediate matrices will take
up in Megabytes. This must be provided if memory_constrained=True.
Ignored if `low_memory=True`.

Returns
-------
Expand All @@ -231,12 +248,14 @@ def random_forest_error(forest, X_train, X_test, inbag=None,
if inbag is None:
inbag = calc_inbag(X_train.shape[0], forest)

pred = np.array([tree.predict(X_test) for tree in forest]).T
# Fortran order after transpose will be C-order
pred = np.array([tree.predict(X_test) for tree in forest], order='F').T
pred_mean = np.mean(pred, 0)
pred_centered = pred - pred_mean
n_trees = forest.n_estimators
V_IJ = _core_computation(X_train, X_test, inbag, pred_centered, n_trees,
memory_constrained, memory_limit)
V_IJ = _core_computation(
X_train, X_test, inbag, pred_centered, n_trees,
low_memory=low_memory, memory_limit=memory_limit)
V_IJ_unbiased = _bias_correction(V_IJ, inbag, pred_centered, n_trees)

# Correct for cases where resampling is done without replacement:
Expand All @@ -261,7 +280,7 @@ def random_forest_error(forest, X_train, X_test, inbag=None,

results_ss = random_forest_error(new_forest, X_train, X_test,
calibrate=False,
memory_constrained=memory_constrained,
low_memory=low_memory,
memory_limit=memory_limit)
# Use this second set of variance estimates
# to estimate scale of Monte Carlo noise
Expand Down
46 changes: 37 additions & 9 deletions forestci/tests/test_forestci.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
import numpy.testing as npt
from sklearn.ensemble import RandomForestRegressor
import forestci as fci
import forestci.cyfci


def test_compare_cycore_computation():
a = np.arange(1,7,dtype=np.float64).reshape(2,3)
b = np.arange(1,13,dtype=np.float64).reshape(4,3)
c = fci._cycore_computation(a, b)
actual = fci._core_computation(np.zeros((2,10)), np.zeros((4,10)), a, b, 3)
npt.assert_almost_equal(actual, c)


def test_random_forest_error():
Expand Down Expand Up @@ -57,21 +66,40 @@ def test_core_computation():
for _ in range(1000)])
n_trees = 4

our_vij = fci._core_computation(X_train_ex, X_test_ex, inbag_ex,
pred_centered_ex, n_trees)
pred_centered_ex = pred_centered_ex.astype(np.float64)
inbag_ex = inbag_ex.astype(np.float64)
our_vij = forestci._core_computation(
X_train_ex, X_test_ex, inbag_ex, pred_centered_ex, n_trees
)

r_vij = np.concatenate([np.array([112.5, 387.5]) for _ in range(1000)])

npt.assert_almost_equal(our_vij, r_vij)

for mc, ml in zip([True, False], [.01, None]):
our_vij = fci._core_computation(X_train_ex, X_test_ex, inbag_ex,
pred_centered_ex, n_trees,
memory_constrained=True,
memory_limit=.01,
test_mode=True)

npt.assert_almost_equal(our_vij, r_vij)
def test_low_memory_core_computation():
inbag_ex = np.array([[1., 2., 0., 1.],
[1., 0., 2., 0.],
[1., 1., 1., 2.]])

X_train_ex = np.array([[3, 3],
[6, 4],
[6, 6]])
X_test_ex = np.vstack([np.array([[5, 2],
[5, 5]]) for _ in range(1000)])
pred_centered_ex = np.vstack([np.array([[-20, -20, 10, 30],
[-20, 30, -20, 10]])
for _ in range(1000)])
n_trees = 4

pred_centered_ex = pred_centered_ex.astype(np.float64)
inbag_ex = inbag_ex.astype(np.float64)
our_vij = forestci._core_computation(
X_train_ex, X_test_ex, inbag_ex, pred_centered_ex, n_trees, low_memory=True)

r_vij = np.concatenate([np.array([112.5, 387.5]) for _ in range(1000)])

npt.assert_almost_equal(our_vij, r_vij)


def test_bias_correction():
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
numpy>= 1.8.2
nose>=1.1.2
scikit-learn>=0.17
cython
scipy
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function
import sys, os
from setuptools import setup, find_packages
from setuptools.extension import Extension

with open('requirements.txt') as f:
INSTALL_REQUIRES = [l.strip() for l in f.readlines() if l]
Expand All @@ -22,6 +23,13 @@
with open(ver_file) as f:
exec(f.read())

ext = Extension(
'forestci.cyfci',
['forestci/cyfci.pyx'],
include_dirs=[numpy.get_include()],
extra_compile_args=['-O3', '-fopenmp'],
extra_link_args=['-fopenmp'])

opts = dict(name=NAME,
maintainer=MAINTAINER,
maintainer_email=MAINTAINER_EMAIL,
Expand All @@ -36,7 +44,8 @@
platforms=PLATFORMS,
version=VERSION,
packages=find_packages(),
install_requires=INSTALL_REQUIRES)
install_requires=INSTALL_REQUIRES,
ext_modules=[ext])

if __name__ == '__main__':
setup(**opts)