Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 130 additions & 91 deletions iup/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,45 +100,41 @@ def __init__(self, seed: int):
"""
self.rng_key = random.key(seed)
self.fit_key, self.pred_key = random.split(self.rng_key, 2)
self.model = LPLModel._logistic_plus_linear

@staticmethod
def _logistic_plus_linear(
def _logistic_plus_linear_groups(
elapsed,
N_vax=None,
N_tot=None,
groups=None,
num_group_factors=0,
num_group_levels=[0],
A_shape1=100.0,
A_shape2=180.0,
A_sig=40.0,
H_shape1=100.0,
H_shape2=225.0,
n_shape=25.0,
n_rate=1.0,
M_shape=1.0,
M_rate=10.0,
M_sig=40.0,
d_shape=350.0,
d_rate=1.0,
N_vax,
N_tot,
data_level_matrix: np.ndarray,
level_factor_matrix: np.ndarray,
A_shape1,
A_shape2,
A_sig,
H_shape1,
H_shape2,
n_shape,
n_rate,
M_shape,
M_rate,
M_sig,
d_shape,
d_rate,
):
"""
Fit a mixed Logistic Plus Linear model on training data.

Parameters
elapsed: np.array
fraction of a year elapsed since the start of season at each data point
N_vax: np.array | None
N_vax: np.array
number of people vaccinated at each data point
N_tot: np.array | None
N_tot: np.array
number of people contacted at each data point
groups: np.array | None
numeric codes for groups: row = data point, col = grouping factor
num_group_factors: Int
number of grouping factors
num_group_levels: List[Int,]
number of unique levels of each grouping factor
data_level_matrix:
see iup.utils.get_design_matrices()
level_factor_matrix:
see iup.utils.get_design_matrices()
other parameters: float
parameters to specify the prior distributions

Expand All @@ -154,31 +150,73 @@ def _logistic_plus_linear(
n = numpyro.sample("n", dist.Gamma(n_shape, n_rate))
M = numpyro.sample("M", dist.Gamma(M_shape, M_rate))
d = numpyro.sample("d", dist.Gamma(d_shape, d_rate))
# If grouping factors are given, find the group-specific deviations for each datum
if groups is not None:
A_sigs = numpyro.sample(
"A_sigs", dist.Exponential(A_sig), sample_shape=(num_group_factors,)
)
M_sigs = numpyro.sample(
"M_sigs", dist.Exponential(M_sig), sample_shape=(num_group_factors,)
)
A_devs = numpyro.sample(
"A_devs", dist.Normal(0, 1), sample_shape=(sum(num_group_levels),)
) * np.repeat(A_sigs, np.array(num_group_levels))
M_devs = numpyro.sample(
"M_devs", dist.Normal(0, 1), sample_shape=(sum(num_group_levels),)
) * np.repeat(M_sigs, np.array(num_group_levels))
A_tot = np.sum(A_devs[groups], axis=1) + A
M_tot = np.sum(M_devs[groups], axis=1) + M
# Calculate latent true uptake at each datum
mu = A_tot / (1 + jnp.exp(0 - n * (elapsed - H))) + (M_tot * elapsed)
else:
# Calculate latent true uptake at each datum if no grouping factors
mu = A / (1 + jnp.exp(0 - n * (elapsed - H))) + (M * elapsed)

_, n_levels = data_level_matrix.shape
_, n_factors = level_factor_matrix.shape

A_sigs = numpyro.sample(
"A_sigs", dist.Exponential(A_sig), sample_shape=(n_factors,)
)
M_sigs = numpyro.sample(
"M_sigs", dist.Exponential(M_sig), sample_shape=(n_factors,)
)

A_zs = numpyro.sample("A_zs", dist.Normal(0, 1), sample_shape=(n_levels,))
M_zs = numpyro.sample("M_zs", dist.Normal(0, 1), sample_shape=(n_levels,))

A_devs = numpyro.deterministic(
"A_devs",
jnp.matmul(level_factor_matrix, A_sigs) * A_zs, # type: ignore
)
M_devs = numpyro.deterministic(
"M_devs",
jnp.matmul(level_factor_matrix, M_sigs) * M_zs, # type: ignore
)

A_tot = A + jnp.matmul(data_level_matrix, A_devs)
M_tot = M + jnp.matmul(data_level_matrix, M_devs)

# Calculate latent true uptake at each datum
mu = numpyro.deterministic(
"mu", A_tot / (1 + jnp.exp(0 - n * (elapsed - H))) + (M_tot * elapsed)
)

# Calculate the shape parameters for the beta-binomial likelihood
S1 = mu * d
S2 = (1 - mu) * d
numpyro.sample("obs", dist.BetaBinomial(S1, S2, N_tot), obs=N_vax) # type: ignore

@staticmethod
def _logistic_plus_linear_no_groups(
elapsed,
N_vax,
N_tot,
A_shape1,
A_shape2,
H_shape1,
H_shape2,
n_shape,
n_rate,
M_shape,
M_rate,
d_shape,
d_rate,
):
# Sample the overall average value for each parameter
A = numpyro.sample("A", dist.Beta(A_shape1, A_shape2))
H = numpyro.sample("H", dist.Beta(H_shape1, H_shape2))
n = numpyro.sample("n", dist.Gamma(n_shape, n_rate))
M = numpyro.sample("M", dist.Gamma(M_shape, M_rate))
d = numpyro.sample("d", dist.Gamma(d_shape, d_rate))

mu = numpyro.deterministic(
"mu", A / (1 + jnp.exp(0 - n * (elapsed - H))) + (M * elapsed)
)

# Calculate the shape parameters for the beta-binomial likelihood
S1 = mu * d
S2 = (1 - mu) * d
numpyro.sample("obs", dist.BetaBinomial(S1, S2, N_tot), obs=N_vax)
numpyro.sample("obs", dist.BetaBinomial(S1, S2, N_tot), obs=N_vax) # type: ignore

@staticmethod
def augment_data(
Expand Down Expand Up @@ -251,54 +289,57 @@ def fit(
"""
self.group_combos = extract_group_combos(data, groups)

# prepare common run arguments for grouped and ungrouped models
run_kwargs = {
"elapsed": data["elapsed"].to_numpy(),
"N_vax": data["N_vax"].to_numpy(),
"N_tot": data["N_tot"].to_numpy(),
"A_shape1": params["A_shape1"],
"A_shape2": params["A_shape2"],
"A_sig": params["A_sig"],
"H_shape1": params["H_shape1"],
"H_shape2": params["H_shape2"],
"n_shape": params["n_shape"],
"n_rate": params["n_rate"],
"M_shape": params["M_shape"],
"M_rate": params["M_rate"],
"M_sig": params["M_sig"],
"d_shape": params["d_shape"],
"d_rate": params["d_rate"],
}

# Tranform the levels of the grouping factors into numeric codes
if groups is not None:
self.num_group_factors = len(groups)
self.num_group_levels = iup.utils.count_unique_values(self.group_combos)
self.value_to_index = iup.utils.map_value_to_index(data.select(groups))
group_codes = iup.utils.value_to_index(
data.select(groups), self.value_to_index, self.num_group_levels
self.level_to_index = iup.utils.map_level_to_index(data.select(groups))
data_level_matrix, level_factor_matrix = iup.utils.get_design_matrices(
data.select(groups), self.level_to_index
)

run_kwargs |= {
"data_level_matrix": data_level_matrix,
"level_factor_matrix": level_factor_matrix,
"A_sig": params["A_sig"],
"M_sig": params["M_sig"],
}

model = self._logistic_plus_linear_groups

self.kernel = NUTS(
self._logistic_plus_linear_groups, init_strategy=init_to_sample
)
else:
group_codes = None
self.num_group_factors = 0
self.num_group_levels = [0]
self.value_to_index = None
model = self._logistic_plus_linear_no_groups

# Prepare the data to be fed to the model. Must be numpy arrays.
elapsed = data["elapsed"].to_numpy()
N_vax = data["N_vax"].to_numpy()
N_tot = data["N_tot"].to_numpy()
self.kernel = NUTS(model, init_strategy=init_to_sample)

self.kernel = NUTS(self.model, init_strategy=init_to_sample)
self.mcmc = MCMC(
self.kernel,
num_warmup=mcmc["num_warmup"],
num_samples=mcmc["num_samples"],
num_chains=mcmc["num_chains"],
)

self.mcmc.run(
self.fit_key,
elapsed=elapsed,
N_vax=N_vax,
N_tot=N_tot,
groups=group_codes,
num_group_factors=self.num_group_factors,
num_group_levels=self.num_group_levels,
A_shape1=params["A_shape1"],
A_shape2=params["A_shape2"],
A_sig=params["A_sig"],
H_shape1=params["H_shape1"],
H_shape2=params["H_shape2"],
n_shape=params["n_shape"],
n_rate=params["n_rate"],
M_shape=params["M_shape"],
M_rate=params["M_rate"],
M_sig=params["M_sig"],
d_shape=params["d_shape"],
d_rate=params["d_rate"],
)
self.mcmc.run(self.fit_key, **run_kwargs)

print(self.mcmc.print_summary())

Expand Down Expand Up @@ -393,10 +434,9 @@ def predict(
predictive = Predictive(self.model, self.mcmc.get_samples())

if groups is not None:
# Make a numpy array of numeric codes for grouping factor levels
# that matches the same codes used when fitting the model
group_codes = iup.utils.value_to_index(
scaffold.select(groups), self.value_to_index, self.num_group_levels
assert self.level_to_index is not None
data_level_matrix, level_factor_matrix = iup.utils.get_design_matrices(
scaffold.select(groups), self.level_to_index
)

# Make a prediction-machine from the fit model
Expand All @@ -405,9 +445,8 @@ def predict(
self.pred_key,
elapsed=scaffold["elapsed"].to_numpy(),
N_tot=scaffold["N_tot"].to_numpy(),
groups=group_codes,
num_group_factors=self.num_group_factors,
num_group_levels=self.num_group_levels,
data_level_matrix=data_level_matrix,
level_factor_matrix=level_factor_matrix,
)["obs"]
).transpose()
else:
Expand Down
Loading
Loading