Skip to content

Conversation

@mathDR
Copy link
Contributor

@mathDR mathDR commented Nov 1, 2025

Checklist

  • I've formatted the new code by running uv run poe format before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

This PR does one thing: allows for the user to request to only return the diagonal elements of the covariance matrix of the test points.

This is done through the addition of a typing.Literal named return_cov_type that is passed into the respective predict functions for Prior, ConjugatePosterior and NonConjugatePrior. When the literal is "dense" then the full covariance is computed and returned. If the literal is "diagonal" then only the diagonal elements are computed and returned.

The implementation of this is through a jax.lax.cond which is the jax equivalent of an if-else capability. One drawback of this architecture is that both branches of the cond must return the same type of PyTree. Therefore in returning the covariance on the diagonal branch, we must cast the diagonal array to a Dense type LinearOperator.

In implementing this PR, some issues with the Diagonal functionality in kernels.computations were discovered and fixed.

This PR closes #389 and #566

@github-actions
Copy link

github-actions bot commented Nov 1, 2025

Thank you for opening your first PR into GPJax!

If you have not heard from us in a while, please feel free to ping
@gpjax/developers or anyone who has commented on the PR.
Most of our reviewers are volunteers and sometimes things fall
through the cracks.

You can also join us on
Slack
for real-time
discussion.

For details on testing, writing docs, and our review process,
please see the developer
guide

We strive to be a welcoming and open project. Please follow our
Code of
Conduct
.

@mathDR
Copy link
Contributor Author

mathDR commented Nov 2, 2025

Oh I suppose I should add a test checking that things error out appropriately when calling with an incorrect literal...

thoughts @thomaspinder ?

Copy link
Owner

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some open comments that we should resolve before merging around scoping and naming. If any of my comments are unclear, then please just ask :)

@mathDR mathDR requested a review from thomaspinder November 2, 2025 18:19
@mathDR
Copy link
Contributor Author

mathDR commented Nov 3, 2025

Okay I think I fixed all of the docstring formatting issues.

@mathDR
Copy link
Contributor Author

mathDR commented Nov 4, 2025

Generically I am refactoring this and adding more tests. i see that in the scipy version, they pre-compute some Dense matrices to enable mean computations outside of the covariance computations.

In our case, can take advantage of computing the Cholesky of Kxx and the respective solves we need for those, so I hesitate to pull that piece out of each branch.

Also, we must be calling vmap on this predict function, because my tests are showing that both branches of the jax.lax.cond are being computed (which defeats the whole purpose! 🤣 ) So I will try to run that down...

@mathDR mathDR requested a review from thomaspinder November 5, 2025 15:02
@mathDR
Copy link
Contributor Author

mathDR commented Nov 7, 2025

@thomaspinder I think this is good to go. LMK if you have any more comments.

@thomaspinder
Copy link
Owner

Agreed. The changes are looking good to me now. Thanks for your effort here! Will merge once the tests have finished.

As an aside, it may be nice to publicly demonstrate this functionality in a notebook. Let me know if you want to do this, otherwise I can/will.

@mathDR
Copy link
Contributor Author

mathDR commented Nov 7, 2025

I can do this. Perhaps extending the regression notebook?

@thomaspinder
Copy link
Owner

I just opened a PR to your branch that fixes the graph kernel issue. If we first merge that, then the tests here should pass and we can merge into main.

@mathDR
Copy link
Contributor Author

mathDR commented Nov 9, 2025

I just opened a PR to your branch that fixes the graph kernel issue. If we first merge that, then the tests here should pass and we can merge into main.

weird that my main branch has all of the add_diag_cov branch changes. Let me rebase then I can review your PR.

disregard this is how forks work. 🤦

@thomaspinder thomaspinder merged commit bfc7488 into thomaspinder:main Nov 11, 2025
19 checks passed
@thomaspinder
Copy link
Owner

Thanks for the work here - pleased to get this one merged! I opened up a new issue for demonstrating this PR's usage into the notebooks. Let me know if you're still keen to pick it up, otherwise I will later this week.

#571

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

dev: Improve GP Prediction Efficiency

2 participants