Skip to content

Commit 174c81a

Browse files
authored
Update VariableView and analysis plotting apis (#268)
Update `VariableView` and analysis plotting apis
2 parents 99d5cb9 + 536c3a8 commit 174c81a

File tree

12 files changed

+258
-331
lines changed

12 files changed

+258
-331
lines changed

brainpy/analysis/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,4 @@
2222
from .lowdim.lowdim_bifurcation import *
2323

2424
from .constants import *
25-
from . import constants as C
26-
from . import stability
27-
from . import utils
25+
from . import constants as C, stability, plotstyle, utils

brainpy/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import brainpy.math as bm
1010
from brainpy import errors
11-
from brainpy.analysis import stability, utils, constants as C
11+
from brainpy.analysis import stability, plotstyle, utils, constants as C
1212
from brainpy.analysis.lowdim.lowdim_analyzer import *
1313

1414
pyplot = None
@@ -79,8 +79,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
7979
pyplot.figure(self.x_var)
8080
for fp_type, points in container.items():
8181
if len(points['x']):
82-
plot_style = stability.plot_scheme[fp_type]
83-
pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
82+
plot_style = plotstyle.plot_schema[fp_type]
83+
pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type)
8484
pyplot.xlabel(self.target_par_names[0])
8585
pyplot.ylabel(self.x_var)
8686

@@ -107,10 +107,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
107107
ax = fig.add_subplot(projection='3d')
108108
for fp_type, points in container.items():
109109
if len(points['x']):
110-
plot_style = stability.plot_scheme[fp_type]
110+
plot_style = plotstyle.plot_schema[fp_type]
111111
xs = points['p0']
112112
ys = points['p1']
113113
zs = points['x']
114+
plot_style.pop('linestyle')
114115
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)
115116

116117
ax.set_xlabel(self.target_par_names[0])
@@ -298,8 +299,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
298299
pyplot.figure(var)
299300
for fp_type, points in container.items():
300301
if len(points['p']):
301-
plot_style = stability.plot_scheme[fp_type]
302-
pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
302+
plot_style = plotstyle.plot_schema[fp_type]
303+
pyplot.plot(points['p'], points[var], **plot_style, label=fp_type)
303304
pyplot.xlabel(self.target_par_names[0])
304305
pyplot.ylabel(var)
305306

@@ -330,10 +331,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
330331
ax = fig.add_subplot(projection='3d')
331332
for fp_type, points in container.items():
332333
if len(points['p0']):
333-
plot_style = stability.plot_scheme[fp_type]
334+
plot_style = plotstyle.plot_schema[fp_type]
334335
xs = points['p0']
335336
ys = points['p1']
336337
zs = points[var]
338+
plot_style.pop('linestyle')
337339
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)
338340

339341
ax.set_xlabel(self.target_par_names[0])

brainpy/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import brainpy.math as bm
88
from brainpy import errors, math
9-
from brainpy.analysis import stability, constants as C, utils
9+
from brainpy.analysis import stability, plotstyle, constants as C, utils
1010
from brainpy.analysis.lowdim.lowdim_analyzer import *
1111

1212
pyplot = None
@@ -107,8 +107,8 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
107107
if with_plot:
108108
for fp_type, points in container.items():
109109
if len(points):
110-
plot_style = stability.plot_scheme[fp_type]
111-
pyplot.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type)
110+
plot_style = plotstyle.plot_schema[fp_type]
111+
pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type)
112112
pyplot.legend()
113113
if show:
114114
pyplot.show()
@@ -248,9 +248,9 @@ def plot_nullcline(self, with_plot=True, with_return=False,
248248

249249
if with_plot:
250250
if x_style is None:
251-
x_style = dict(color='cornflowerblue', alpha=.7, )
252-
fmt = x_style.pop('fmt', '.')
253-
pyplot.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline")
251+
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
252+
line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple()
253+
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")
254254

255255
# Nullcline of the y variable
256256
utils.output('I am computing fy-nullcline ...')
@@ -260,9 +260,9 @@ def plot_nullcline(self, with_plot=True, with_return=False,
260260

261261
if with_plot:
262262
if y_style is None:
263-
y_style = dict(color='lightcoral', alpha=.7, )
264-
fmt = y_style.pop('fmt', '.')
265-
pyplot.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline")
263+
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
264+
line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple()
265+
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")
266266

267267
if with_plot:
268268
pyplot.xlabel(self.x_var)
@@ -349,8 +349,8 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,
349349
if with_plot:
350350
for fp_type, points in container.items():
351351
if len(points['x']):
352-
plot_style = stability.plot_scheme[fp_type]
353-
pyplot.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type)
352+
plot_style = plotstyle.plot_schema[fp_type]
353+
pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type)
354354
pyplot.legend()
355355
if show:
356356
pyplot.show()

brainpy/analysis/plotstyle.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
__all__ = [
5+
'plot_schema',
6+
'set_plot_schema',
7+
]
8+
9+
from .stability import (CENTER_MANIFOLD, SADDLE_NODE, STABLE_POINT_1D,
10+
UNSTABLE_POINT_1D, CENTER_2D, STABLE_NODE_2D,
11+
STABLE_FOCUS_2D, STABLE_STAR_2D, STABLE_DEGENERATE_2D,
12+
UNSTABLE_NODE_2D, UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D,
13+
UNSTABLE_DEGENERATE_2D, UNSTABLE_LINE_2D,
14+
STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D,
15+
UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, STABLE_FOCUS_3D,
16+
UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D)
17+
18+
19+
_markersize = 20
20+
21+
plot_schema = {}
22+
23+
plot_schema[CENTER_MANIFOLD] = {'color': 'orangered', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
24+
plot_schema[SADDLE_NODE] = {"color": 'tab:blue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
25+
26+
plot_schema[STABLE_POINT_1D] = {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
27+
plot_schema[UNSTABLE_POINT_1D] = {"color": 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
28+
29+
plot_schema.update({
30+
CENTER_2D: {'color': 'lime', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
31+
STABLE_NODE_2D: {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
32+
STABLE_FOCUS_2D: {"color": 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
33+
STABLE_STAR_2D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
34+
STABLE_DEGENERATE_2D: {'color': 'blueviolet', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
35+
UNSTABLE_NODE_2D: {"color": 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
36+
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
37+
UNSTABLE_STAR_2D: {'color': 'green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
38+
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
39+
UNSTABLE_LINE_2D: {'color': 'dodgerblue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
40+
})
41+
42+
43+
plot_schema.update({
44+
STABLE_POINT_3D: {'color': 'tab:gray', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
45+
UNSTABLE_POINT_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
46+
STABLE_NODE_3D: {'color': 'tab:green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
47+
UNSTABLE_SADDLE_3D: {'color': 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
48+
UNSTABLE_FOCUS_3D: {'color': 'tab:pink', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
49+
STABLE_FOCUS_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
50+
UNSTABLE_NODE_3D: {'color': 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
51+
UNSTABLE_CENTER_3D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
52+
UNKNOWN_3D: {'color': 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
53+
})
54+
55+
56+
def set_plot_schema(fixed_point: str, **schema):
57+
if not isinstance(fixed_point, str):
58+
raise TypeError(f'Must instance of string, but we got {type(fixed_point)}: {fixed_point}')
59+
if fixed_point not in plot_schema:
60+
raise KeyError(f'Fixed point type {fixed_point} does not found in the built-in types. ')
61+
plot_schema[fixed_point].update(**schema)
62+
63+
64+
def set_markersize(markersize):
65+
if not isinstance(markersize, int):
66+
raise TypeError(f"Must be an integer, but got {type(markersize)}: {markersize}")
67+
global _markersize
68+
__markersize = markersize
69+
for key in tuple(plot_schema.keys()):
70+
plot_schema[key]['markersize'] = markersize
71+
72+

brainpy/analysis/stability.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
'get_1d_stability_types',
77
'get_2d_stability_types',
88
'get_3d_stability_types',
9-
'plot_scheme',
9+
1010

1111
'stability_analysis',
1212

@@ -27,17 +27,13 @@
2727
'UNSTABLE_LINE_2D',
2828
]
2929

30-
plot_scheme = {}
30+
3131

3232
SADDLE_NODE = 'saddle node'
3333
CENTER_MANIFOLD = 'center manifold'
34-
plot_scheme[CENTER_MANIFOLD] = {'color': 'orangered'}
35-
plot_scheme[SADDLE_NODE] = {"color": 'tab:blue'}
3634

3735
STABLE_POINT_1D = 'stable point'
3836
UNSTABLE_POINT_1D = 'unstable point'
39-
plot_scheme[STABLE_POINT_1D] = {"color": 'tab:red'}
40-
plot_scheme[UNSTABLE_POINT_1D] = {"color": 'tab:olive'}
4137

4238
CENTER_2D = 'center'
4339
STABLE_NODE_2D = 'stable node'
@@ -49,18 +45,7 @@
4945
UNSTABLE_STAR_2D = 'unstable star'
5046
UNSTABLE_DEGENERATE_2D = 'unstable degenerate'
5147
UNSTABLE_LINE_2D = 'unstable line'
52-
plot_scheme.update({
53-
CENTER_2D: {'color': 'lime'},
54-
STABLE_NODE_2D: {"color": 'tab:red'},
55-
STABLE_FOCUS_2D: {"color": 'tab:purple'},
56-
STABLE_STAR_2D: {'color': 'tab:olive'},
57-
STABLE_DEGENERATE_2D: {'color': 'blueviolet'},
58-
UNSTABLE_NODE_2D: {"color": 'tab:orange'},
59-
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan'},
60-
UNSTABLE_STAR_2D: {'color': 'green'},
61-
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen'},
62-
UNSTABLE_LINE_2D: {'color': 'dodgerblue'},
63-
})
48+
6449

6550
STABLE_POINT_3D = 'unclassified stable point'
6651
UNSTABLE_POINT_3D = 'unclassified unstable point'
@@ -71,17 +56,6 @@
7156
UNSTABLE_FOCUS_3D = 'unstable focus'
7257
UNSTABLE_CENTER_3D = 'unstable center'
7358
UNKNOWN_3D = 'unknown 3d'
74-
plot_scheme.update({
75-
STABLE_POINT_3D: {'color': 'tab:gray'},
76-
UNSTABLE_POINT_3D: {'color': 'tab:purple'},
77-
STABLE_NODE_3D: {'color': 'tab:green'},
78-
UNSTABLE_SADDLE_3D: {'color': 'tab:red'},
79-
UNSTABLE_FOCUS_3D: {'color': 'tab:pink'},
80-
STABLE_FOCUS_3D: {'color': 'tab:purple'},
81-
UNSTABLE_NODE_3D: {'color': 'tab:orange'},
82-
UNSTABLE_CENTER_3D: {'color': 'tab:olive'},
83-
UNKNOWN_3D: {'color': 'tab:cyan'},
84-
})
8559

8660

8761
def get_1d_stability_types():

brainpy/dyn/rates/populations.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def update(self, tdi, x=None):
180180
self.y.value = y
181181

182182
def clear_input(self):
183-
self.input[:] = 0.
184-
self.input_y[:] = 0.
183+
self.input.value = bm.zeros_like(self.input)
184+
self.input_y.value = bm.zeros_like(self.input_y)
185185

186186

187187
class FeedbackFHN(RateModel):
@@ -375,8 +375,8 @@ def update(self, tdi, x=None):
375375
self.y.value = y
376376

377377
def clear_input(self):
378-
self.input[:] = 0.
379-
self.input_y[:] = 0.
378+
self.input.value = bm.zeros_like(self.input)
379+
self.input_y.value = bm.zeros_like(self.input_y)
380380

381381

382382
class QIF(RateModel):
@@ -558,8 +558,8 @@ def update(self, tdi, x=None):
558558
self.y.value = y
559559

560560
def clear_input(self):
561-
self.input[:] = 0.
562-
self.input_y[:] = 0.
561+
self.input.value = bm.zeros_like(self.input)
562+
self.input_y.value = bm.zeros_like(self.input_y)
563563

564564

565565
class StuartLandauOscillator(RateModel):
@@ -700,8 +700,8 @@ def update(self, tdi, x=None):
700700
self.y.value = y
701701

702702
def clear_input(self):
703-
self.input[:] = 0.
704-
self.input_y[:] = 0.
703+
self.input.value = bm.zeros_like(self.input)
704+
self.input_y.value = bm.zeros_like(self.input_y)
705705

706706

707707
class WilsonCowanModel(RateModel):
@@ -857,8 +857,8 @@ def update(self, tdi, x=None):
857857
self.y.value = y
858858

859859
def clear_input(self):
860-
self.input[:] = 0.
861-
self.input_y[:] = 0.
860+
self.input.value = bm.zeros_like(self.input)
861+
self.input_y.value = bm.zeros_like(self.input_y)
862862

863863

864864
class JansenRitModel(RateModel):
@@ -976,5 +976,5 @@ def update(self, tdi, x=None):
976976
self.i.value = bm.maximum(self.i + di * dt, 0.)
977977

978978
def clear_input(self):
979-
self.Ie[:] = 0.
980-
self.Ii[:] = 0.
979+
self.Ie.value = bm.zeros_like(self.Ie)
980+
self.Ii.value = bm.zeros_like(self.Ii)

brainpy/math/delayvars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
435435
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)
436436

437437
elif self.update_method == CONCAT_UPDATING:
438-
self.data.value = bm.concatenate([self.data[1:], bm.broadcast_to(value, self.delay_target_shape)], axis=0)
438+
self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])])
439439

440440
else:
441441
raise ValueError(f'Unknown updating method "{self.update_method}"')

0 commit comments

Comments
 (0)