Skip to content

[Bug]: Improve scaling of posterior computation when Standardize is used #3159

@DonneyF

Description

@DonneyF

What happened?

Calling model.posterior(X) for large X can be slow due to the need to compute covar_tf = scale_mat @ lcv @ scale_mat when untransforming. This can improve some posterior sampling schemes (i.e. Thompson sampling with candidate points). I see a wall-clock difference of about ~10x on my local machine (14900k, 4090), but it's only on the order of seconds when there are 8k test points.

Please provide a minimal, reproducible example of the unexpected behavior.

import time
import torch
from gpytorch.mlls import ExactMarginalLogLikelihood
import linear_operator
from botorch.models import SingleTaskGP
from botorch.utils.transforms import standardize
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float64

torch.manual_seed(0)
torch.set_default_dtype(dtype)

def sync_and_perf_count():
    torch.cuda.synchronize()
    return time.perf_counter()

init1 = sync_and_perf_count()
d = 300
N = 1000
X = torch.rand(N, d, dtype=dtype, device=device)
Y = standardize(torch.randn(N, 1, dtype=dtype, device=device))
model = SingleTaskGP(train_X=X, train_Y=Y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
model.eval()
init2 = sync_and_perf_count()

N_cand = 8000
X_cand = torch.randn(N_cand, d, dtype=dtype, device=device)

torch.cuda.empty_cache()

# Via SingleTaskGP API
t1 = sync_and_perf_count()
posterior = model.posterior(X_cand)
covar_mat = posterior.covariance_matrix # Calls to_dense()
t2 = sync_and_perf_count()

torch.cuda.empty_cache()

# Manual Computation
t3 = sync_and_perf_count()
sigma_sq = model.likelihood.noise_covar.noise.unsqueeze(-1)
eye = torch.eye(N, device=device, dtype=dtype)
K_test_test = model.covar_module(X_cand, X_cand)
K_obs_test = model.covar_module(X, X_cand).to_dense()
K_obs_obs = model.covar_module(X, X)

K_obs_obs_noise = linear_operator.to_linear_operator(
    K_obs_obs + sigma_sq * eye
)
K_posterior = K_test_test - K_obs_test.mT @ K_obs_obs_noise.solve(K_obs_test)
if hasattr(model, 'outcome_transform'):
    K_posterior = K_posterior * model.outcome_transform.stdvs.pow(2)
K_posterior = K_posterior.to_dense()
t4 = sync_and_perf_count()

assert torch.allclose(K_posterior, covar_mat)
print(f"SingleTaskGP: {t2 - t1:.2f}", f"Manual: {t4 - t3:.2f}")

Please paste any relevant traceback/logs produced by the example provided.

SingleTaskGP: 2.33 Manual: 0.18

BoTorch Version

0.16.1

Python Version

3.13

Operating System

Ubuntu 22.04

(Optional) Describe any potential fixes you've considered to the issue outlined above.

I would propose a path in untransform_posterior that checks for scalar cases that avoids needing DiagLinearOperator.

Pull Request

None

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions