44from functools import partial
55
66import numpy as np
7+ import jax
78from jax import numpy as jnp
89from jax import vmap
910from jax .scipy .optimize import minimize
@@ -274,21 +275,21 @@ def F_fx(self):
274275 f = partial (f , ** (self .pars_update + self .fixed_vars ))
275276 f = utils .f_without_jaxarray_return (f )
276277 f = utils .remove_return_shape (f )
277- self .analyzed_results [C .F_fx ] = bm .jit (f , device = self .jit_device )
278+ self .analyzed_results [C .F_fx ] = jax .jit (f , device = self .jit_device )
278279 return self .analyzed_results [C .F_fx ]
279280
280281 @property
281282 def F_vmap_fx (self ):
282283 if C .F_vmap_fx not in self .analyzed_results :
283- self .analyzed_results [C .F_vmap_fx ] = bm .jit (vmap (self .F_fx ), device = self .jit_device )
284+ self .analyzed_results [C .F_vmap_fx ] = jax .jit (vmap (self .F_fx ), device = self .jit_device )
284285 return self .analyzed_results [C .F_vmap_fx ]
285286
286287 @property
287288 def F_dfxdx (self ):
288289 """The function to evaluate :math:`\f rac{df_x(*\mathrm{vars}, *\mathrm{pars})}{dx}`."""
289290 if C .F_dfxdx not in self .analyzed_results :
290291 dfx = bm .vector_grad (self .F_fx , argnums = 0 )
291- self .analyzed_results [C .F_dfxdx ] = bm .jit (dfx , device = self .jit_device )
292+ self .analyzed_results [C .F_dfxdx ] = jax .jit (dfx , device = self .jit_device )
292293 return self .analyzed_results [C .F_dfxdx ]
293294
294295 @property
@@ -307,7 +308,7 @@ def F_vmap_fp_aux(self):
307308 # ---
308309 # "X": a two-dimensional matrix: (num_batch, num_var)
309310 # "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
310- self .analyzed_results [C .F_vmap_fp_aux ] = bm .jit (vmap (self .F_fixed_point_aux ))
311+ self .analyzed_results [C .F_vmap_fp_aux ] = jax .jit (vmap (self .F_fixed_point_aux ))
311312 return self .analyzed_results [C .F_vmap_fp_aux ]
312313
313314 @property
@@ -326,7 +327,7 @@ def F_vmap_fp_opt(self):
326327 # ---
327328 # "X": a two-dimensional matrix: (num_batch, num_var)
328329 # "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
329- self .analyzed_results [C .F_vmap_fp_opt ] = bm .jit (vmap (self .F_fixed_point_opt ))
330+ self .analyzed_results [C .F_vmap_fp_opt ] = jax .jit (vmap (self .F_fixed_point_opt ))
330331 return self .analyzed_results [C .F_vmap_fp_opt ]
331332
332333 def _get_fixed_points (self , candidates , * args , num_seg = None , tol_aux = 1e-7 , loss_screen = None ):
@@ -519,31 +520,31 @@ def F_y_by_x_in_fy(self):
519520 @property
520521 def F_vmap_fy (self ):
521522 if C .F_vmap_fy not in self .analyzed_results :
522- self .analyzed_results [C .F_vmap_fy ] = bm .jit (vmap (self .F_fy ), device = self .jit_device )
523+ self .analyzed_results [C .F_vmap_fy ] = jax .jit (vmap (self .F_fy ), device = self .jit_device )
523524 return self .analyzed_results [C .F_vmap_fy ]
524525
525526 @property
526527 def F_dfxdy (self ):
527528 """The function to evaluate :math:`\f rac{df_x (*\mathrm{vars}, *\mathrm{pars})}{dy}`."""
528529 if C .F_dfxdy not in self .analyzed_results :
529530 dfxdy = bm .vector_grad (self .F_fx , argnums = 1 )
530- self .analyzed_results [C .F_dfxdy ] = bm .jit (dfxdy , device = self .jit_device )
531+ self .analyzed_results [C .F_dfxdy ] = jax .jit (dfxdy , device = self .jit_device )
531532 return self .analyzed_results [C .F_dfxdy ]
532533
533534 @property
534535 def F_dfydx (self ):
535536 """The function to evaluate :math:`\f rac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dx}`."""
536537 if C .F_dfydx not in self .analyzed_results :
537538 dfydx = bm .vector_grad (self .F_fy , argnums = 0 )
538- self .analyzed_results [C .F_dfydx ] = bm .jit (dfydx , device = self .jit_device )
539+ self .analyzed_results [C .F_dfydx ] = jax .jit (dfydx , device = self .jit_device )
539540 return self .analyzed_results [C .F_dfydx ]
540541
541542 @property
542543 def F_dfydy (self ):
543544 """The function to evaluate :math:`\f rac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dy}`."""
544545 if C .F_dfydy not in self .analyzed_results :
545546 dfydy = bm .vector_grad (self .F_fy , argnums = 1 )
546- self .analyzed_results [C .F_dfydy ] = bm .jit (dfydy , device = self .jit_device )
547+ self .analyzed_results [C .F_dfydy ] = jax .jit (dfydy , device = self .jit_device )
547548 return self .analyzed_results [C .F_dfydy ]
548549
549550 @property
@@ -556,7 +557,7 @@ def f_jacobian(*var_and_pars):
556557
557558 def call (* var_and_pars ):
558559 var_and_pars = tuple ((vp .value if isinstance (vp , bm .Array ) else vp ) for vp in var_and_pars )
559- return jnp .array (bm .jit (f_jacobian , device = self .jit_device )(* var_and_pars ))
560+ return jnp .array (jax .jit (f_jacobian , device = self .jit_device )(* var_and_pars ))
560561
561562 self .analyzed_results [C .F_jacobian ] = call
562563 return self .analyzed_results [C .F_jacobian ]
@@ -681,7 +682,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
681682
682683 if self .F_x_by_y_in_fx is not None :
683684 utils .output ("I am evaluating fx-nullcline by F_x_by_y_in_fx ..." )
684- vmap_f = bm .jit (vmap (self .F_x_by_y_in_fx ), device = self .jit_device )
685+ vmap_f = jax .jit (vmap (self .F_x_by_y_in_fx ), device = self .jit_device )
685686 for j , pars in enumerate (par_seg ):
686687 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
687688 mesh_values = jnp .meshgrid (* ((ys ,) + pars ))
@@ -697,7 +698,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
697698
698699 elif self .F_y_by_x_in_fx is not None :
699700 utils .output ("I am evaluating fx-nullcline by F_y_by_x_in_fx ..." )
700- vmap_f = bm .jit (vmap (self .F_y_by_x_in_fx ), device = self .jit_device )
701+ vmap_f = jax .jit (vmap (self .F_y_by_x_in_fx ), device = self .jit_device )
701702 for j , pars in enumerate (par_seg ):
702703 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
703704 mesh_values = jnp .meshgrid (* ((xs ,) + pars ))
@@ -715,9 +716,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
715716 utils .output ("I am evaluating fx-nullcline by optimization ..." )
716717 # auxiliary functions
717718 f2 = lambda y , x , * pars : self .F_fx (x , y , * pars )
718- vmap_f2 = bm .jit (vmap (f2 ), device = self .jit_device )
719- vmap_brentq_f2 = bm .jit (vmap (utils .jax_brentq (f2 )), device = self .jit_device )
720- vmap_brentq_f1 = bm .jit (vmap (utils .jax_brentq (self .F_fx )), device = self .jit_device )
719+ vmap_f2 = jax .jit (vmap (f2 ), device = self .jit_device )
720+ vmap_brentq_f2 = jax .jit (vmap (utils .jax_brentq (f2 )), device = self .jit_device )
721+ vmap_brentq_f1 = jax .jit (vmap (utils .jax_brentq (self .F_fx )), device = self .jit_device )
721722
722723 # num segments
723724 for _j , Ps in enumerate (par_seg ):
@@ -774,7 +775,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
774775
775776 if self .F_x_by_y_in_fy is not None :
776777 utils .output ("I am evaluating fy-nullcline by F_x_by_y_in_fy ..." )
777- vmap_f = bm .jit (vmap (self .F_x_by_y_in_fy ), device = self .jit_device )
778+ vmap_f = jax .jit (vmap (self .F_x_by_y_in_fy ), device = self .jit_device )
778779 for j , pars in enumerate (par_seg ):
779780 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
780781 mesh_values = jnp .meshgrid (* ((ys ,) + pars ))
@@ -790,7 +791,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
790791
791792 elif self .F_y_by_x_in_fy is not None :
792793 utils .output ("I am evaluating fy-nullcline by F_y_by_x_in_fy ..." )
793- vmap_f = bm .jit (vmap (self .F_y_by_x_in_fy ), device = self .jit_device )
794+ vmap_f = jax .jit (vmap (self .F_y_by_x_in_fy ), device = self .jit_device )
794795 for j , pars in enumerate (par_seg ):
795796 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
796797 mesh_values = jnp .meshgrid (* ((xs ,) + pars ))
@@ -809,9 +810,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
809810
810811 # auxiliary functions
811812 f2 = lambda y , x , * pars : self .F_fy (x , y , * pars )
812- vmap_f2 = bm .jit (vmap (f2 ), device = self .jit_device )
813- vmap_brentq_f2 = bm .jit (vmap (utils .jax_brentq (f2 )), device = self .jit_device )
814- vmap_brentq_f1 = bm .jit (vmap (utils .jax_brentq (self .F_fy )), device = self .jit_device )
813+ vmap_f2 = jax .jit (vmap (f2 ), device = self .jit_device )
814+ vmap_brentq_f2 = jax .jit (vmap (utils .jax_brentq (f2 )), device = self .jit_device )
815+ vmap_brentq_f1 = jax .jit (vmap (utils .jax_brentq (self .F_fy )), device = self .jit_device )
815816
816817 for j , Ps in enumerate (par_seg ):
817818 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
@@ -859,7 +860,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
859860 xs = self .resolutions [self .x_var ]
860861 ys = self .resolutions [self .y_var ]
861862 P = tuple (self .resolutions [p ] for p in self .target_par_names )
862- f_select = bm .jit (vmap (lambda vals , ids : vals [ids ], in_axes = (1 , 1 )))
863+ f_select = jax .jit (vmap (lambda vals , ids : vals [ids ], in_axes = (1 , 1 )))
863864
864865 # num seguments
865866 if isinstance (num_segments , int ):
@@ -939,10 +940,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
939940
940941 if self .convert_type () == C .x_by_y :
941942 num_seg = len (self .resolutions [self .y_var ])
942- f_vmap = bm .jit (vmap (self .F_y_convert [1 ]))
943+ f_vmap = jax .jit (vmap (self .F_y_convert [1 ]))
943944 else :
944945 num_seg = len (self .resolutions [self .x_var ])
945- f_vmap = bm .jit (vmap (self .F_x_convert [1 ]))
946+ f_vmap = jax .jit (vmap (self .F_x_convert [1 ]))
946947 # get the signs
947948 signs = jnp .sign (f_vmap (candidates , * args ))
948949 signs = signs .reshape ((num_seg , - 1 ))
@@ -972,10 +973,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
972973 # get another value
973974 if self .convert_type () == C .x_by_y :
974975 y_values = fps
975- x_values = bm .jit (vmap (self .F_y_convert [0 ]))(y_values , * args )
976+ x_values = jax .jit (vmap (self .F_y_convert [0 ]))(y_values , * args )
976977 else :
977978 x_values = fps
978- y_values = bm .jit (vmap (self .F_x_convert [0 ]))(x_values , * args )
979+ y_values = jax .jit (vmap (self .F_x_convert [0 ]))(x_values , * args )
979980 fps = jnp .stack ([x_values , y_values ]).T
980981 return fps , selected_ids , args
981982
@@ -1042,7 +1043,7 @@ def F_fz(self):
10421043 wrapper = utils .std_derivative (arguments , self .target_var_names , self .target_par_names )
10431044 f = wrapper (self .model .f_derivatives [self .z_var ])
10441045 f = partial (f , ** (self .pars_update + self .fixed_vars ))
1045- self .analyzed_results [C .F_fz ] = bm .jit (f , device = self .jit_device )
1046+ self .analyzed_results [C .F_fz ] = jax .jit (f , device = self .jit_device )
10461047 return self .analyzed_results [C .F_fz ]
10471048
10481049 def fz_signs (self , pars = (), cache = False ):
0 commit comments