33from functools import partial
44
55import numpy as np
6+ from jax import vmap
67from jax import numpy as jnp
78from jax .scipy .optimize import minimize
89
@@ -262,7 +263,7 @@ def F_fx(self):
262263 @property
263264 def F_vmap_fx (self ):
264265 if C .F_vmap_fx not in self .analyzed_results :
265- self .analyzed_results [C .F_vmap_fx ] = bm .jit (bm . vmap (self .F_fx ), device = self .jit_device )
266+ self .analyzed_results [C .F_vmap_fx ] = bm .jit (vmap (self .F_fx ), device = self .jit_device )
266267 return self .analyzed_results [C .F_vmap_fx ]
267268
268269 @property
@@ -289,7 +290,7 @@ def F_vmap_fp_aux(self):
289290 # ---
290291 # "X": a two-dimensional matrix: (num_batch, num_var)
291292 # "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
292- self .analyzed_results [C .F_vmap_fp_aux ] = bm .jit (bm . vmap (self .F_fixed_point_aux ))
293+ self .analyzed_results [C .F_vmap_fp_aux ] = bm .jit (vmap (self .F_fixed_point_aux ))
293294 return self .analyzed_results [C .F_vmap_fp_aux ]
294295
295296 @property
@@ -308,7 +309,7 @@ def F_vmap_fp_opt(self):
308309 # ---
309310 # "X": a two-dimensional matrix: (num_batch, num_var)
310311 # "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
311- self .analyzed_results [C .F_vmap_fp_opt ] = bm .jit (bm . vmap (self .F_fixed_point_opt ))
312+ self .analyzed_results [C .F_vmap_fp_opt ] = bm .jit (vmap (self .F_fixed_point_opt ))
312313 return self .analyzed_results [C .F_vmap_fp_opt ]
313314
314315 def _get_fixed_points (self , candidates , * args , num_seg = None , tol_aux = 1e-7 , loss_screen = None ):
@@ -501,7 +502,7 @@ def F_y_by_x_in_fy(self):
501502 @property
502503 def F_vmap_fy (self ):
503504 if C .F_vmap_fy not in self .analyzed_results :
504- self .analyzed_results [C .F_vmap_fy ] = bm .jit (bm . vmap (self .F_fy ), device = self .jit_device )
505+ self .analyzed_results [C .F_vmap_fy ] = bm .jit (vmap (self .F_fy ), device = self .jit_device )
505506 return self .analyzed_results [C .F_vmap_fy ]
506507
507508 @property
@@ -663,7 +664,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
663664
664665 if self .F_x_by_y_in_fx is not None :
665666 utils .output ("I am evaluating fx-nullcline by F_x_by_y_in_fx ..." )
666- vmap_f = bm .jit (bm . vmap (self .F_x_by_y_in_fx ), device = self .jit_device )
667+ vmap_f = bm .jit (vmap (self .F_x_by_y_in_fx ), device = self .jit_device )
667668 for j , pars in enumerate (par_seg ):
668669 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
669670 mesh_values = jnp .meshgrid (* ((ys ,) + pars ))
@@ -679,7 +680,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
679680
680681 elif self .F_y_by_x_in_fx is not None :
681682 utils .output ("I am evaluating fx-nullcline by F_y_by_x_in_fx ..." )
682- vmap_f = bm .jit (bm . vmap (self .F_y_by_x_in_fx ), device = self .jit_device )
683+ vmap_f = bm .jit (vmap (self .F_y_by_x_in_fx ), device = self .jit_device )
683684 for j , pars in enumerate (par_seg ):
684685 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
685686 mesh_values = jnp .meshgrid (* ((xs ,) + pars ))
@@ -697,9 +698,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
697698 utils .output ("I am evaluating fx-nullcline by optimization ..." )
698699 # auxiliary functions
699700 f2 = lambda y , x , * pars : self .F_fx (x , y , * pars )
700- vmap_f2 = bm .jit (bm . vmap (f2 ), device = self .jit_device )
701- vmap_brentq_f2 = bm .jit (bm . vmap (utils .jax_brentq (f2 )), device = self .jit_device )
702- vmap_brentq_f1 = bm .jit (bm . vmap (utils .jax_brentq (self .F_fx )), device = self .jit_device )
701+ vmap_f2 = bm .jit (vmap (f2 ), device = self .jit_device )
702+ vmap_brentq_f2 = bm .jit (vmap (utils .jax_brentq (f2 )), device = self .jit_device )
703+ vmap_brentq_f1 = bm .jit (vmap (utils .jax_brentq (self .F_fx )), device = self .jit_device )
703704
704705 # num segments
705706 for _j , Ps in enumerate (par_seg ):
@@ -756,7 +757,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
756757
757758 if self .F_x_by_y_in_fy is not None :
758759 utils .output ("I am evaluating fy-nullcline by F_x_by_y_in_fy ..." )
759- vmap_f = bm .jit (bm . vmap (self .F_x_by_y_in_fy ), device = self .jit_device )
760+ vmap_f = bm .jit (vmap (self .F_x_by_y_in_fy ), device = self .jit_device )
760761 for j , pars in enumerate (par_seg ):
761762 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
762763 mesh_values = jnp .meshgrid (* ((ys ,) + pars ))
@@ -772,7 +773,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
772773
773774 elif self .F_y_by_x_in_fy is not None :
774775 utils .output ("I am evaluating fy-nullcline by F_y_by_x_in_fy ..." )
775- vmap_f = bm .jit (bm . vmap (self .F_y_by_x_in_fy ), device = self .jit_device )
776+ vmap_f = bm .jit (vmap (self .F_y_by_x_in_fy ), device = self .jit_device )
776777 for j , pars in enumerate (par_seg ):
777778 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
778779 mesh_values = jnp .meshgrid (* ((xs ,) + pars ))
@@ -791,9 +792,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
791792
792793 # auxiliary functions
793794 f2 = lambda y , x , * pars : self .F_fy (x , y , * pars )
794- vmap_f2 = bm .jit (bm . vmap (f2 ), device = self .jit_device )
795- vmap_brentq_f2 = bm .jit (bm . vmap (utils .jax_brentq (f2 )), device = self .jit_device )
796- vmap_brentq_f1 = bm .jit (bm . vmap (utils .jax_brentq (self .F_fy )), device = self .jit_device )
795+ vmap_f2 = bm .jit (vmap (f2 ), device = self .jit_device )
796+ vmap_brentq_f2 = bm .jit (vmap (utils .jax_brentq (f2 )), device = self .jit_device )
797+ vmap_brentq_f1 = bm .jit (vmap (utils .jax_brentq (self .F_fy )), device = self .jit_device )
797798
798799 for j , Ps in enumerate (par_seg ):
799800 if len (par_seg .arg_id_segments [0 ]) > 1 : utils .output (f"{ C .prefix } segment { j } ..." )
@@ -841,7 +842,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
841842 xs = self .resolutions [self .x_var ].value
842843 ys = self .resolutions [self .y_var ].value
843844 P = tuple (self .resolutions [p ].value for p in self .target_par_names )
844- f_select = bm .jit (bm . vmap (lambda vals , ids : vals [ids ], in_axes = (1 , 1 )))
845+ f_select = bm .jit (vmap (lambda vals , ids : vals [ids ], in_axes = (1 , 1 )))
845846
846847 # num seguments
847848 if isinstance (num_segments , int ):
@@ -921,10 +922,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
921922
922923 if self .convert_type () == C .x_by_y :
923924 num_seg = len (self .resolutions [self .y_var ])
924- f_vmap = bm .jit (bm . vmap (self .F_y_convert [1 ]))
925+ f_vmap = bm .jit (vmap (self .F_y_convert [1 ]))
925926 else :
926927 num_seg = len (self .resolutions [self .x_var ])
927- f_vmap = bm .jit (bm . vmap (self .F_x_convert [1 ]))
928+ f_vmap = bm .jit (vmap (self .F_x_convert [1 ]))
928929 # get the signs
929930 signs = jnp .sign (f_vmap (candidates , * args ))
930931 signs = signs .reshape ((num_seg , - 1 ))
@@ -954,10 +955,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
954955 # get another value
955956 if self .convert_type () == C .x_by_y :
956957 y_values = fps
957- x_values = bm .jit (bm . vmap (self .F_y_convert [0 ]))(y_values , * args )
958+ x_values = bm .jit (vmap (self .F_y_convert [0 ]))(y_values , * args )
958959 else :
959960 x_values = fps
960- y_values = bm .jit (bm . vmap (self .F_x_convert [0 ]))(x_values , * args )
961+ y_values = bm .jit (vmap (self .F_x_convert [0 ]))(x_values , * args )
961962 fps = jnp .stack ([x_values , y_values ]).T
962963 return fps , selected_ids , args
963964
0 commit comments