Separate covariance computations for customizable posterior inference#2715
Conversation
8723467 to
3ce2bdf
Compare
Balandat
left a comment
There was a problem hiding this comment.
Thanks, this looks great to me at a high level. cc @kayween, @gpleiss, @jacobrgardner in case they have any comments since this is a pretty deep change
The docs build failure should be resolved once rebased on #2716
| # Concatenate the input to the training input | ||
| full_inputs = [] | ||
| batch_shape = train_inputs[0].shape[:-2] | ||
| for train_input, input in length_safe_zip(train_inputs, test_inputs): |
There was a problem hiding this comment.
Seems like we can replace length_safe_zip with zip(*, strict=True)?
There was a problem hiding this comment.
zip(..., strict=True) was added in Python 3.10 (PEP 618), which seems to be GPyTorch's Python version floor, so we can make this change. Since length_safe_zip is used throughout the code base, I'll put up a separate commit for this.
| if sum(test_train_covar.shape[-2:]) <= settings.max_eager_kernel_size.value(): | ||
| # If we are calling it here, it requires two calls to "to_dense", but if we | ||
| # called it in the test covariance getter, then it would break the other | ||
| # prediction strategies (actually it already, does because we are not |
There was a problem hiding this comment.
| # prediction strategies (actually it already, does because we are not | |
| # prediction strategies (actually it already does because we are not |
There was a problem hiding this comment.
should we be overwriting them?
There was a problem hiding this comment.
I think a better design for the prediction strategies is based on multiple-dispatch on the types of the covariance matrices. That way there wouldn't need to be a fall back if the types don't match, it would just call the correct implementation to begin with. With this in mind, I'd prefer to defer further changes to a future commit.
5ece9e9 to
c4ee5f5
Compare
|
@Balandat Thanks for the review, incorporated your suggestions. |
|
This refactor looks good to me. Besides unblocking what's in the PR description, this PR also does quite a bit of cleanup, which is great. cc @jacobrgardner @gpleiss in case they want to weigh in here. |
7b18a28 to
4de99c9
Compare
gpleiss
left a comment
There was a problem hiding this comment.
LGTM once the comments are addressed!
4de99c9 to
502156b
Compare
502156b to
c1c1a7b
Compare
This commit separates out the computation of the test-set mean and covariances
into
_get_test_prior_mean_and_covariancesand rewritesexact_predictiontoaccept
test_mean,test_test_covar,test_train_covaralready separated.This enables: