-
-
Notifications
You must be signed in to change notification settings - Fork 72
Ability to return just diag on predict call
#567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thank you for opening your first PR into GPJax! If you have not heard from us in a while, please feel free to ping You can also join us on For details on testing, writing docs, and our review process, We strive to be a welcoming and open project. Please follow our |
|
Oh I suppose I should add a test checking that things error out appropriately when calling with an incorrect literal... thoughts @thomaspinder ? |
thomaspinder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some open comments that we should resolve before merging around scoping and naming. If any of my comments are unclear, then please just ask :)
|
Okay I think I fixed all of the docstring formatting issues. |
|
Generically I am refactoring this and adding more tests. i see that in the scipy version, they pre-compute some In our case, can take advantage of computing the Cholesky of Also, we must be calling |
|
@thomaspinder I think this is good to go. LMK if you have any more comments. |
|
Agreed. The changes are looking good to me now. Thanks for your effort here! Will merge once the tests have finished. As an aside, it may be nice to publicly demonstrate this functionality in a notebook. Let me know if you want to do this, otherwise I can/will. |
|
I can do this. Perhaps extending the regression notebook? |
|
I just opened a PR to your branch that fixes the graph kernel issue. If we first merge that, then the tests here should pass and we can merge into main. |
disregard this is how forks work. 🤦 |
Thomaspinder/fix graph kernel
|
Thanks for the work here - pleased to get this one merged! I opened up a new issue for demonstrating this PR's usage into the notebooks. Let me know if you're still keen to pick it up, otherwise I will later this week. |
Checklist
uv run poe formatbefore committing.Description
This PR does one thing: allows for the user to request to only return the diagonal elements of the covariance matrix of the test points.
This is done through the addition of a
typing.Literalnamedreturn_cov_typethat is passed into the respectivepredictfunctions forPrior,ConjugatePosteriorandNonConjugatePrior. When the literal is"dense"then the full covariance is computed and returned. If the literal is"diagonal"then only the diagonal elements are computed and returned.The implementation of this is through a
jax.lax.condwhich is the jax equivalent of anif-elsecapability. One drawback of this architecture is that both branches of thecondmust return the same type ofPyTree. Therefore in returning the covariance on thediagonalbranch, we must cast the diagonal array to aDensetypeLinearOperator.In implementing this PR, some issues with the
Diagonalfunctionality inkernels.computationswere discovered and fixed.This PR closes #389 and #566