44import arviz
55import calibr8
66import numpy
7- import scipy .stats
8- from calibr8 .utils import pm
7+ import pymc as pm
98from packaging import version
109
11- # Use the new ConstantData container if available,
12- # because it gives superior computational performance.
13- if hasattr (pm , "ConstantData" ):
14- pmData = pm .ConstantData
15- else :
16- pmData = pm .Data
17-
18-
1910try :
20- import aesara .tensor as at
11+ import pytensor .tensor as pt
2112except ModuleNotFoundError :
22- import theano .tensor as at
13+ import aesara .tensor as pt
2314
2415
2516_log = logging .getLogger (__file__ )
@@ -147,7 +138,7 @@ def sample(self, **kwargs) -> None:
147138 return_inferencedata = True ,
148139 target_accept = 0.95 ,
149140 init = "adapt_diag" ,
150- start = self .theta_map ,
141+ initvals = self .theta_map ,
151142 tune = 500 ,
152143 draws = 500 ,
153144 )
@@ -160,12 +151,14 @@ def sample(self, **kwargs) -> None:
160151def _make_random_walk (
161152 name : str ,
162153 * ,
154+ init_dist : pt .TensorVariable ,
163155 mu : float = 0 ,
164156 sigma : float ,
165157 nu : float = 1 ,
166158 length : int ,
167159 student_t : bool ,
168160 initval : numpy .ndarray = None ,
161+ dims : typing .Optional [str ] = None ,
169162):
170163 """Create a random walk with either a Normal or Student-t distribution.
171164
@@ -176,13 +169,12 @@ def _make_random_walk(
176169 ----------
177170 name : str
178171 Name of the random walk variable.
172+ init_dist
173+ A random variable to use as the prior for innovations.
179174 mu : float, array-like
180175 Mean of the random walk.
181- If a vector is passed, only the first element should be nonzero,
182- otherwise the random walk will drift systematically.
183176 sigma : float, array-like
184177 Standard deviation (Normal) or scale (StudentT) parameter.
185- A vector may be passed to customize, for example the prior at the start.
186178 nu : float, array-like
187179 Degree of freedom for the StudentT distribution - only used when `student_t == True`.
188180 length : int
@@ -193,6 +185,8 @@ def _make_random_walk(
193185 initval : numpy.ndarray
194186 Initial values for the RandomWalk variable.
195187 If set, PyMC uses these values as start points for MAP optimization and MCMC sampling.
188+ dims
189+ Optional dims to be forwarded to the `RandomWalk`.
196190
197191 Returns
198192 -------
@@ -201,43 +195,23 @@ def _make_random_walk(
201195 """
202196 pmversion = version .parse (pm .__version__ )
203197
204- # Adapt to rename of the testval→initval kwarg
205- if pmversion <= version .parse ("3.11.4" ):
206- initval_kwarg = "testval"
207- else :
208- initval_kwarg = "initval"
209-
210- if pmversion < version .parse ("4.0.0b1" ) and not student_t :
211- # Use the gaussian random walk distribution directly.
212- return pm .GaussianRandomWalk (
213- ** {
214- "name" : name ,
215- "mu" : mu ,
216- "sigma" : sigma ,
217- "shape" : (length ,),
218- initval_kwarg : initval ,
219- }
220- )
221- else :
222- # Create the random walk manually.
223- rv_kwargs = {
224- "name" : f"{ name } __diff_" ,
225- "mu" : mu ,
226- "sigma" : sigma ,
227- "shape" : (length ,),
228- # Since the initval refers to the random walk, but we're creating it
229- # using the cumsum of an RV, we need to do numpy.diff to get an initial
230- # value for the RV from the initial value of the random walk.
231- initval_kwarg : numpy .diff (initval , prepend = 0 ) if initval is not None else None ,
232- }
233-
234- if student_t :
235- rv_cls = pm .StudentT
236- rv_kwargs ["nu" ] = nu
237- else :
238- rv_cls = pm .Normal
198+ if pmversion < version .parse ("4.2.2" ):
199+ raise NotImplementedError ("PyMC versions <4.2.2 are no longer supported." )
239200
240- return pm .Deterministic (name , at .cumsum (rv_cls (** rv_kwargs )))
201+ if student_t :
202+ innov_dist = pm .StudentT .dist (mu = mu , sigma = sigma , nu = nu )
203+ else :
204+ innov_dist = pm .Normal .dist (mu = mu , sigma = sigma )
205+
206+ rw = pm .RandomWalk (
207+ name ,
208+ init_dist = init_dist ,
209+ innovation_dist = innov_dist ,
210+ steps = length - 1 ,
211+ initval = initval ,
212+ dims = dims ,
213+ )
214+ return rw
241215
242216
243217def _get_smoothed_mu (
@@ -348,8 +322,6 @@ def fit_mu_t(
348322 t_switchpoints_known = numpy .sort (list (switchpoints .keys ()))
349323 if student_t is None :
350324 student_t = len (switchpoints ) == 0
351- # build a dict of known switchpoint begin cycle indices so they can be ignored in autodetection
352- c_switchpoints_known = [0 ]
353325
354326 # Use a smoothed, diff-based growth rate on the backscatter to initialize the optimization.
355327 # These values are still everything but high-quality estimates of the growth rate,
@@ -361,23 +333,21 @@ def fit_mu_t(
361333 TD = len (t_data )
362334 TS = len (t_segments )
363335
364- # The mu_prior parameter is used to initialize the random walk at a more realistic growth rate.
365- # This can become necessary when there was no lag phase.
366- if mu_prior != 0 :
367- mu_prior = numpy .array ([mu_prior ] + [0 ] * (TS - 1 ))
368- # Override guess with user-provided mu_prior for nonzero starting points.
369- mu_guess [mu_prior != 0 ] = mu_prior [mu_prior != 0 ]
370-
371336 # build PyMC model
372337 coords = {
373338 "timepoint" : numpy .arange (TD ),
374339 "segment" : numpy .arange (TS ),
375340 }
376341 with pm .Model (coords = coords ) as pmodel :
377- pmData ("known_switchpoints" , t_switchpoints_known )
378- pmData ("t_data" , t_data , dims = "timepoint" )
379- pmData ("t_segments" , t_segments , dims = "segment" )
380- dt = pmData ("dt" , numpy .diff (t_data ), dims = "segment" )
342+ pm .ConstantData ("known_switchpoints" , t_switchpoints_known )
343+ pm .ConstantData ("t_data" , t_data , dims = "timepoint" )
344+ pm .ConstantData ("t_segments" , t_segments , dims = "segment" )
345+ dt = pm .ConstantData ("dt" , numpy .diff (t_data ), dims = "segment" )
346+
347+ # The init dist for the random walk is where each segment starts.
348+ # Here we center it on the user-provided mu_prior,
349+ # taking the absolute of it (+0.05 safety margin to avoid 0) as the scale.
350+ init_dist = pm .Normal .dist (mu = mu_prior , sigma = pt .abs (mu_prior ) + 0.05 )
381351
382352 if len (t_switchpoints_known ) > 0 :
383353 _log .info (
@@ -394,7 +364,8 @@ def fit_mu_t(
394364 mu_segments .append (
395365 _make_random_walk (
396366 name ,
397- mu = mu_prior [slc ],
367+ init_dist = init_dist ,
368+ mu = 0 ,
398369 sigma = drift_scale ,
399370 nu = nu ,
400371 length = i_len ,
@@ -403,49 +374,51 @@ def fit_mu_t(
403374 )
404375 )
405376 i_from += i_len
406- # remember the index to ignore it in potential autodetection
407- c_switchpoints_known .append (i_from )
408377 # the last segment until the end
409378 i_len = len (t [i_from :]) - 1
410379 name = f"mu_phase_{ len (mu_segments )} "
411380 slc = slice (i_from , None )
412381 mu_segments .append (
413382 _make_random_walk (
414383 name ,
415- mu_prior [slc ],
384+ init_dist = init_dist ,
385+ mu = 0 ,
416386 sigma = drift_scale ,
417387 nu = nu ,
418388 length = i_len ,
419389 student_t = student_t ,
420390 initval = mu_guess [slc ],
421391 )
422392 )
423- mu_t = pm .Deterministic ("mu_t" , at .concatenate (mu_segments ), dims = "segment" )
393+ mu_t = pm .Deterministic ("mu_t" , pt .concatenate (mu_segments ), dims = "segment" )
424394 else :
425395 _log .info (
426396 "Creating model without switchpoints. StudentT=%b" , len (t_switchpoints_known ), student_t
427397 )
428398 mu_t = _make_random_walk (
429399 "mu_t" ,
430- mu = mu_prior ,
400+ init_dist = init_dist ,
401+ mu = 0 ,
431402 sigma = drift_scale ,
432403 nu = nu ,
433404 length = TS ,
434405 student_t = student_t ,
435406 initval = mu_guess ,
407+ dims = "segment" ,
436408 )
437409
438410 X0 = pm .LogNormal ("X0" , mu = numpy .log (x0_prior ), sigma = 1 )
439411 Xt = pm .Deterministic (
440412 "X" ,
441- at .concatenate ([X0 [None ], X0 * pm .math .exp (at .extra_ops .cumsum (mu_t * dt ))]),
413+ pt .concatenate ([X0 [None ], X0 * pm .math .exp (pt .extra_ops .cumsum (mu_t * dt ))]),
442414 dims = "timepoint" ,
443415 )
444416 calibration_model .loglikelihood (
445417 x = Xt ,
446- y = pmData ("backscatter" , y , dims = ("timepoint" ,)),
418+ y = pm . ConstantData ("backscatter" , y , dims = ("timepoint" ,)),
447419 replicate_id = replicate_id ,
448420 dependent_key = calibration_model .dependent_key ,
421+ dims = "timepoint" ,
449422 )
450423
451424 # MAP fit
@@ -454,31 +427,14 @@ def fit_mu_t(
454427
455428 # with StudentT random walks, switchpoints can be autodetected
456429 if student_t :
457- # first CDF values at all mu_t elements
458- cdf_evals = []
459- for rvname in sorted (theta_map .keys ()):
460- if "__diff_" in rvname :
461- rv = pmodel [rvname ]
462- # for every µ, find out where it lies in the CDF of the StudentT prior distribution
463- cdf_evals += list (
464- scipy .stats .t .cdf (
465- x = theta_map [rvname ],
466- loc = rv .owner .inputs [3 ].eval (),
467- scale = rv .owner .inputs [4 ].eval (),
468- df = rv .owner .inputs [2 ].eval (),
469- )
470- )
471- cdf_evals = numpy .array (cdf_evals )
472- # filter for the elements that lie outside of the [0.005, 0.995] interval
473- significance_mask = numpy .logical_or (
474- cdf_evals < (switchpoint_prob / 2 ),
475- cdf_evals > (1 - switchpoint_prob / 2 ),
430+ switchpoints_detected = detect_switchpoints (
431+ switchpoint_prob ,
432+ t_data ,
433+ pmodel ,
434+ theta_map ,
476435 )
477- # add these autodetected timepoints to the switchpoints-dict
478- # (ignore the first timepoint)
479- for c_switch , (t_switch , is_switchpoint ) in enumerate (zip (t_data , significance_mask [1 :])):
480- if is_switchpoint and c_switch not in c_switchpoints_known :
481- switchpoints [t_switch ] = "detected"
436+ # Known switchpoints override detected ones 👇
437+ switchpoints = {** switchpoints_detected , ** switchpoints }
482438
483439 # bundle up all relevant variables into a result object
484440 result = GrowthRateResult (
@@ -500,3 +456,71 @@ def fit_mu_t(
500456 result .sample (draws = mcmc_samples )
501457
502458 return result
459+
460+
461+ def detect_switchpoints (
462+ switchpoint_prob : float ,
463+ t_data : typing .Sequence [float ],
464+ pmodel : pm .Model ,
465+ theta_map : typing .Dict [str , numpy .ndarray ],
466+ ) -> typing .Dict [float , str ]:
467+ """Helper function to detect switchpoints from a fitted random walk.
468+
469+ Parameters
470+ ----------
471+ switchpoint_prob
472+ Probability threshold for detecting switchpoints.
473+ Random walk innovations with a prior probability less than this
474+ will be classified as switchpoints.
475+ t_data
476+ Time values corresponding to the random walk steps.
477+ pmodel
478+ The PyMC model containing `"mu_t*"` random walks.
479+ theta_map
480+ MAP estimate of the model.
481+
482+ Returns
483+ -------
484+ switchpoints
485+ Dictionary of switchpoints with
486+ keys being the time point and
487+ values `"detected"`.
488+ """
489+ # first CDF values at all mu_t elements
490+ cdf_evals = []
491+ for rvname in sorted (theta_map .keys ()):
492+ if rvname not in pmodel .named_vars :
493+ continue
494+ # The random walk may be split in multiple segments.
495+ # We can identify a segment from the RVOp type that created it.
496+ rv = pmodel [rvname ]
497+ if rv .owner is None :
498+ continue
499+ if isinstance (rv .owner .op , pm .RandomWalk .rv_type ):
500+ # Get a handle on the innovation dist so we can evaluate prior CDFs.
501+ innov_dist = rv .owner .inputs [1 ]
502+ # Calculate the innovations from the MAP estimate of the points.
503+ # This gives only the deltas between the points, so the 0th element
504+ # in the new vector corresponds to the segment between the 0st and 1nd point.
505+ innov = numpy .diff (theta_map [rvname ])
506+ # Now we can evaluate the CDFs of the innovations.
507+ logcdfs = pm .logcdf (innov_dist , innov ).eval ()
508+ # We define switchpoints based on the time of the point with an extreme CDF value.
509+ # To get our <number of segments> length vector to align with the <number of points>,
510+ # we prepend a 0.5 as a placeholder for the CDF of the initial point of the random walk.
511+ cdf_evals += [0.5 , * numpy .exp (logcdfs )]
512+ cdf_evals = numpy .array (cdf_evals )
513+ if len (cdf_evals ) != len (t_data ) - 1 :
514+ raise Exception (
515+ f"Failed to find all random walk segments. Found { len (cdf_evals )} , expected { len (t_data ) - 1 } ."
516+ )
517+ # Filter for the elements that lie outside of the [0.005, 0.995] interval (if switchpoint_prob=0.01).
518+ significance_mask = numpy .logical_or (
519+ cdf_evals < (switchpoint_prob / 2 ),
520+ cdf_evals > (1 - switchpoint_prob / 2 ),
521+ )
522+ # Collect switchpoint information from points with significant CDF values.
523+ # Here we don't need to filter known switchpoints, because these correspond to the first
524+ # point in each random walk, for which we assigned non-significant 0.5 CDF placeholders above.
525+ switchpoints = {t : "detected" for t , is_switchpoint in zip (t_data , significance_mask ) if is_switchpoint }
526+ return switchpoints
0 commit comments