Skip to content

Hierarchical GAM demo#272

Closed
Fuhan-Yang wants to merge 9 commits into
mainfrom
fy_gam_hi2
Closed

Hierarchical GAM demo#272
Fuhan-Yang wants to merge 9 commits into
mainfrom
fy_gam_hi2

Conversation

@Fuhan-Yang
Copy link
Copy Markdown
Contributor

@Fuhan-Yang Fuhan-Yang commented Feb 12, 2026

For vaccine uptake in a single season, the coefficient controlling each basis function are estimated given design matrix and penalty matrix. The final outcome of estimated coefficients is a vector $\beta$, along with other estimates: $\lambda, \sigma$.

When vaccine uptake data are from multiple seasons and states, instead of directly estimating $\beta$, the deviation vector $\delta$ of a certain state and a certain season from the population mean of $\beta$ is estimated. Each element in $\delta$ is the deviation from the population coefficient controlling each basis function.

scipy.interpolate only accepts the data from a single season. Thus, when adding group factors, the design matrix and penalty matrix for a single season, for each level in season ($i$) and state ($j$), are calculated, and then is used to estimate $\delta_{season=i, state=j}$.

There exists an error about the incompatibility between the np.array from make_lsq_spline and the jnp.array used in numpyro. This requries refactoring the code to put the function of getting design matrix and penalty matrix outside of the numpyro model. I want to make sure we are on the same page before that!

@Fuhan-Yang Fuhan-Yang requested a review from swo February 16, 2026 19:38
Copy link
Copy Markdown
Collaborator

@swo swo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finally started to dig into the math and the code and have a lot of questions!

Comment thread docs/gam.md Outdated
Comment thread docs/gam.md Outdated
k = m + p - 1
$$

The element in $X$ is the value of the basis function evaluated at the predictor elapsed, with rows are data point $x_i$ , and columns are basis function $B_k$.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unclear on what the basis functions are. Presumably there are many that you can pick from. What are their functional forms? Do you optimize over this selection?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also confusing that $k$ is used as a fixed constant (the length of $\beta$) but also as an index.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some more reading, it seems like these are https://en.wikipedia.org/wiki/B-spline ?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, basis splines are the way to go when doing something like this. Bases are pre-computable, limiting the work you need to do in NumPyro

Comment thread docs/gam.md
```math
\begin{align*}
p(\beta,\lambda,\sigma |y) & ∝ p(y |\beta,\sigma)p(\beta|\lambda, \sigma)p(\lambda)p(\sigma) \\
p(y|\beta, \sigma) & \sim MultiNormal(X\beta, \sigma I) \\
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more easily written as

$$ p(y_i | \beta, \sigma) \sim \mathrm{Norm}((X \beta)_i, \sigma) $$

Comment thread docs/gam.md Outdated
\begin{align*}
p(\beta,\lambda,\sigma |y) & ∝ p(y |\beta,\sigma)p(\beta|\lambda, \sigma)p(\lambda)p(\sigma) \\
p(y|\beta, \sigma) & \sim MultiNormal(X\beta, \sigma I) \\
p(\beta|\lambda, \sigma) & \sim MultiNormal(0, (\sigma/\lambda)S^{-}) \\
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be $S^{-1}$?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be $S^{-1}$, there is paper talking about using $S^-$ as pseudoinverse matrix, when there is singular problem (?) but generally it should be $S^{-1}$.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, pseudoinverse matrix is used when $S$ is not full rank or square, which may not be an issue here...so we can directly use inverse matrix.

Comment thread docs/gam.md Outdated
#### Deriving prior of $\beta$

As we assume the link function is identity, that indicates the data $y$ follows normal distribution with covariance matrix $\sigma^2I$.
As we assume the link function is identity, that indicates the data $y$ follows normal distribution with covariance matrix $\sigma I$.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow; you could have a different link function and still have the data be normally distributed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Identity function is the canonical link function of normal distribution (of response variable). People typically use it to indicate a linear regression, where the error is normally distributed:

$$\begin{align*} y_i &\sim N( \beta_1x_i+ \beta_0, \sigma) \\\ g^{-1}(E(y_i))& = \beta_1x_i+ \beta_0 \\\ \epsilon_i& \sim N(0,\sigma) \end{align*}$$

but other link function can be used with the error is still normally distributed (like when $g^{-1} = log(.)$,

$$\begin{align*} log(y_i) &\sim N( \beta_1x_i+ \beta_0, \sigma) \\\ g^{-1}(E(y_i))& = \beta_1x_i+ \beta_0 \\\ log(E(y_i)) &= \beta_1x_i + \beta_0 \\\ \epsilon_i& \sim N(0,\sigma) \end{align*}$$

Let me rephrase!

Comment thread docs/gam.md Outdated

The group factors are season ($s$) and geography ($g$). Their effects are introduced in $\beta$, to allow varying shape of the spline function adjusted by each $\beta_k$ to control the corresponding basis function $B_k$.

Given the population mean of $\beta$, denoted as $\bar \beta$, the $\beta$ specific to a certain season ($s=i$) and certain geography ($g=j$) is $\bar \beta$ plus vector $\delta_{s=i}$ and $\delta_{g=j}$. $\delta_{s=i}$ defines the deviation of the certain season from $\bar \beta$ and the certain geography from $\bar \beta$.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually when people use a single symbol like $\delta$ with a single subscript, it means that $\delta$ is a vector, and the value of the subscript determines which index of $\delta$ to look at.

So $\delta_s$ means "give me the $s$-th value of the vector $\delta$" and $\delta_g$ means "give me the $g$-th value of the vector $\delta$".

If you want to have different kinds of $\delta$'s, then you'll need a different kind of indexing. The simplest thing is to have two indices, and say $\delta_{0s}$ refers to the $s$-th seasonal deviation and $\delta_{1g}$ is the $g$-th geographical deviation.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know what you mean, but it's very non-standard notation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm realizing that part of this is because people use lowercase letters to mean indices. It wouldn't be crazy to use uppercase, so that $\delta_{Ss} \sim \mathcal{N}(0, \sigma_S)$ and $\sigma_S = \mathrm{Exp}(40)$. It's not great, but neither is $\delta_{0s}$.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part has been deleted as it doesn't follow hierarchical structure, per discussion

Comment thread scripts/gam/gam_scipy.py Outdated
X: ArrayLike,
estimate: ArrayLike,
data: pl.DataFrame,
p: int = 2,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm realizing that it's confusing that "p" is probability, likelihood, and also degree of the spline

Comment thread scripts/gam/gam_scipy.py Outdated

# Penalized precision matrix, add 1e-6 to make sure stability
precision = (lam * S) + 1e-6 * jnp.eye(p)
if groups is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can drop this. We'll always have groups, even if it's just season.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The script has been deleted, per discussion

Comment thread scripts/gam/gam_scipy.py Outdated

z = numpyro.sample(
f"z_{idx}",
dist.MultivariateNormal(0, jnp.eye(k)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need multivariate normal here. This is a vector of values, so you can just use dist.Normal

Comment thread scripts/gam/gam_scipy.py Outdated

if __name__ == "__main__":
## model fitting ##
for idx in data["group_combo_idx"].unique():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to follow the logic here:

  1. Pick common $\lambda$, $\sigma_\mathrm{season}$, $\sigma_\mathrm{geo}$, and $\sigma_Z$ to be used across all states/seasons
  2. For each state/season, fit a spline
  3. Adjust that spline based on the common values

This means that the only way the fit for one state/season "sees" the data from other seasons is via the shared parameters ($\lambda$, $\sigma_\mathrm{season}$, $\sigma_\mathrm{geo}$, and $\sigma_Z$). The deviations, design matrix, etc. are all within each state/season.

So I'm confused first about the merit of having a prior on $\lambda$. Are we asking the data to tell us what kind of penalty we should put on wiggliness?

Second, I'm confused how this is a hierarchical model, since we're not sharing information about the states or the seasons. I expect this means that the forecasts don't look very good?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The script has been deleted, per discussion

Co-authored-by: Scott Olesen <ulp7@cdc.gov>
Copy link
Copy Markdown

@afmagee42 afmagee42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could just be my relative inexperience with splines, or being out of date on the project at hand, but I am not sure I see how this model does what I thought we wanted it to do?

Comment thread docs/gam.md Outdated
```math

g^{-1}(E(y)) = X\beta + \beta_0
g^{-1}(E(y)) = X\beta
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link function framing feels to me like an easy way for us to either get trapped in frequentist thinking or backed into corners we don't want to be in.

Comment thread docs/gam.md Outdated
k = m + p - 1
$$

The element in $X$ is the value of the basis function evaluated at the predictor elapsed, with rows are data point $x_i$ , and columns are basis function $B_k$.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, basis splines are the way to go when doing something like this. Bases are pre-computable, limiting the work you need to do in NumPyro

Comment thread docs/gam.md Outdated
```

$y$ is a vector of observed vaccination coverage, $X$ is the design matrix of basis function with $N \times k$ dimension, where $N$ is the number of data points and $k$ is the number of basis functions used. The element in $X$ is the value of the basis function evaluated at the predictor elapsed, with rows are data point $x_i$ , and columns are basis function $B_k$.
$y$ is a vector of observed vaccination coverage, $X$ is the design matrix of basis function with $N \times k$ dimension, where $N$ is the number of data points and $k$ is the number of basis functions used. $k$ is defined by the order degree of spline function $p$ and the number of internal knots $m$ by:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
$y$ is a vector of observed vaccination coverage, $X$ is the design matrix of basis function with $N \times k$ dimension, where $N$ is the number of data points and $k$ is the number of basis functions used. $k$ is defined by the order degree of spline function $p$ and the number of internal knots $m$ by:
$y$ is a vector of observed vaccination coverage, $X$ is the design matrix of basis function with $N \times k$ dimension, where $N$ is the number of data points and $k$ is the number of basis functions used. $K$ is defined by the order degree of spline function $p$ and the number of internal knots $m$ by:

I think this was just a typo and big K was intended, rather than little k?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also not call this a definition, it's a relationship. The definition comes when you choose which of $K$, $m$, and $p$ that you fix, and which is free.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See updated!

Comment thread docs/gam.md Outdated
$y$ is a vector of observed vaccination coverage, $X$ is the design matrix of basis function with $N \times k$ dimension, where $N$ is the number of data points and $k$ is the number of basis functions used. $k$ is defined by the order degree of spline function $p$ and the number of internal knots $m$ by:

$$
k = m + p - 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
k = m + p - 1
K = m + p - 1

Same typo, I think?

Comment thread docs/gam.md
L(y|\beta)\cdot exp(-\lambda\beta^TS\beta/(2\sigma))
```

Using empirical Bayes approach, we can derive:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empirical Bayes gets my hackles up

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that the prior of $\beta$ is derived because it looks like a prior in the equation...i.e, fits in the position where prior should be. That's my understanding about "empirical"

Comment thread docs/gam.md Outdated

```math
L(y|\beta)\cdot exp(-\lambda\beta^TS\beta/(2\sigma^2))
L(y|\beta)\cdot exp(-\lambda\beta^TS\beta/(2\sigma))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, in my estimation in so far as MCMV has coalesced on a style, we don't use L(params) for the likelihood, we write p(data | params)

Comment thread docs/gam.md
```


#### Deriving prior of $\beta$
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this? This sems like an attempt to make a prior out of a frequentist penalty function, which does not work well. When doing maximum likelihood, all you need to penalize is a point. We need to penalize a distribution. If you could just apply a Bayesian prior matching a frequentist penalty function, "the Bayesian lasso" (aka regression with exponential priors) would provide good sparsity, but it doesn't.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what Wood et al done to derive prior from the penalty function in section 2.4. I'm new on this topic, can you explain why it doesn't work well, is it frequentist penalty function penalizes a point while Bayesian way is to penalize a distribution, and the connection between these two is questionable?

Comment thread docs/gam.md Outdated

```math
\begin{align*}
\beta_{total} &= \bar \beta + \delta_{s=i} + \delta_{g=j} \\
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these scalars? Vectors?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this partial pooling structure imply about the functional forms per state-season? It is not clear to me that

  • There is any guarantee they have to look remotely similar
  • This looks anything like our general idea that functional forms are consistent give or take shifts to the start of uptake and peak uptake

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per discussion, the function forms are not shared and it's like no pooling. This part has been deleted

@afmagee42
Copy link
Copy Markdown

If $f(t)$ is a spline function, and $u(t, s, y)$ is the uptake in state $s$ in year $y$, I would have expected a partially-pooled spline model to look, downstream of the spline-y bits and the priors on their coefficients, more like

$$u(t, s, y) := \theta_{s, y} \times f((t - \xi_{s, y}) / \zeta_{s, y})$$

@Fuhan-Yang
Copy link
Copy Markdown
Contributor Author

Fuhan-Yang commented Feb 20, 2026

Per discussion, the scripts are not truly hierarchical, and it takes much efforts to build HGAM in numpyro. The scripts have been deleted. The doc has been rephrased per comments, and the part about hierarchy in the document has been deleted.

@Fuhan-Yang Fuhan-Yang requested review from afmagee42 and swo February 20, 2026 04:14
@Fuhan-Yang Fuhan-Yang mentioned this pull request Feb 20, 2026
@Fuhan-Yang
Copy link
Copy Markdown
Contributor Author

Close per discussion

@Fuhan-Yang Fuhan-Yang closed this Feb 20, 2026
@swo swo deleted the fy_gam_hi2 branch March 23, 2026 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants