Description
Hi, thanks for creating this library!
I was going through the continual regression tutorial notebook and wanted to try out diag_fisher
and I expected it to produce similar results to diag_vi
. However, the predictive distribution did not look right except for the first episode. Here are my results for VI (by sampling parameters) and Fisher (linearized forward):
One thing I noticed is for diag_fisher
, the prec_diag
of some parameters are close to zero, which I clipped to 1e-5
because otherwise it will cause an error with posteriors.laplace.diag_fisher.sample
. I also de-normalized log_posterior
by multiplying it with the number of data points for diag_fisher
transform. I was wondering whether this is expected or there's something wrong with my implementation?
Here's my code which can be subbed in after the VI training block in the notebook:
def train_for_la(dataloader, prior_mean, prior_sd, n_epochs=100, init_log_sds=None):
seq_log_post = partial(log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)
# compute map estimate
opt = torch.optim.Adam(mlp.parameters())
for _ in range(n_epochs):
for batch in dataloader:
opt.zero_grad()
loss = -seq_log_post(dict(mlp.named_parameters()), batch)[0]
loss.backward()
opt.step()
# update laplace state
def _log_posterior(params, batch, prior_mean, prior_sd):
"""Non data normalized log post"""
x, y = batch
y_pred = mlp_functional(params, x)
log_post = log_likelihood(y_pred, y) * samps_per_episode + log_prior(params, prior_mean, prior_sd)
return log_post, y_pred
_seq_log_post = partial(_log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)
transform = posteriors.laplace.diag_fisher.build(
_seq_log_post, per_sample=False, init_prec_diag=0.,
)
state = transform.init({k: v.data for k, v in dict(mlp.named_parameters()).items()})
for batch in dataloader:
state = transform.update(state, batch, inplace=False)
state.prec_diag = tree_map(lambda x: torch.clip(x, 1e-5, 1e5), state.prec_diag)
return state
# train laplace
mlp.load_state_dict(trained_params[0])
la_states = []
for i in range(n_episodes):
seq_prior_mean = prior_mean if i == 0 else tree_map(lambda x: x.clone(), la_states[i - 1].params)
seq_prior_sd = prior_sd if i == 0 else tree_map(
lambda prec: torch.sqrt(1 / prec.clone() + transition_sd ** 2), la_states[i - 1].prec_diag
)
state = train_for_la(
dataloaders[i], seq_prior_mean, seq_prior_sd, n_epochs=100, init_log_sds=None
)
la_states += [copy.deepcopy(state)]
mlp.load_state_dict(la_states[i].params)
# laplace forward
def to_sd_diag_la(state, temperature=1.0):
return tree_map(lambda x: torch.sqrt(temperature / (x + 1e-3)), state.prec_diag)
def forward_linearized(model, state, batch, temperature=1.0):
n_linearised_test_samples = 30
x, _ = batch
sd_diag = to_sd_diag_la(state, temperature)
def model_func_with_aux(p, x):
return torch.func.functional_call(model, p, x), torch.tensor([])
lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
model_func_with_aux,
state.params,
x,
sd_diag,
)
samps = torch.randn(
lin_mean.shape[0],
n_linearised_test_samples,
lin_mean.shape[1],
device=lin_mean.device,
)
lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
return lin_logits
# plot laplace
fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)
for i, ax in enumerate(axes):
plot_data(ax, up_to_episode=i+1)
with torch.no_grad():
preds = forward_linearized(mlp, la_states[i], [plt_linsp.view(-1, 1), None]).squeeze(-1)
# plot predictions
sd = preds.std(1)
preds = preds.mean(1)
ax.plot(plt_linsp, preds, color='blue', alpha=1.)
ax.fill_between(plt_linsp, preds - sd, preds + sd, color='blue', alpha=0.2)
ax.set_title(f"After Episode {i+1}")
plt.suptitle("diag_fisher")
Activity