The problem
PyMC already has some timeseries capability, but this needs to be expanded to cover Bayesian Structural Time series (STS). We do have some good time series example notebooks:
- One is on a Prophet-like model. This kind of approach does not look too daunting to beginer/intermediate user as it basically comes down to creating some Fourier features as predictor variables. That said, it could be convenient if PyMC provided a utility function to do this
- Another is on Structural AR timeseries. While this notebook is excellent, it also demonstrates that it is non-trivial to implement and acts as a barrier to entry for some.
What we want
If we look at the Bayesian Structural Time Series section of the excellent Bayesian Modeling and Computation in Python book (by @aloctavodia, @canyon289, and @junpenglao) then we can see that it is trivial to build an STS model using tfp.sts
def generate_bsts_model(observed=None):
observed: Observed time series, tfp.sts use it to generate prior.
# Trend
trend = tfp.sts.LocalLinearTrend(observed_time_series=observed)
# Seasonal
seasonal = tfp.sts.Seasonal(num_seasons=12, observed_time_series=observed)
# Full model
return tfp.sts.Sum([trend, seasonal], observed_time_series=observed)
observed = tf.constant(us_monthly_birth["birth_in_thousands"], dtype=tf.float32)
birth_model = generate_bsts_model(observed=observed)
# Generate the posterior distribution conditioned on the observed
target_log_prob_fn = birth_model.joint_log_prob(observed_time_series=observed)
So long story short, it would be excellent if we can expand the native PyMC time series capabilities in general, but specifically for Bayesian STS.
Note: There is already an STS implementation in JAX here. One approach would be to call jax code using the kind of approach outlined in How to wrap a JAX function for use in PyMC. However, @ricardoV94 and @lucianopaz point out that that approach would restrict the backends available. Native PyMC/Aesara implementions would mean C/Python/Numba backends are available.
A list of relevant useful functionality that we may want to enable is given on the tfp.sts