-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Implement unconstraining transform for LKJCorr #7380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement unconstraining transform for LKJCorr #7380
Conversation
|
|
|
Hi, It's unlikely I'm going to have any time to work on this for the next 6 months. The hardest part is coming up with a closed form solution for log_det_jac, which I don't think I'm very close to doing. |
|
Thanks for the update @johncant and for pushing this as far as you did. |
df723bc to
cf8d9a8
Compare
|
This is updated to a fully working implementation of the transformer, with non-trivial tests. It is currently blocked by #7910, because it uses There's some jank -- I'm not sure we need to pass If we got rid of the On that note, those two triangular pack helpers could themselves be a separate transformation, since they define a Bijection with forward and backward methods. That might be conceptually better, but on the other hand I don't think these would be used anywhere else, so it also makes sense to keep them as part of this class. |
3ac4970 to
1f29284
Compare
|
BTW once this is merged we should still explore the new transform by that Stan dev, it should have better sampling properties IIRC |
|
Do you have a link to the actual implementation? I was rooting around in the stan/stan-math repo and couldn't find it |
|
Wow - congratulations @jessegrabowski ! |
You had it 90% of the way there, we were just missing this weird spiral triangular construction that tfp was doing. I have no idea why they do it this way though, I just blindly copied xD |
1f29284 to
a6ab223
Compare
|
I switch the LKJCorr implementation to use
We could potentially try to check Here is the benchmark I ran: Before (unrolled loop): After (scan): |
|
We can support estimation of |
|
Those timings look good enough I wouldn't worry. It will hopefully fall out from our goal of speeding up Scan. Can you try with Numba mode out of curiosity? |
|
Numba timings. With scan: Unrolled loop (old implementation): Pretty bad! |
|
JAX also doesn't work with the new scan implementation, I guess because |
|
length shouldn't matter if only last state is used? can you write index as a recurring out? I don't recall where arange shows up |
|
timings are pretty nice, not worried about the small case |
No, it's bad. I updated the original timings with a compiled function instead of |
|
I used |
|
Ah so you were seeing longer compile times that made it look like scan was doing better? Unrolled 100 things or more is gonna be pretty ugly though |
|
Can you show the final numba graph? |
ce9dc2d to
655f433
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7380 +/- ##
==========================================
+ Coverage 90.89% 91.05% +0.16%
==========================================
Files 123 123
Lines 19489 19503 +14
==========================================
+ Hits 17714 17758 +44
+ Misses 1775 1745 -30
🚀 New features to boost your workflow:
|
655f433 to
48b46f5
Compare
|
I have some asks, test-wise
And follow-up, we should open issues for:
|
|
right now there's some confusion in the code about what
Tests are failing as a result. Working with |
|
Aren't you mixing |
I'm not sure we need _LKJCorr at all anymore |
Implement LKJCholeskyCorr transformation Co-authored-by: John Cant <[email protected]>
48b46f5 to
aac5607
Compare
What changed? I didn't see us changing anything in the RV side yet |
|
I added the other requested test. I am hitting a problem now where the RNG returned from my scan in |
Now everything works on matrices natively, so there's no need for helpers to do all the packing/unpacking. The transform returns matrix, the logp evaluates matrices, and the rv_op returns random matrices. |
|
So is the underlying distribution no longer working on the non-zero-triangular entries now? |
|
I don't know if it was worth it, but I understand that was done for memory efficiency. To avoid storing all those useless entries. Also, computationally could the logp more expensive now? And what are we working on, the correlation matrix or still the cholesky factor of it? |
|
We're working on the cholesky decomposition of the correlation matrix. This should be computationally cheaper across the board, but yes it will require more memory for random draws. But users can't do anything with the low memory draws anyway. If you actually want to compute with a correlation matrix, you are going to have to materialize it at some point. The logp in this PR should be cheaper than what's on main. Here's the current code: shape = n * (n - 1) // 2
tri_index = np.zeros((n, n), dtype="int32")
tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
value = pt.take(value, tri_index)
value = pt.fill_diagonal(value, 1)
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
try:
eta = float(get_underlying_scalar_constant_value(eta))
except NotScalarConstantError:
raise NotImplementedError("logp only implemented for constant `eta`")
result = _lkj_normalizing_constant(eta, n)
result += (eta - 1.0) * pt.log(det(value))
return check_parameters(
result,
value >= -1,
value <= 1,
matrix_pos_def(value),
eta > 0,
)And here's the new logp: result = _lkj_normalizing_constant(eta, n) + (eta - 1.0) * 2 * pt.diagonal(
value, axis1=-2, axis2=-1
).log().sum(axis=-1)
row_norms = pt.sum(value**2, axis=-1)
return check_parameters(
result,
eta > 0,
pt.isclose(row_norms, 1.0),
msg="Invalid values passed to LKJCorr logp: value is not a valid correlationm matrix or eta <= 0.",
)You can see previously they did all this reshaping business to get the packed vector into a full correlation matrix then computed the logp. Now we just exploit the cholesky structure to get cheap logdet. Our rewrites definitely wouldn't have been able to infer any structure from what were previously doing. |
I guess the point is that one can avoid storing them in the trace? That's an advantage I didn't think of. Personally I like storing the whole thing so I can have square dimensions on it and easily look at results. Working with the packed lower triangle requires multi-indices, which arviz supports only barely. |
Just to keep things as mysterious as possible, I don't see the double scan anymore in the current state of the PR. |

I've ported this bijector from
tensorflowand added toLKJCorr. This ensures that initial samples drawn fromLKJCorrare positive definite, which fixes #7101 . Sampling now completes successfully with no divergences.There are several parts I'm not comfortable with:
nparameter fromoporrvwithoutevaling any pytensors?@fonnesbeck @twiecki @jessegrabowski @velochy - please could you take a look? I would like to make sure that this fix makes sense before adding tests and making the linters pass.
Notes:
forwardintensorflow_probabilityisbackwardinpymcDescription
Backward method
Forward method
log_jac_det
This was quite complicated to implement, so I used the symbolic jacobian.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7380.org.readthedocs.build/en/7380/