Skip to content

Commit bfc7488

Browse files
mathDRthomaspinder
andauthored
Ability to return just diag on predict call (#567)
* updated linting * linting * reverted computations base * fixed docstrings * ran poe format * changed Tuple to tuple * changed return_cov_type to return_covariance_type * rescoped prior lax cond and renamed functions * revert typing Literal as beartype doesnt support it * ran poe format * fixed name of return function in nonconjugate posterior * reformatted docstring args * fixed linting error * refactoring * refactored and added tests * Fix graph kernel shapes * Fix graph kernel shapes * Tidy typing * Linting * using the dense and diagonal attributes in the regression example --------- Co-authored-by: Thomas Pinder <[email protected]>
1 parent b74be38 commit bfc7488

File tree

6 files changed

+451
-86
lines changed

6 files changed

+451
-86
lines changed

examples/regression.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# extension: .py
88
# format_name: percent
99
# format_version: '1.3'
10-
# jupytext_version: 1.11.2
10+
# jupytext_version: 1.17.3
1111
# kernelspec:
1212
# display_name: .venv
1313
# language: python
@@ -121,10 +121,15 @@
121121
# we have just defined can be represented by a
122122
# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax)
123123
# multivariate Gaussian distribution. Such functionality enables trivial sampling, and
124-
# the evaluation of the GP's mean and covariance .
124+
# the evaluation of the GP's mean and covariance.
125+
#
126+
# Since we want to sample from the full posterior, we need to calculate the full covariance matrix.
127+
# We can enforce this by including the `return_covariance_type = "dense"` attribute when predicting.
128+
# Note this is what will be defaulted if left blank.
125129

126130
# %%
127-
prior_dist = prior.predict(xtest)
131+
# %% [markdown]
132+
prior_dist = prior.predict(xtest, return_covariance_type="dense")
128133

129134
prior_mean = prior_dist.mean
130135
prior_std = prior_dist.variance
@@ -212,9 +217,13 @@
212217
# this, we use our defined `posterior` and `likelihood` at our test inputs to obtain
213218
# the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean`
214219
# and `stddev` can be used to extract the predictive mean and standard deviatation.
220+
#
221+
# We are only concerned here about the variance between the test points and themselves, so
222+
# we can just copute the diagonal version of the covariance. We enforce this by using
223+
# `return_covariance_type = "diagonal"` in the `predict` call.
215224

216225
# %%
217-
latent_dist = opt_posterior.predict(xtest, train_data=D)
226+
latent_dist = opt_posterior.predict(xtest, train_data=D, return_covariance_type="diagonal")
218227
predictive_dist = opt_posterior.likelihood(latent_dist)
219228

220229
predictive_mean = predictive_dist.mean

0 commit comments

Comments
 (0)