11import logging
22import typing
3+ from typing import Dict , Optional , Sequence , Tuple , Union
34
45import arviz
56import calibr8
1011try :
1112 import pytensor .tensor as pt
1213except ModuleNotFoundError :
13- import aesara .tensor as pt
14+ import aesara .tensor as pt # type: ignore
1415
1516
1617_log = logging .getLogger (__file__ )
@@ -22,13 +23,13 @@ class GrowthRateResult:
2223 def __init__ (
2324 self ,
2425 * ,
25- t_data : numpy .ndarray ,
26- t_segments : numpy .ndarray ,
27- y : numpy .ndarray ,
26+ t_data : Union [ Sequence [ float ], numpy .ndarray ] ,
27+ t_segments : Union [ Sequence [ float ], numpy .ndarray ] ,
28+ y : Union [ Sequence [ float ], numpy .ndarray ] ,
2829 calibration_model : calibr8 .CalibrationModel ,
29- switchpoints : typing . Dict [float , str ],
30+ switchpoints : Dict [float , str ],
3031 pmodel : pm .Model ,
31- theta_map : dict ,
32+ theta_map : Dict [ str , numpy . ndarray ] ,
3233 ):
3334 """Creates a result object of a growth rate analysis.
3435
@@ -47,9 +48,9 @@ def __init__(
4748 theta_map : dict
4849 the PyMC MAP estimate
4950 """
50- self ._t_data = t_data
51- self ._t_segments = t_segments
52- self ._y = y
51+ self ._t_data = numpy . asarray ( t_data )
52+ self ._t_segments = numpy . asarray ( t_segments )
53+ self ._y = numpy . asarray ( y )
5354 self ._switchpoints = switchpoints
5455 self .calibration_model = calibration_model
5556 self ._pmodel = pmodel
@@ -73,17 +74,17 @@ def y(self) -> numpy.ndarray:
7374 return self ._y
7475
7576 @property
76- def switchpoints (self ) -> typing . Dict [float , str ]:
77+ def switchpoints (self ) -> Dict [float , str ]:
7778 """Dictionary (by time) of known and detected switchpoints."""
7879 return self ._switchpoints
7980
8081 @property
81- def known_switchpoints (self ) -> typing . Tuple [float ]:
82+ def known_switchpoints (self ) -> Tuple [float , ... ]:
8283 """Time values of previously known switchpoints in the model."""
8384 return tuple (t for t , label in self .switchpoints .items () if label != "detected" )
8485
8586 @property
86- def detected_switchpoints (self ) -> typing . Tuple [float ]:
87+ def detected_switchpoints (self ) -> Tuple [float , ... ]:
8788 """Time values of switchpoints that were autodetected from the fit."""
8889 return tuple (t for t , label in self .switchpoints .items () if label == "detected" )
8990
@@ -93,12 +94,12 @@ def pmodel(self) -> pm.Model:
9394 return self ._pmodel
9495
9596 @property
96- def theta_map (self ) -> dict :
97+ def theta_map (self ) -> Dict [ str , numpy . ndarray ] :
9798 """MAP estimate of the model parameters."""
9899 return self ._theta_map
99100
100101 @property
101- def idata (self ) -> typing . Optional [arviz .InferenceData ]:
102+ def idata (self ) -> Optional [arviz .InferenceData ]:
102103 """ArviZ InferenceData object of the MCMC trace."""
103104 return self ._idata
104105
@@ -113,18 +114,20 @@ def x_map(self) -> numpy.ndarray:
113114 return self .theta_map ["X" ]
114115
115116 @property
116- def mu_mcmc (self ) -> typing . Optional [numpy .ndarray ]:
117+ def mu_mcmc (self ) -> Optional [numpy .ndarray ]:
117118 """Posterior samples of growth rates in segments between data points."""
118119 if not self .idata :
119120 return None
121+ assert hasattr (self .idata , "posterior" )
120122 return self .idata .posterior .mu_t .stack (sample = ("chain" , "draw" )).values .T
121123
122124 @property
123- def x_mcmc (self ) -> typing . Optional [numpy .ndarray ]:
125+ def x_mcmc (self ) -> Optional [numpy .ndarray ]:
124126 """Posterior samples of biomass curve."""
125- if not self .idata :
127+ if self .idata is None :
126128 return None
127- return self ._idata .posterior ["X" ].stack (sample = ("chain" , "draw" )).T
129+ assert hasattr (self .idata , "posterior" )
130+ return self .idata .posterior ["X" ].stack (sample = ("chain" , "draw" )).T
128131
129132 def sample (self , ** kwargs ) -> None :
130133 """Runs MCMC sampling with default settings on the growth model.
@@ -157,8 +160,8 @@ def _make_random_walk(
157160 nu : float = 1 ,
158161 length : int ,
159162 student_t : bool ,
160- initval : numpy .ndarray = None ,
161- dims : typing . Optional [str ] = None ,
163+ initval : Optional [ numpy .ndarray ] = None ,
164+ dims : Optional [str ] = None ,
162165):
163166 """Create a random walk with either a Normal or Student-t distribution.
164167
@@ -215,7 +218,11 @@ def _make_random_walk(
215218
216219
217220def _get_smoothed_mu (
218- t : numpy .ndarray , y : numpy .ndarray , cm_cdw : calibr8 .CalibrationModel , * , clip = 0.5
221+ t : Sequence [float ],
222+ y : Sequence [float ],
223+ cm_cdw : calibr8 .CalibrationModel ,
224+ * ,
225+ clip : float = 0.5 ,
219226) -> numpy .ndarray :
220227 """Calculate a rough estimate of the specific growth rate from smoothed observations.
221228
@@ -236,10 +243,10 @@ def _get_smoothed_mu(
236243 A vector of specific growth rates.
237244 """
238245 # apply moving average to reduce backscatter noise
239- y = numpy .convolve (y , numpy .ones (5 ) / 5 , "same" )
246+ yarr = numpy .convolve (y , numpy .ones (5 ) / 5 , "same" )
240247
241248 # convert to biomass
242- X = cm_cdw .predict_independent (y )
249+ X = cm_cdw .predict_independent (yarr )
243250
244251 # calculate growth rate
245252 dX = numpy .diff (X )
@@ -259,17 +266,17 @@ def _get_smoothed_mu(
259266
260267
261268def fit_mu_t (
262- t : typing . Sequence [float ],
263- y : typing . Sequence [float ],
269+ t : Sequence [float ],
270+ y : Sequence [float ],
264271 calibration_model : calibr8 .CalibrationModel ,
265272 * ,
266- switchpoints : typing . Optional [typing . Union [typing . Sequence [float ], typing . Dict [float , str ]]] = None ,
273+ switchpoints : Optional [Union [Sequence [float ], Dict [float , str ]]] = None ,
267274 mcmc_samples : int = 0 ,
268275 mu_prior : float = 0 ,
269276 drift_scale : float ,
270277 nu : float = 5 ,
271278 x0_prior : float = 0.25 ,
272- student_t : typing . Optional [bool ] = None ,
279+ student_t : Optional [bool ] = None ,
273280 switchpoint_prob : float = 0.01 ,
274281 replicate_id : str = "unnamed" ,
275282):
@@ -357,7 +364,7 @@ def fit_mu_t(
357364 mu_segments = []
358365 i_from = 0
359366 for i , t_switch in enumerate (t_switchpoints_known ):
360- i_to = numpy .argmax (t > t_switch )
367+ i_to = int ( numpy .argmax (t > t_switch ) )
361368 i_len = len (t [i_from :i_to ])
362369 name = f"mu_phase_{ i } "
363370 slc = slice (i_from , i_to )
@@ -460,10 +467,10 @@ def fit_mu_t(
460467
461468def detect_switchpoints (
462469 switchpoint_prob : float ,
463- t_data : typing . Sequence [float ],
470+ t_data : Sequence [float ],
464471 pmodel : pm .Model ,
465- theta_map : typing . Dict [str , numpy .ndarray ],
466- ) -> typing . Dict [float , str ]:
472+ theta_map : Dict [str , numpy .ndarray ],
473+ ) -> Dict [float , str ]:
467474 """Helper function to detect switchpoints from a fitted random walk.
468475
469476 Parameters
@@ -509,15 +516,15 @@ def detect_switchpoints(
509516 # To get our <number of segments> length vector to align with the <number of points>,
510517 # we prepend a 0.5 as a placeholder for the CDF of the initial point of the random walk.
511518 cdf_evals += [0.5 , * numpy .exp (logcdfs )]
512- cdf_evals = numpy .array (cdf_evals )
513- if len (cdf_evals ) != len (t_data ) - 1 :
519+ cdf_evals_arr = numpy .array (cdf_evals )
520+ if len (cdf_evals_arr ) != len (t_data ) - 1 :
514521 raise Exception (
515- f"Failed to find all random walk segments. Found { len (cdf_evals )} , expected { len (t_data ) - 1 } ."
522+ f"Failed to find all random walk segments. Found { len (cdf_evals_arr )} , expected { len (t_data ) - 1 } ."
516523 )
517524 # Filter for the elements that lie outside of the [0.005, 0.995] interval (if switchpoint_prob=0.01).
518525 significance_mask = numpy .logical_or (
519- cdf_evals < (switchpoint_prob / 2 ),
520- cdf_evals > (1 - switchpoint_prob / 2 ),
526+ cdf_evals_arr < (switchpoint_prob / 2 ),
527+ cdf_evals_arr > (1 - switchpoint_prob / 2 ),
521528 )
522529 # Collect switchpoint information from points with significant CDF values.
523530 # Here we don't need to filter known switchpoints, because these correspond to the first
0 commit comments