Skip to content

Commit 27ef1ea

Browse files
Merge pull request #32 from JuBiotech/pymc5
Drop PyMC3 compatibility in favor of PyMC v5
2 parents 5fc74b3 + cf1ca51 commit 27ef1ea

File tree

8 files changed

+386
-214
lines changed

8 files changed

+386
-214
lines changed

.github/workflows/pipeline.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ jobs:
1010
runs-on: ubuntu-latest
1111
strategy:
1212
matrix:
13-
python-version: [3.7, 3.8, 3.9]
14-
pymc-version: ["without", "pymc>=4.0.0", '"pymc3>=3.11.5" "numpy<1.22"']
13+
python-version: ["3.8", "3.9"]
14+
pymc-version: ["without", "'pymc>=4.2.2,<5'", "'pymc>=5.0.0'"]
1515
steps:
16-
- uses: actions/checkout@v2
16+
- uses: actions/checkout@v3
1717
- name: Set up Python ${{ matrix.python-version }}
18-
uses: actions/setup-python@v1
18+
uses: actions/setup-python@v4.3.1
1919
with:
2020
python-version: ${{ matrix.python-version }}
2121
- name: Install dependencies

.github/workflows/release.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ jobs:
1212
env:
1313
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
1414
steps:
15-
- uses: actions/checkout@v2
15+
- uses: actions/checkout@v3
1616
- name: Set up Python
17-
uses: actions/setup-python@v1
17+
uses: actions/setup-python@v4.3.1
1818
with:
19-
python-version: 3.7
19+
python-version: 3.9
2020
- name: Install dependencies
2121
run: |
2222
pip install -e .

bletl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
NoMeasurementData,
2020
)
2121

22-
__version__ = "1.1.3"
22+
__version__ = "1.2.0"

bletl/growth.py

Lines changed: 121 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,13 @@
44
import arviz
55
import calibr8
66
import numpy
7-
import scipy.stats
8-
from calibr8.utils import pm
7+
import pymc as pm
98
from packaging import version
109

11-
# Use the new ConstantData container if available,
12-
# because it gives superior computational performance.
13-
if hasattr(pm, "ConstantData"):
14-
pmData = pm.ConstantData
15-
else:
16-
pmData = pm.Data
17-
18-
1910
try:
20-
import aesara.tensor as at
11+
import pytensor.tensor as pt
2112
except ModuleNotFoundError:
22-
import theano.tensor as at
13+
import aesara.tensor as pt
2314

2415

2516
_log = logging.getLogger(__file__)
@@ -147,7 +138,7 @@ def sample(self, **kwargs) -> None:
147138
return_inferencedata=True,
148139
target_accept=0.95,
149140
init="adapt_diag",
150-
start=self.theta_map,
141+
initvals=self.theta_map,
151142
tune=500,
152143
draws=500,
153144
)
@@ -160,12 +151,14 @@ def sample(self, **kwargs) -> None:
160151
def _make_random_walk(
161152
name: str,
162153
*,
154+
init_dist: pt.TensorVariable,
163155
mu: float = 0,
164156
sigma: float,
165157
nu: float = 1,
166158
length: int,
167159
student_t: bool,
168160
initval: numpy.ndarray = None,
161+
dims: typing.Optional[str] = None,
169162
):
170163
"""Create a random walk with either a Normal or Student-t distribution.
171164
@@ -176,13 +169,12 @@ def _make_random_walk(
176169
----------
177170
name : str
178171
Name of the random walk variable.
172+
init_dist
173+
A random variable to use as the prior for innovations.
179174
mu : float, array-like
180175
Mean of the random walk.
181-
If a vector is passed, only the first element should be nonzero,
182-
otherwise the random walk will drift systematically.
183176
sigma : float, array-like
184177
Standard deviation (Normal) or scale (StudentT) parameter.
185-
A vector may be passed to customize, for example the prior at the start.
186178
nu : float, array-like
187179
Degree of freedom for the StudentT distribution - only used when `student_t == True`.
188180
length : int
@@ -193,6 +185,8 @@ def _make_random_walk(
193185
initval : numpy.ndarray
194186
Initial values for the RandomWalk variable.
195187
If set, PyMC uses these values as start points for MAP optimization and MCMC sampling.
188+
dims
189+
Optional dims to be forwarded to the `RandomWalk`.
196190
197191
Returns
198192
-------
@@ -201,43 +195,23 @@ def _make_random_walk(
201195
"""
202196
pmversion = version.parse(pm.__version__)
203197

204-
# Adapt to rename of the testval→initval kwarg
205-
if pmversion <= version.parse("3.11.4"):
206-
initval_kwarg = "testval"
207-
else:
208-
initval_kwarg = "initval"
209-
210-
if pmversion < version.parse("4.0.0b1") and not student_t:
211-
# Use the gaussian random walk distribution directly.
212-
return pm.GaussianRandomWalk(
213-
**{
214-
"name": name,
215-
"mu": mu,
216-
"sigma": sigma,
217-
"shape": (length,),
218-
initval_kwarg: initval,
219-
}
220-
)
221-
else:
222-
# Create the random walk manually.
223-
rv_kwargs = {
224-
"name": f"{name}__diff_",
225-
"mu": mu,
226-
"sigma": sigma,
227-
"shape": (length,),
228-
# Since the initval refers to the random walk, but we're creating it
229-
# using the cumsum of an RV, we need to do numpy.diff to get an initial
230-
# value for the RV from the initial value of the random walk.
231-
initval_kwarg: numpy.diff(initval, prepend=0) if initval is not None else None,
232-
}
233-
234-
if student_t:
235-
rv_cls = pm.StudentT
236-
rv_kwargs["nu"] = nu
237-
else:
238-
rv_cls = pm.Normal
198+
if pmversion < version.parse("4.2.2"):
199+
raise NotImplementedError("PyMC versions <4.2.2 are no longer supported.")
239200

240-
return pm.Deterministic(name, at.cumsum(rv_cls(**rv_kwargs)))
201+
if student_t:
202+
innov_dist = pm.StudentT.dist(mu=mu, sigma=sigma, nu=nu)
203+
else:
204+
innov_dist = pm.Normal.dist(mu=mu, sigma=sigma)
205+
206+
rw = pm.RandomWalk(
207+
name,
208+
init_dist=init_dist,
209+
innovation_dist=innov_dist,
210+
steps=length - 1,
211+
initval=initval,
212+
dims=dims,
213+
)
214+
return rw
241215

242216

243217
def _get_smoothed_mu(
@@ -348,8 +322,6 @@ def fit_mu_t(
348322
t_switchpoints_known = numpy.sort(list(switchpoints.keys()))
349323
if student_t is None:
350324
student_t = len(switchpoints) == 0
351-
# build a dict of known switchpoint begin cycle indices so they can be ignored in autodetection
352-
c_switchpoints_known = [0]
353325

354326
# Use a smoothed, diff-based growth rate on the backscatter to initialize the optimization.
355327
# These values are still everything but high-quality estimates of the growth rate,
@@ -361,23 +333,21 @@ def fit_mu_t(
361333
TD = len(t_data)
362334
TS = len(t_segments)
363335

364-
# The mu_prior parameter is used to initialize the random walk at a more realistic growth rate.
365-
# This can become necessary when there was no lag phase.
366-
if mu_prior != 0:
367-
mu_prior = numpy.array([mu_prior] + [0] * (TS - 1))
368-
# Override guess with user-provided mu_prior for nonzero starting points.
369-
mu_guess[mu_prior != 0] = mu_prior[mu_prior != 0]
370-
371336
# build PyMC model
372337
coords = {
373338
"timepoint": numpy.arange(TD),
374339
"segment": numpy.arange(TS),
375340
}
376341
with pm.Model(coords=coords) as pmodel:
377-
pmData("known_switchpoints", t_switchpoints_known)
378-
pmData("t_data", t_data, dims="timepoint")
379-
pmData("t_segments", t_segments, dims="segment")
380-
dt = pmData("dt", numpy.diff(t_data), dims="segment")
342+
pm.ConstantData("known_switchpoints", t_switchpoints_known)
343+
pm.ConstantData("t_data", t_data, dims="timepoint")
344+
pm.ConstantData("t_segments", t_segments, dims="segment")
345+
dt = pm.ConstantData("dt", numpy.diff(t_data), dims="segment")
346+
347+
# The init dist for the random walk is where each segment starts.
348+
# Here we center it on the user-provided mu_prior,
349+
# taking the absolute of it (+0.05 safety margin to avoid 0) as the scale.
350+
init_dist = pm.Normal.dist(mu=mu_prior, sigma=pt.abs(mu_prior) + 0.05)
381351

382352
if len(t_switchpoints_known) > 0:
383353
_log.info(
@@ -394,7 +364,8 @@ def fit_mu_t(
394364
mu_segments.append(
395365
_make_random_walk(
396366
name,
397-
mu=mu_prior[slc],
367+
init_dist=init_dist,
368+
mu=0,
398369
sigma=drift_scale,
399370
nu=nu,
400371
length=i_len,
@@ -403,49 +374,51 @@ def fit_mu_t(
403374
)
404375
)
405376
i_from += i_len
406-
# remember the index to ignore it in potential autodetection
407-
c_switchpoints_known.append(i_from)
408377
# the last segment until the end
409378
i_len = len(t[i_from:]) - 1
410379
name = f"mu_phase_{len(mu_segments)}"
411380
slc = slice(i_from, None)
412381
mu_segments.append(
413382
_make_random_walk(
414383
name,
415-
mu_prior[slc],
384+
init_dist=init_dist,
385+
mu=0,
416386
sigma=drift_scale,
417387
nu=nu,
418388
length=i_len,
419389
student_t=student_t,
420390
initval=mu_guess[slc],
421391
)
422392
)
423-
mu_t = pm.Deterministic("mu_t", at.concatenate(mu_segments), dims="segment")
393+
mu_t = pm.Deterministic("mu_t", pt.concatenate(mu_segments), dims="segment")
424394
else:
425395
_log.info(
426396
"Creating model without switchpoints. StudentT=%b", len(t_switchpoints_known), student_t
427397
)
428398
mu_t = _make_random_walk(
429399
"mu_t",
430-
mu=mu_prior,
400+
init_dist=init_dist,
401+
mu=0,
431402
sigma=drift_scale,
432403
nu=nu,
433404
length=TS,
434405
student_t=student_t,
435406
initval=mu_guess,
407+
dims="segment",
436408
)
437409

438410
X0 = pm.LogNormal("X0", mu=numpy.log(x0_prior), sigma=1)
439411
Xt = pm.Deterministic(
440412
"X",
441-
at.concatenate([X0[None], X0 * pm.math.exp(at.extra_ops.cumsum(mu_t * dt))]),
413+
pt.concatenate([X0[None], X0 * pm.math.exp(pt.extra_ops.cumsum(mu_t * dt))]),
442414
dims="timepoint",
443415
)
444416
calibration_model.loglikelihood(
445417
x=Xt,
446-
y=pmData("backscatter", y, dims=("timepoint",)),
418+
y=pm.ConstantData("backscatter", y, dims=("timepoint",)),
447419
replicate_id=replicate_id,
448420
dependent_key=calibration_model.dependent_key,
421+
dims="timepoint",
449422
)
450423

451424
# MAP fit
@@ -454,31 +427,14 @@ def fit_mu_t(
454427

455428
# with StudentT random walks, switchpoints can be autodetected
456429
if student_t:
457-
# first CDF values at all mu_t elements
458-
cdf_evals = []
459-
for rvname in sorted(theta_map.keys()):
460-
if "__diff_" in rvname:
461-
rv = pmodel[rvname]
462-
# for every µ, find out where it lies in the CDF of the StudentT prior distribution
463-
cdf_evals += list(
464-
scipy.stats.t.cdf(
465-
x=theta_map[rvname],
466-
loc=rv.owner.inputs[3].eval(),
467-
scale=rv.owner.inputs[4].eval(),
468-
df=rv.owner.inputs[2].eval(),
469-
)
470-
)
471-
cdf_evals = numpy.array(cdf_evals)
472-
# filter for the elements that lie outside of the [0.005, 0.995] interval
473-
significance_mask = numpy.logical_or(
474-
cdf_evals < (switchpoint_prob / 2),
475-
cdf_evals > (1 - switchpoint_prob / 2),
430+
switchpoints_detected = detect_switchpoints(
431+
switchpoint_prob,
432+
t_data,
433+
pmodel,
434+
theta_map,
476435
)
477-
# add these autodetected timepoints to the switchpoints-dict
478-
# (ignore the first timepoint)
479-
for c_switch, (t_switch, is_switchpoint) in enumerate(zip(t_data, significance_mask[1:])):
480-
if is_switchpoint and c_switch not in c_switchpoints_known:
481-
switchpoints[t_switch] = "detected"
436+
# Known switchpoints override detected ones 👇
437+
switchpoints = {**switchpoints_detected, **switchpoints}
482438

483439
# bundle up all relevant variables into a result object
484440
result = GrowthRateResult(
@@ -500,3 +456,71 @@ def fit_mu_t(
500456
result.sample(draws=mcmc_samples)
501457

502458
return result
459+
460+
461+
def detect_switchpoints(
462+
switchpoint_prob: float,
463+
t_data: typing.Sequence[float],
464+
pmodel: pm.Model,
465+
theta_map: typing.Dict[str, numpy.ndarray],
466+
) -> typing.Dict[float, str]:
467+
"""Helper function to detect switchpoints from a fitted random walk.
468+
469+
Parameters
470+
----------
471+
switchpoint_prob
472+
Probability threshold for detecting switchpoints.
473+
Random walk innovations with a prior probability less than this
474+
will be classified as switchpoints.
475+
t_data
476+
Time values corresponding to the random walk steps.
477+
pmodel
478+
The PyMC model containing `"mu_t*"` random walks.
479+
theta_map
480+
MAP estimate of the model.
481+
482+
Returns
483+
-------
484+
switchpoints
485+
Dictionary of switchpoints with
486+
keys being the time point and
487+
values `"detected"`.
488+
"""
489+
# first CDF values at all mu_t elements
490+
cdf_evals = []
491+
for rvname in sorted(theta_map.keys()):
492+
if rvname not in pmodel.named_vars:
493+
continue
494+
# The random walk may be split in multiple segments.
495+
# We can identify a segment from the RVOp type that created it.
496+
rv = pmodel[rvname]
497+
if rv.owner is None:
498+
continue
499+
if isinstance(rv.owner.op, pm.RandomWalk.rv_type):
500+
# Get a handle on the innovation dist so we can evaluate prior CDFs.
501+
innov_dist = rv.owner.inputs[1]
502+
# Calculate the innovations from the MAP estimate of the points.
503+
# This gives only the deltas between the points, so the 0th element
504+
# in the new vector corresponds to the segment between the 0st and 1nd point.
505+
innov = numpy.diff(theta_map[rvname])
506+
# Now we can evaluate the CDFs of the innovations.
507+
logcdfs = pm.logcdf(innov_dist, innov).eval()
508+
# We define switchpoints based on the time of the point with an extreme CDF value.
509+
# To get our <number of segments> length vector to align with the <number of points>,
510+
# we prepend a 0.5 as a placeholder for the CDF of the initial point of the random walk.
511+
cdf_evals += [0.5, *numpy.exp(logcdfs)]
512+
cdf_evals = numpy.array(cdf_evals)
513+
if len(cdf_evals) != len(t_data) - 1:
514+
raise Exception(
515+
f"Failed to find all random walk segments. Found {len(cdf_evals)}, expected {len(t_data) - 1}."
516+
)
517+
# Filter for the elements that lie outside of the [0.005, 0.995] interval (if switchpoint_prob=0.01).
518+
significance_mask = numpy.logical_or(
519+
cdf_evals < (switchpoint_prob / 2),
520+
cdf_evals > (1 - switchpoint_prob / 2),
521+
)
522+
# Collect switchpoint information from points with significant CDF values.
523+
# Here we don't need to filter known switchpoints, because these correspond to the first
524+
# point in each random walk, for which we assigned non-significant 0.5 CDF placeholders above.
525+
switchpoints = {t: "detected" for t, is_switchpoint in zip(t_data, significance_mask) if is_switchpoint}
526+
return switchpoints

0 commit comments

Comments
 (0)