Skip to content

Continual regression tutorial question: Laplace #107

Open
@ran-weii

Description

@ran-weii

Hi, thanks for creating this library!

I was going through the continual regression tutorial notebook and wanted to try out diag_fisher and I expected it to produce similar results to diag_vi. However, the predictive distribution did not look right except for the first episode. Here are my results for VI (by sampling parameters) and Fisher (linearized forward):

image

image

One thing I noticed is for diag_fisher, the prec_diag of some parameters are close to zero, which I clipped to 1e-5 because otherwise it will cause an error with posteriors.laplace.diag_fisher.sample. I also de-normalized log_posterior by multiplying it with the number of data points for diag_fisher transform. I was wondering whether this is expected or there's something wrong with my implementation?

Here's my code which can be subbed in after the VI training block in the notebook:

def train_for_la(dataloader, prior_mean, prior_sd, n_epochs=100, init_log_sds=None):
    seq_log_post = partial(log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)
    
    # compute map estimate
    opt = torch.optim.Adam(mlp.parameters())
    for _ in range(n_epochs):
        for batch in dataloader:
            opt.zero_grad()
            loss = -seq_log_post(dict(mlp.named_parameters()), batch)[0]
            loss.backward()
            opt.step()
    
    # update laplace state
    def _log_posterior(params, batch, prior_mean, prior_sd):
        """Non data normalized log post"""
        x, y = batch
        y_pred = mlp_functional(params, x)
        log_post = log_likelihood(y_pred, y) * samps_per_episode + log_prior(params, prior_mean, prior_sd)
        return log_post, y_pred
    _seq_log_post = partial(_log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)

    transform = posteriors.laplace.diag_fisher.build(
        _seq_log_post, per_sample=False, init_prec_diag=0.,
    )
    state = transform.init({k: v.data for k, v in dict(mlp.named_parameters()).items()})
    for batch in dataloader:
        state = transform.update(state, batch, inplace=False)
    state.prec_diag = tree_map(lambda x: torch.clip(x, 1e-5, 1e5), state.prec_diag)
    return state

# train laplace
mlp.load_state_dict(trained_params[0])
la_states = []
for i in range(n_episodes):
    seq_prior_mean = prior_mean if i == 0 else tree_map(lambda x: x.clone(), la_states[i - 1].params)
    seq_prior_sd = prior_sd if i == 0 else tree_map(
        lambda prec: torch.sqrt(1 / prec.clone() + transition_sd ** 2), la_states[i - 1].prec_diag
    )
    
    state = train_for_la(
        dataloaders[i], seq_prior_mean, seq_prior_sd, n_epochs=100, init_log_sds=None
    )

    la_states += [copy.deepcopy(state)]
    mlp.load_state_dict(la_states[i].params)

# laplace forward
def to_sd_diag_la(state, temperature=1.0):
    return tree_map(lambda x: torch.sqrt(temperature / (x + 1e-3)), state.prec_diag)

def forward_linearized(model, state, batch, temperature=1.0):
    n_linearised_test_samples = 30
    x, _ = batch
    sd_diag = to_sd_diag_la(state, temperature)

    def model_func_with_aux(p, x):
        return torch.func.functional_call(model, p, x), torch.tensor([])

    lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
        model_func_with_aux,
        state.params,
        x,
        sd_diag,
    )

    samps = torch.randn(
        lin_mean.shape[0],
        n_linearised_test_samples,
        lin_mean.shape[1],
        device=lin_mean.device,
    )
    lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
    return lin_logits

# plot laplace
fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)
for i, ax in enumerate(axes):
    plot_data(ax, up_to_episode=i+1)
    with torch.no_grad():
        preds = forward_linearized(mlp, la_states[i], [plt_linsp.view(-1, 1), None]).squeeze(-1)
    
    # plot predictions
    sd = preds.std(1)
    preds = preds.mean(1)
    ax.plot(plt_linsp, preds, color='blue', alpha=1.)
    ax.fill_between(plt_linsp, preds - sd, preds + sd, color='blue', alpha=0.2)
    ax.set_title(f"After Episode {i+1}")
plt.suptitle("diag_fisher")

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions