diff --git a/tvb_library/tvb/simulator/models/_dfun_stefanescu_jirsa.py b/tvb_library/tvb/simulator/models/_dfun_stefanescu_jirsa.py new file mode 100644 index 0000000000..453ff8336a --- /dev/null +++ b/tvb_library/tvb/simulator/models/_dfun_stefanescu_jirsa.py @@ -0,0 +1,82 @@ +import numba as nb +import numpy as np + + +def dfun_fitzhughnaguma(state_variables, coupling, local_coupling, tau, e_i, K11, Aik, K12, Bik, IE_i, + b, m_i, K21, Cik, II_i, f_i, n_i): + xi = state_variables[0, :] + eta = state_variables[1, :] + alpha = state_variables[2, :] + beta = state_variables[3, :] + derivative = np.empty_like(state_variables) + + # Sum the activity from the modes + c_0 = coupling[0, :].sum(axis=1)[:, np.newaxis] + + # Compute derivatives + derivative[0] = ( + tau * (xi - e_i * xi ** 3 / 3.0 - eta) + + K11 * (np.dot(xi, Aik) - xi) + - K12 * (np.dot(alpha, Bik) - xi) + + tau * (IE_i + c_0 + local_coupling * xi) + ) + + derivative[1] = (xi - b * eta + m_i) / tau + + derivative[2] = ( + tau * (alpha - f_i * alpha ** 3 / 3.0 - beta) + + K21 * (np.dot(xi, Cik) - alpha) + + tau * (II_i + c_0 + local_coupling * xi) + ) + + derivative[3] = (alpha - b * beta + n_i) / tau + + return derivative + +_dfun_fitzhughnaguma_numba = nb.njit(dfun_fitzhughnaguma) +def dfun_fitzhughnaguma_numba(state_variables, coupling, local_coupling, tau, e_i, K11, Aik, K12, Bik, IE_i, + b, m_i, K21, Cik, II_i, f_i, n_i): + return _dfun_fitzhughnaguma_numba(state_variables, coupling, local_coupling, + tau, e_i, K11, Aik, K12, Bik, IE_i, b, + m_i, K21, Cik, II_i, f_i, n_i) + +def dfun_hindmarshrose(state_variables, coupling, local_coupling, r, a_i, b_i, c_i, d_i, s, K11, A_ik, K12, B_ik, IE_i, m_i, K21, C_ik, II_i, e_i, f_i, h_i, p_i, n_i): + xi = state_variables[0, :] + eta = state_variables[1, :] + tau = state_variables[2, :] + alpha = state_variables[3, :] + beta = state_variables[4, :] + gamma = state_variables[5, :] + derivative = np.empty_like(state_variables) + + c_0 = coupling[0, :].sum(axis=1)[:, np.newaxis] + # c_1 = coupling[1, :] + + derivative[0] = (eta - a_i * xi ** 3 + b_i * xi ** 2 - tau + + K11 * (np.dot(xi, A_ik) - xi) - + K12 * (np.dot(alpha, B_ik) - xi) + + IE_i + c_0 + local_coupling * xi) + + derivative[1] = c_i - d_i * xi ** 2 - eta + + derivative[2] = r * s * xi - r * tau - m_i + + derivative[3] = (beta - e_i * alpha ** 3 + f_i * alpha ** 2 - gamma + + K21 * (np.dot(xi, C_ik) - alpha) + + II_i + c_0 + local_coupling * xi) + + derivative[4] = h_i - p_i * alpha ** 2 - beta + + derivative[5] = r * s * alpha - r * gamma - n_i + + return derivative + +_dfun_hindmarshrose_numba = nb.njit(dfun_hindmarshrose) +@nb.njit +def dfun_hindmarshrose_numba(state_variables, coupling, local_coupling, r, a_i, b_i, c_i, d_i, s, K11, A_ik, K12, B_ik, IE_i, m_i, K21, C_ik, II_i, e_i, f_i, h_i, p_i, n_i): + return _dfun_hindmarshrose_numba(state_variables, coupling, local_coupling, + r, a_i, b_i, c_i, d_i, s, K11, A_ik, + K12, B_ik, IE_i, m_i, K21, C_ik, II_i, + e_i, f_i, h_i, p_i, n_i) + + diff --git a/tvb_library/tvb/simulator/models/stefanescu_jirsa.py b/tvb_library/tvb/simulator/models/stefanescu_jirsa.py index df235e1e2d..3609825862 100644 --- a/tvb_library/tvb/simulator/models/stefanescu_jirsa.py +++ b/tvb_library/tvb/simulator/models/stefanescu_jirsa.py @@ -31,6 +31,7 @@ import numpy from scipy.integrate import trapezoid as scipy_integrate_trapz from scipy.stats import norm as scipy_stats_norm +from ._dfun_stefanescu_jirsa import dfun_fitzhughnaguma, dfun_fitzhughnaguma_numba, dfun_hindmarshrose, dfun_hindmarshrose_numba from .base import Model from tvb.basic.neotraits.api import NArray, Final, List, Range @@ -106,6 +107,7 @@ class ReducedSetFitzHughNagumo(ReducedSetBase): """ # Define traited attributes for this model, these represent possible kwargs. + use_numba = True tau = NArray( label=r":math:`\tau`", default=numpy.array([3.0]), @@ -213,32 +215,14 @@ def dfun(self, state_variables, coupling, local_coupling=0.0): \dot{\beta}_i &= \frac{1}{c}\left(\alpha_i-b\beta_i+n_i\right) """ - - xi = state_variables[0, :] - eta = state_variables[1, :] - alpha = state_variables[2, :] - beta = state_variables[3, :] - derivative = numpy.empty_like(state_variables) - # sum the activity from the modes - c_0 = coupling[0, :].sum(axis=1)[:, numpy.newaxis] - - # TODO: generalize coupling variables to a matrix form - # c_1 = coupling[1, :] # this cv represents alpha - - derivative[0] = (self.tau * (xi - self.e_i * xi ** 3 / 3.0 - eta) + - self.K11 * (numpy.dot(xi, self.Aik) - xi) - - self.K12 * (numpy.dot(alpha, self.Bik) - xi) + - self.tau * (self.IE_i + c_0 + local_coupling * xi)) - - derivative[1] = (xi - self.b * eta + self.m_i) / self.tau - - derivative[2] = (self.tau * (alpha - self.f_i * alpha ** 3 / 3.0 - beta) + - self.K21 * (numpy.dot(xi, self.Cik) - alpha) + - self.tau * (self.II_i + c_0 + local_coupling * xi)) - - derivative[3] = (alpha - self.b * beta + self.n_i) / self.tau - - return derivative + dfun_to_use = dfun_fitzhughnaguma + if self.use_numba: + dfun_to_use = dfun_fitzhughnaguma_numba + return dfun_to_use( + state_variables, coupling, local_coupling, + self.tau, self.e_i, self.K11, self.Aik, self.K12, self.Bik, self.IE_i, + self.b, self.m_i, self.K21, self.Cik, self.II_i, self.f_i, self.n_i + ) def update_derived_parameters(self): """ @@ -484,6 +468,7 @@ class ReducedSetHindmarshRose(ReducedSetBase): II_i = None m_i = None n_i = None + use_numba = True def dfun(self, state_variables, coupling, local_coupling=0.0): r""" @@ -509,36 +494,12 @@ def dfun(self, state_variables, coupling, local_coupling=0.0): \dot{\gamma}_i &= rs \alpha_i - r \gamma_i - n_i """ - - xi = state_variables[0, :] - eta = state_variables[1, :] - tau = state_variables[2, :] - alpha = state_variables[3, :] - beta = state_variables[4, :] - gamma = state_variables[5, :] - derivative = numpy.empty_like(state_variables) - - c_0 = coupling[0, :].sum(axis=1)[:, numpy.newaxis] - # c_1 = coupling[1, :] - - derivative[0] = (eta - self.a_i * xi ** 3 + self.b_i * xi ** 2 - tau + - self.K11 * (numpy.dot(xi, self.A_ik) - xi) - - self.K12 * (numpy.dot(alpha, self.B_ik) - xi) + - self.IE_i + c_0 + local_coupling * xi) - - derivative[1] = self.c_i - self.d_i * xi ** 2 - eta - - derivative[2] = self.r * self.s * xi - self.r * tau - self.m_i - - derivative[3] = (beta - self.e_i * alpha ** 3 + self.f_i * alpha ** 2 - gamma + - self.K21 * (numpy.dot(xi, self.C_ik) - alpha) + - self.II_i + c_0 + local_coupling * xi) - - derivative[4] = self.h_i - self.p_i * alpha ** 2 - beta - - derivative[5] = self.r * self.s * alpha - self.r * gamma - self.n_i - - return derivative + dfun_to_use = dfun_hindmarshrose + if self.use_numba: + dfun_to_use = dfun_hindmarshrose_numba + return dfun_to_use( + state_variables, coupling, local_coupling, + self.r, self.a_i, self.b_i, self.c_i, self.d_i, self.s, self.K11, self.A_ik, self.K12, self.B_ik, self.IE_i, self.m_i, self.K21, self.C_ik, self.II_i, self.e_i, self.f_i, self.h_i, self.p_i, self.n_i) def update_derived_parameters(self, corrected_d_p=True): """ diff --git a/tvb_library/tvb/tests/library/simulator/models_benchmark.py b/tvb_library/tvb/tests/library/simulator/models_benchmark.py new file mode 100644 index 0000000000..bd3ad9fced --- /dev/null +++ b/tvb_library/tvb/tests/library/simulator/models_benchmark.py @@ -0,0 +1,78 @@ +import numpy as np +import time +from tvb.tests.library.base_testcase import BaseTestCase +from tvb.simulator.models.stefanescu_jirsa import ReducedSetFitzHughNagumo +from tvb.simulator.models.stefanescu_jirsa import ReducedSetHindmarshRose +from tvb.simulator.models.base import Model + + +class TestBenchmarkModels(BaseTestCase): + """Test cases that benchmark the performance of models' implementation + """ + + def randn_state_for_model(self, model: Model, n_node): + shape = (model.nvar, n_node, model.number_of_modes) + state = np.random.randn(*shape) + return state + + + def randn_coupling_for_model(self, model: Model, n_node): + n_cvar = len(model.cvar) + shape = (n_cvar, n_node, model.number_of_modes) + coupling = np.random.randn(*shape) + return coupling + + def eps_for_model(self, model: Model, n_node, time_limit=0.5, state=None, coupling=None): + model.configure() + if state is None: + state = self.randn_state_for_model(model, n_node) + if coupling is None: + coupling = self.randn_coupling_for_model(model, n_node) + # throw one away in case of initialization + model.dfun(state, coupling) + # start timing + tic = time.time() + n_eval = 0 + while (time.time() - tic) < time_limit: + model.dfun(state, coupling) + n_eval += 1 + toc = time.time() + return n_eval / (toc - tic) + + def test_rsfhn_numba(self): + # create, & initialize numba & no-numba models + rs_fhn_model = ReducedSetFitzHughNagumo() + n_node = 10 + state = self.randn_state_for_model(rs_fhn_model, n_node) + coupling = self.randn_coupling_for_model(rs_fhn_model, n_node) + + rs_fhn_model.use_numba = False + no_numba_performance = self.eps_for_model(rs_fhn_model, 10, state=state, coupling=coupling) + no_numba_dfun_derivative = rs_fhn_model.dfun(state, coupling) + rs_fhn_model.use_numba = True + numba_performance = self.eps_for_model(rs_fhn_model, 10, state=state, coupling=coupling) + numba_dfun_derivative = rs_fhn_model.dfun(state, coupling) + speedup = numba_performance / no_numba_performance + print(f"speedup: {speedup}") + print(f"no_numba_performance: {no_numba_performance}, numba_performance: {numba_performance}") + assert speedup > 1 + assert np.allclose(no_numba_dfun_derivative, numba_dfun_derivative) + + def test_rshr_numba(self): + # create, & initialize numba & no-numba models + rs_hr_model = ReducedSetHindmarshRose() + n_node = 10 + state = self.randn_state_for_model(rs_hr_model, n_node) + coupling = self.randn_coupling_for_model(rs_hr_model, n_node) + + rs_hr_model.use_numba = False + no_numba_performance = self.eps_for_model(rs_hr_model, 10, state=state, coupling=coupling) + no_numba_dfun_derivative = rs_hr_model.dfun(state, coupling) + rs_hr_model.use_numba = True + numba_performance = self.eps_for_model(rs_hr_model, 10, state=state, coupling=coupling) + numba_dfun_derivative = rs_hr_model.dfun(state, coupling) + speedup = numba_performance / no_numba_performance + print(f"speedup: {speedup}") + print(f"no_numba_performance: {no_numba_performance}, numba_performance: {numba_performance}") + assert speedup > 1 + assert np.allclose(no_numba_dfun_derivative, numba_dfun_derivative)