Skip to content

Commit 3d63531

Browse files
authored
[oo transform] enable new style of jit transformation to support static_argnums and static_argnames (#360)
[oo transform] enable new style of jit transformation to support `static_argnums` and `static_argnames`
2 parents 5874e8d + ba72bba commit 3d63531

37 files changed

+303
-197
lines changed

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,6 @@ def f_loss():
343343

344344
grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True)
345345
optimizer.register_train_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
346-
dyn_vars = optimizer.vars() + (fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
347-
dyn_vars = dyn_vars.unique()
348346

349347
def train(idx):
350348
gradients, loss = grad_f()
@@ -353,7 +351,7 @@ def train(idx):
353351
return loss
354352

355353
def batch_train(start_i, n_batch):
356-
return bm.for_loop(train, bm.arange(start_i, start_i + n_batch), dyn_vars=dyn_vars)
354+
return bm.for_loop(train, bm.arange(start_i, start_i + n_batch))
357355

358356
# Run the optimization
359357
if self.verbose:

brainpy/_src/analysis/lowdim/lowdim_analyzer.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import partial
55

66
import numpy as np
7+
import jax
78
from jax import numpy as jnp
89
from jax import vmap
910
from jax.scipy.optimize import minimize
@@ -274,21 +275,21 @@ def F_fx(self):
274275
f = partial(f, **(self.pars_update + self.fixed_vars))
275276
f = utils.f_without_jaxarray_return(f)
276277
f = utils.remove_return_shape(f)
277-
self.analyzed_results[C.F_fx] = bm.jit(f, device=self.jit_device)
278+
self.analyzed_results[C.F_fx] = jax.jit(f, device=self.jit_device)
278279
return self.analyzed_results[C.F_fx]
279280

280281
@property
281282
def F_vmap_fx(self):
282283
if C.F_vmap_fx not in self.analyzed_results:
283-
self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device)
284+
self.analyzed_results[C.F_vmap_fx] = jax.jit(vmap(self.F_fx), device=self.jit_device)
284285
return self.analyzed_results[C.F_vmap_fx]
285286

286287
@property
287288
def F_dfxdx(self):
288289
"""The function to evaluate :math:`\frac{df_x(*\mathrm{vars}, *\mathrm{pars})}{dx}`."""
289290
if C.F_dfxdx not in self.analyzed_results:
290291
dfx = bm.vector_grad(self.F_fx, argnums=0)
291-
self.analyzed_results[C.F_dfxdx] = bm.jit(dfx, device=self.jit_device)
292+
self.analyzed_results[C.F_dfxdx] = jax.jit(dfx, device=self.jit_device)
292293
return self.analyzed_results[C.F_dfxdx]
293294

294295
@property
@@ -307,7 +308,7 @@ def F_vmap_fp_aux(self):
307308
# ---
308309
# "X": a two-dimensional matrix: (num_batch, num_var)
309310
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
310-
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux))
311+
self.analyzed_results[C.F_vmap_fp_aux] = jax.jit(vmap(self.F_fixed_point_aux))
311312
return self.analyzed_results[C.F_vmap_fp_aux]
312313

313314
@property
@@ -326,7 +327,7 @@ def F_vmap_fp_opt(self):
326327
# ---
327328
# "X": a two-dimensional matrix: (num_batch, num_var)
328329
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
329-
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt))
330+
self.analyzed_results[C.F_vmap_fp_opt] = jax.jit(vmap(self.F_fixed_point_opt))
330331
return self.analyzed_results[C.F_vmap_fp_opt]
331332

332333
def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None):
@@ -519,31 +520,31 @@ def F_y_by_x_in_fy(self):
519520
@property
520521
def F_vmap_fy(self):
521522
if C.F_vmap_fy not in self.analyzed_results:
522-
self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device)
523+
self.analyzed_results[C.F_vmap_fy] = jax.jit(vmap(self.F_fy), device=self.jit_device)
523524
return self.analyzed_results[C.F_vmap_fy]
524525

525526
@property
526527
def F_dfxdy(self):
527528
"""The function to evaluate :math:`\frac{df_x (*\mathrm{vars}, *\mathrm{pars})}{dy}`."""
528529
if C.F_dfxdy not in self.analyzed_results:
529530
dfxdy = bm.vector_grad(self.F_fx, argnums=1)
530-
self.analyzed_results[C.F_dfxdy] = bm.jit(dfxdy, device=self.jit_device)
531+
self.analyzed_results[C.F_dfxdy] = jax.jit(dfxdy, device=self.jit_device)
531532
return self.analyzed_results[C.F_dfxdy]
532533

533534
@property
534535
def F_dfydx(self):
535536
"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dx}`."""
536537
if C.F_dfydx not in self.analyzed_results:
537538
dfydx = bm.vector_grad(self.F_fy, argnums=0)
538-
self.analyzed_results[C.F_dfydx] = bm.jit(dfydx, device=self.jit_device)
539+
self.analyzed_results[C.F_dfydx] = jax.jit(dfydx, device=self.jit_device)
539540
return self.analyzed_results[C.F_dfydx]
540541

541542
@property
542543
def F_dfydy(self):
543544
"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dy}`."""
544545
if C.F_dfydy not in self.analyzed_results:
545546
dfydy = bm.vector_grad(self.F_fy, argnums=1)
546-
self.analyzed_results[C.F_dfydy] = bm.jit(dfydy, device=self.jit_device)
547+
self.analyzed_results[C.F_dfydy] = jax.jit(dfydy, device=self.jit_device)
547548
return self.analyzed_results[C.F_dfydy]
548549

549550
@property
@@ -556,7 +557,7 @@ def f_jacobian(*var_and_pars):
556557

557558
def call(*var_and_pars):
558559
var_and_pars = tuple((vp.value if isinstance(vp, bm.Array) else vp) for vp in var_and_pars)
559-
return jnp.array(bm.jit(f_jacobian, device=self.jit_device)(*var_and_pars))
560+
return jnp.array(jax.jit(f_jacobian, device=self.jit_device)(*var_and_pars))
560561

561562
self.analyzed_results[C.F_jacobian] = call
562563
return self.analyzed_results[C.F_jacobian]
@@ -681,7 +682,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
681682

682683
if self.F_x_by_y_in_fx is not None:
683684
utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...")
684-
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
685+
vmap_f = jax.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
685686
for j, pars in enumerate(par_seg):
686687
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
687688
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -697,7 +698,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
697698

698699
elif self.F_y_by_x_in_fx is not None:
699700
utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...")
700-
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
701+
vmap_f = jax.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
701702
for j, pars in enumerate(par_seg):
702703
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
703704
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -715,9 +716,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
715716
utils.output("I am evaluating fx-nullcline by optimization ...")
716717
# auxiliary functions
717718
f2 = lambda y, x, *pars: self.F_fx(x, y, *pars)
718-
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
719-
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
720-
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
719+
vmap_f2 = jax.jit(vmap(f2), device=self.jit_device)
720+
vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
721+
vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
721722

722723
# num segments
723724
for _j, Ps in enumerate(par_seg):
@@ -774,7 +775,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
774775

775776
if self.F_x_by_y_in_fy is not None:
776777
utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...")
777-
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
778+
vmap_f = jax.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
778779
for j, pars in enumerate(par_seg):
779780
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
780781
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -790,7 +791,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
790791

791792
elif self.F_y_by_x_in_fy is not None:
792793
utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...")
793-
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
794+
vmap_f = jax.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
794795
for j, pars in enumerate(par_seg):
795796
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
796797
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -809,9 +810,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
809810

810811
# auxiliary functions
811812
f2 = lambda y, x, *pars: self.F_fy(x, y, *pars)
812-
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
813-
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
814-
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
813+
vmap_f2 = jax.jit(vmap(f2), device=self.jit_device)
814+
vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
815+
vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
815816

816817
for j, Ps in enumerate(par_seg):
817818
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
@@ -859,7 +860,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
859860
xs = self.resolutions[self.x_var]
860861
ys = self.resolutions[self.y_var]
861862
P = tuple(self.resolutions[p] for p in self.target_par_names)
862-
f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
863+
f_select = jax.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
863864

864865
# num seguments
865866
if isinstance(num_segments, int):
@@ -939,10 +940,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
939940

940941
if self.convert_type() == C.x_by_y:
941942
num_seg = len(self.resolutions[self.y_var])
942-
f_vmap = bm.jit(vmap(self.F_y_convert[1]))
943+
f_vmap = jax.jit(vmap(self.F_y_convert[1]))
943944
else:
944945
num_seg = len(self.resolutions[self.x_var])
945-
f_vmap = bm.jit(vmap(self.F_x_convert[1]))
946+
f_vmap = jax.jit(vmap(self.F_x_convert[1]))
946947
# get the signs
947948
signs = jnp.sign(f_vmap(candidates, *args))
948949
signs = signs.reshape((num_seg, -1))
@@ -972,10 +973,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
972973
# get another value
973974
if self.convert_type() == C.x_by_y:
974975
y_values = fps
975-
x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args)
976+
x_values = jax.jit(vmap(self.F_y_convert[0]))(y_values, *args)
976977
else:
977978
x_values = fps
978-
y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args)
979+
y_values = jax.jit(vmap(self.F_x_convert[0]))(x_values, *args)
979980
fps = jnp.stack([x_values, y_values]).T
980981
return fps, selected_ids, args
981982

@@ -1042,7 +1043,7 @@ def F_fz(self):
10421043
wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names)
10431044
f = wrapper(self.model.f_derivatives[self.z_var])
10441045
f = partial(f, **(self.pars_update + self.fixed_vars))
1045-
self.analyzed_results[C.F_fz] = bm.jit(f, device=self.jit_device)
1046+
self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device)
10461047
return self.analyzed_results[C.F_fz]
10471048

10481049
def fz_signs(self, pars=(), cache=False):

brainpy/_src/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, model, target_pars, target_vars, fixed_vars=None,
4444
@property
4545
def F_vmap_dfxdx(self):
4646
if C.F_vmap_dfxdx not in self.analyzed_results:
47-
f = bm.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
47+
f = jax.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
4848
self.analyzed_results[C.F_vmap_dfxdx] = f
4949
return self.analyzed_results[C.F_vmap_dfxdx]
5050

@@ -163,7 +163,7 @@ def F_vmap_jacobian(self):
163163
if C.F_vmap_jacobian not in self.analyzed_results:
164164
f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args),
165165
self.F_fy(xy[0], xy[1], *args)])
166-
f2 = bm.jit(vmap(bm.jacobian(f1)), device=self.jit_device)
166+
f2 = jax.jit(vmap(bm.jacobian(f1)), device=self.jit_device)
167167
self.analyzed_results[C.F_vmap_jacobian] = f2
168168
return self.analyzed_results[C.F_vmap_jacobian]
169169

brainpy/_src/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(self,
160160
@property
161161
def F_vmap_brentq_fy(self):
162162
if C.F_vmap_brentq_fy not in self.analyzed_results:
163-
f_opt = bm.jit(vmap(utils.jax_brentq(self.F_fy)))
163+
f_opt = jax.jit(vmap(utils.jax_brentq(self.F_fy)))
164164
self.analyzed_results[C.F_vmap_brentq_fy] = f_opt
165165
return self.analyzed_results[C.F_vmap_brentq_fy]
166166

brainpy/_src/analysis/lowdim/tests/test_phase_plane.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jax.numpy as jnp
88

99

10-
block = False
10+
show = False
1111

1212

1313
class TestPhasePlane(unittest.TestCase):
@@ -27,7 +27,8 @@ def int_x(x, t, Iext):
2727
plt.ion()
2828
analyzer.plot_vector_field()
2929
analyzer.plot_fixed_point()
30-
plt.show(block=block)
30+
if show:
31+
plt.show()
3132
plt.close()
3233
bp.math.disable_x64()
3334

@@ -74,6 +75,7 @@ def int_s2(s2, t, s1):
7475
analyzer.plot_vector_field()
7576
analyzer.plot_nullcline(coords=dict(s2='s2-s1'))
7677
analyzer.plot_fixed_point()
77-
plt.show(block=block)
78+
if show:
79+
plt.show()
7880
plt.close()
7981
bp.math.disable_x64()

brainpy/_src/analysis/utils/optimization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def brentq_candidates(vmap_f, *values, args=()):
197197

198198
def brentq_roots(f, starts, ends, *vmap_args, args=()):
199199
in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args)))
200-
vmap_f_opt = bm.jit(vmap(jax_brentq(f), in_axes=in_axes))
200+
vmap_f_opt = jax.jit(vmap(jax_brentq(f), in_axes=in_axes))
201201
all_args = vmap_args + args
202202
if len(all_args):
203203
res = vmap_f_opt(starts, ends, all_args)
@@ -397,7 +397,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
397397
return fps
398398
starts = candidates[candidate_ids]
399399
ends = candidates[candidate_ids + 1]
400-
f_opt = bm.jit(vmap(jax_brentq(f), in_axes=(0, 0, None)))
400+
f_opt = jax.jit(vmap(jax_brentq(f), in_axes=(0, 0, None)))
401401
res = f_opt(starts, ends, args)
402402
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
403403
fps2 = res['root'][valid_idx]
@@ -406,7 +406,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
406406

407407
def roots_of_1d_by_xy(f, starts, ends, args):
408408
f = f_without_jaxarray_return(f)
409-
f_opt = bm.jit(vmap(jax_brentq(f)))
409+
f_opt = jax.jit(vmap(jax_brentq(f)))
410410
res = f_opt(starts, ends, (args,))
411411
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
412412
xs = res['root'][valid_idx]

brainpy/_src/analysis/utils/others.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from typing import Union, Dict
44

5+
import jax
56
import jax.numpy as jnp
67
import numpy as np
7-
from jax import vmap
88
from jax.tree_util import tree_map
99

1010
import brainpy.math as bm
@@ -80,7 +80,7 @@ def get_sign(f, xs, ys):
8080

8181
def get_sign2(f, *xyz, args=()):
8282
in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
83-
f = bm.jit(vmap(f_without_jaxarray_return(f), in_axes=in_axes))
83+
f = jax.jit(jax.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
8484
xyz = tuple((v.value if isinstance(v, bm.Array) else v) for v in xyz)
8585
XYZ = jnp.meshgrid(*xyz)
8686
XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)

brainpy/_src/dyn/runners.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,9 @@ def _get_f_predict(self, shared_args: Dict = None):
668668

669669
shared_kwargs_str = serialize_kwargs(shared_args)
670670
if shared_kwargs_str not in self._f_predict_compiled:
671-
dyn_vars = self.target.vars()
672-
dyn_vars.update(self._dyn_vars)
673-
dyn_vars.update(self.vars(level=0))
674-
dyn_vars = dyn_vars.unique()
675671

676672
if self._memory_efficient:
677-
_jit_step = bm.jit(partial(self._step_func_predict, shared_args), dyn_vars=dyn_vars)
673+
_jit_step = bm.jit(partial(self._step_func_predict, shared_args))
678674

679675
def run_func(all_inputs):
680676
outs = None
@@ -688,12 +684,10 @@ def run_func(all_inputs):
688684
return outs, None
689685

690686
else:
691-
@bm.jit(dyn_vars=dyn_vars)
687+
step = partial(self._step_func_predict, shared_args)
688+
692689
def run_func(all_inputs):
693-
return bm.for_loop(partial(self._step_func_predict, shared_args),
694-
all_inputs,
695-
dyn_vars=dyn_vars,
696-
jit=self.jit['predict'])
690+
return bm.for_loop(step, all_inputs, jit=self.jit['predict'])
697691

698692
self._f_predict_compiled[shared_kwargs_str] = run_func
699693

brainpy/_src/dyn/synapses/abstract_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -977,8 +977,7 @@ def update(self, tdi):
977977
inp = bm.cond((a > 5) * (b > 5),
978978
lambda _: self.rng.normal(a, b * p, self.target_var.shape),
979979
lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape),
980-
None,
981-
dyn_vars=self.rng)
980+
None)
982981
self.target_var += inp * self.weight
983982

984983
def __repr__(self):

brainpy/_src/dyn/synapses_v2/others.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ def update(self):
6868
inp = bm.cond((a > 5) * (b > 5),
6969
lambda _: self.rng.normal(a, b * p, self.target_shape),
7070
lambda _: self.rng.binomial(self.num_input, p, self.target_shape),
71-
None,
72-
dyn_vars=self.rng)
71+
None)
7372
return inp * self.weight
7473

7574
def __repr__(self):

0 commit comments

Comments
 (0)