Skip to content

Conversation

@johncant
Copy link

@johncant johncant commented Jun 21, 2024

I've ported this bijector from tensorflow and added to LKJCorr. This ensures that initial samples drawn from LKJCorr are positive definite, which fixes #7101 . Sampling now completes successfully with no divergences.

There are several parts I'm not comfortable with:

@fonnesbeck @twiecki @jessegrabowski @velochy - please could you take a look? I would like to make sure that this fix makes sense before adding tests and making the linters pass.

Notes:

  • Tests not yet written, linters not yet ran
  • The original tensorflow bijector is defined in the opposite sense to pymc transforms, i.e. forward in tensorflow_probability is backward in pymc
  • The original tensorflow bijector produces cholesky factors, not actual correlation matrices, so in this implementation, we have to do a cholesky decomposition in the forward transform.
  • In the tensorflow bijector, the triagonal elements of a matrix are filled in a clockwise spiral, as opposed to numpy which defines indices in a row-major order.

Description

Backward method

  1. Start with identity matrix and fill lower triangular elements with unconstrained real numbers.
  2. Normalize each row so the L-2 norm is 1
  3. This is now a Cholesky factor that will always result in positive definite correlation matrices

Forward method

  1. Reconstruct the correlation matrix from its upper triangular elements
  2. Perform cholesky decomposition to obtain L
  3. The diagonal elements of L are multipliers we used to normalize the other elements.
  4. Extract those diagonal elements and divide to undo the backward method

log_jac_det

This was quite complicated to implement, so I used the symbolic jacobian.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7380.org.readthedocs.build/en/7380/

@welcome
Copy link

welcome bot commented Jun 21, 2024

Thank You Banner]
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

@johncant johncant changed the title Fix #7101 by implementing a transform to that LKJCorr samples are positive definite Fix #7101 by implementing a transform to ensure that LKJCorr samples are positive definite Jun 21, 2024
@ricardoV94 ricardoV94 changed the title Fix #7101 by implementing a transform to ensure that LKJCorr samples are positive definite Implement unconstraining transform for LKJCorr Jun 24, 2024
@johncant
Copy link
Author

Hi, It's unlikely I'm going to have any time to work on this for the next 6 months. The hardest part is coming up with a closed form solution for log_det_jac, which I don't think I'm very close to doing.

@twiecki
Copy link
Member

twiecki commented Jul 30, 2024

Thanks for the update @johncant and for pushing this as far as you did.

@jessegrabowski jessegrabowski force-pushed the fix_lkjcorr_positive_definiteness branch from df723bc to cf8d9a8 Compare October 4, 2025 23:47
@github-actions github-actions bot added bug hackathon Suitable for hackathon labels Oct 4, 2025
@jessegrabowski
Copy link
Member

This is updated to a fully working implementation of the transformer, with non-trivial tests.

It is currently blocked by #7910, because it uses pt.flip with negative indices, which are bugged in the current main branch. Tests pass locally on the current pytensor main.

There's some jank -- I'm not sure we need to pass n to the constructor. We can easily infer n from all the inputs. I guess the benefit is to be able to check that shapes haven't changed somewhere since the transformer was initialized? But these are so internal to PyMC that users shouldn't be able to make that happen anyway (and if one can, he's a super user so who cares).

If we got rid of the n property, it would also do away with the unit_diag arguments on the triangular_pack helpers.

On that note, those two triangular pack helpers could themselves be a separate transformation, since they define a Bijection with forward and backward methods. That might be conceptually better, but on the other hand I don't think these would be used anywhere else, so it also makes sense to keep them as part of this class.

@jessegrabowski jessegrabowski force-pushed the fix_lkjcorr_positive_definiteness branch from 3ac4970 to 1f29284 Compare October 5, 2025 05:20
@ricardoV94
Copy link
Member

BTW once this is merged we should still explore the new transform by that Stan dev, it should have better sampling properties IIRC

@jessegrabowski
Copy link
Member

Do you have a link to the actual implementation? I was rooting around in the stan/stan-math repo and couldn't find it

@johncant
Copy link
Author

johncant commented Oct 5, 2025

Wow - congratulations @jessegrabowski !

@jessegrabowski
Copy link
Member

Wow - congratulations @jessegrabowski !

You had it 90% of the way there, we were just missing this weird spiral triangular construction that tfp was doing. I have no idea why they do it this way though, I just blindly copied xD

@jessegrabowski jessegrabowski force-pushed the fix_lkjcorr_positive_definiteness branch from 1f29284 to a6ab223 Compare October 29, 2025 17:16
@jessegrabowski
Copy link
Member

jessegrabowski commented Oct 29, 2025

I switch the LKJCorr implementation to use scan to generate samples. This means:

  1. Both eta and n can be symbolic
  2. Larger matrices sample much faster.
  3. But small matrices are somewhat slower

We could potentially try to check n and unroll the loop if it's sufficiently small. It also made me think an scan_to_unrolled_loop rewrite/tool might be interesting/useful in general.

Here is the benchmark I ran:

@pytest.mark.parametrize("n", [2, 10, 100])
def test_lkjcorr_sampling_benchmark(n, benchmark):
    d = pm.LKJCorr.dist(n=n, eta=2.0)
    sample_fn = pm.compile(inputs=[], outputs=[d], mode='FAST_RUN')
    sample_fn()

    benchmark(sample_fn)

Before (unrolled loop):

----------------------------------------------------------------------------------------------------- benchmark: 3 tests -----------------------------------------------------------------------------------------------------                                                                                                                                              
Name (time in us)                             Min                    Max                  Mean                StdDev                Median                 IQR            Outliers           OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                              
test_lkjcorr_sampling_benchmark[2]         3.1670 (1.0)          38.4999 (1.0)          3.6752 (1.0)          0.5252 (1.0)          3.7500 (1.0)        0.3759 (1.0)       435;361  272,093.1823 (1.0)       17765           1                                                                                                                                              
test_lkjcorr_sampling_benchmark[10]       59.1250 (18.67)       250.4170 (6.50)        68.1735 (18.55)        4.7520 (9.05)        69.0001 (18.40)      4.0832 (10.86)    2420;568   14,668.4544 (0.05)      12006           1
test_lkjcorr_sampling_benchmark[100]     831.8330 (262.66)   21,313.4581 (553.60)   1,588.8558 (432.32)   1,978.1464 (>1000.0)  1,050.2914 (280.08)   179.7080 (478.07)     63;117      629.3837 (0.00)        862           1                                                                                                                                              
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                              

After (scan):

----------------------------------------------------------------------------------------------------- benchmark: 3 tests ----------------------------------------------------------------------------------------------------                                                                                                                                               
Name (time in us)                               Min                   Max                  Mean              StdDev                Median                 IQR            Outliers           OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                               
test_lkjcorr_sampling_benchmark[2]           7.9999 (1.0)         77.4580 (1.0)          9.2234 (1.0)        2.3224 (1.0)          8.5421 (1.0)        0.9161 (1.0)       691;911  108,419.6575 (1.0)        8120           1                                                                                                                                               
test_lkjcorr_sampling_benchmark[10]        229.1250 (28.64)      581.1669 (7.50)       253.4111 (27.47)     21.7310 (9.36)       249.2500 (29.18)     19.0000 (20.74)     283;121    3,946.1577 (0.04)       3424           1
test_lkjcorr_sampling_benchmark[100]     3,908.5420 (488.57)   6,156.4590 (79.48)    4,609.0084 (499.71)   446.9420 (192.45)   4,542.3959 (531.77)   660.9370 (721.49)       70;3      216.9664 (0.00)        208           1                                                                                                                                               
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                               

@jessegrabowski
Copy link
Member

We can support estimation of eta now. We should be able to add broadcasting on these parameters. but I'll leave that for a future PR.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 29, 2025

Those timings look good enough I wouldn't worry. It will hopefully fall out from our goal of speeding up Scan. Can you try with Numba mode out of curiosity?

@jessegrabowski
Copy link
Member

Numba timings. With scan:

----------------------------------------------------------------------------------------------------- benchmark: 3 tests -----------------------------------------------------------------------------------------------------                                                                                                                                              
Name (time in us)                               Min                    Max                  Mean              StdDev                Median                 IQR            Outliers           OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                              
test_lkjcorr_sampling_benchmark[2]           8.5830 (1.0)          93.3340 (1.0)          9.6742 (1.0)        1.1092 (1.0)          9.8330 (1.0)        0.8750 (1.0)       215;186  103,367.4023 (1.0)       11911           1                                                                                                                                              
test_lkjcorr_sampling_benchmark[10]        199.0420 (23.19)    40,315.1250 (431.94)     272.4426 (28.16)    823.9673 (742.82)     225.4579 (22.93)     24.5001 (28.00)      23;326    3,670.4975 (0.04)       4113           1
test_lkjcorr_sampling_benchmark[100]     3,578.3750 (416.92)    5,509.3330 (59.03)    4,043.9366 (418.01)   256.9228 (231.62)   4,027.7915 (409.62)   316.2490 (361.44)       71;4      247.2838 (0.00)        238           1                                                                                                                                              
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                              

Unrolled loop (old implementation):

-------------------------------------------------------------------------------------------------- benchmark: 3 tests --------------------------------------------------------------------------------------------------                                                                                                                                                    
Name (time in us)                             Min                   Max                Mean              StdDev              Median                 IQR             Outliers  OPS (Kops/s)            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                                    
test_lkjcorr_sampling_benchmark[2]         4.7911 (1.0)         49.6670 (1.0)        5.5282 (1.0)        0.7295 (1.0)        5.5420 (1.0)        0.1250 (1.0)       174;2127      180.8892 (1.0)       10846           1                                                                                                                                                    
test_lkjcorr_sampling_benchmark[10]        7.5829 (1.58)        54.9171 (1.11)       8.7685 (1.59)       0.8032 (1.10)       8.8330 (1.59)       0.2501 (2.00)     4005;7133      114.0448 (0.63)      34385           1
test_lkjcorr_sampling_benchmark[100]     167.4580 (34.95)    2,400.6250 (48.33)    317.5335 (57.44)    235.1336 (322.34)   250.0410 (45.12)    105.5831 (844.46)     162;312        3.1493 (0.02)       2211           1                                                                                                                                                    
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------                                                                                                                                                    

Pretty bad!

@jessegrabowski
Copy link
Member

JAX also doesn't work with the new scan implementation, I guess because n is symbolic so pt.arange(2, n)(the length of the scan) is unknown at compile time

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 29, 2025

length shouldn't matter if only last state is used? can you write index as a recurring out?

I don't recall where arange shows up

@ricardoV94
Copy link
Member

timings are pretty nice, not worried about the small case

@jessegrabowski
Copy link
Member

timings are pretty nice, not worried about the small case

No, it's bad. I updated the original timings with a compiled function instead of pm.draw and the scan is significantly worse across the board. I can't reproduce the original timings, so I might have been picking up first-run compile times or something.

@jessegrabowski
Copy link
Member

I used pm.compile, doesn't that handle it for you?

@ricardoV94
Copy link
Member

Ah so you were seeing longer compile times that made it look like scan was doing better? Unrolled 100 things or more is gonna be pretty ugly though

@ricardoV94
Copy link
Member

Can you show the final numba graph?

@jessegrabowski jessegrabowski force-pushed the fix_lkjcorr_positive_definiteness branch 2 times, most recently from ce9dc2d to 655f433 Compare January 31, 2026 14:35
@codecov
Copy link

codecov bot commented Jan 31, 2026

Codecov Report

❌ Patch coverage is 99.35065% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 91.05%. Comparing base (c868a84) to head (aac5607).

Files with missing lines Patch % Lines
pymc/distributions/multivariate.py 98.46% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7380      +/-   ##
==========================================
+ Coverage   90.89%   91.05%   +0.16%     
==========================================
  Files         123      123              
  Lines       19489    19503      +14     
==========================================
+ Hits        17714    17758      +44     
+ Misses       1775     1745      -30     
Files with missing lines Coverage Δ
pymc/distributions/transforms.py 100.00% <100.00%> (ø)
pymc/sampling/jax.py 93.51% <ø> (+9.66%) ⬆️
pymc/testing.py 91.09% <100.00%> (+0.45%) ⬆️
pymc/distributions/multivariate.py 94.23% <98.46%> (+0.32%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jessegrabowski jessegrabowski force-pushed the fix_lkjcorr_positive_definiteness branch from 655f433 to 48b46f5 Compare February 1, 2026 16:46
@ricardoV94
Copy link
Member

ricardoV94 commented Feb 1, 2026

I have some asks, test-wise

  • Test the transform directly: forward-backward is identity and log-jac-det is correct, like we do for other transforms
  • Test that the forward of a random vector is indeed the cholesky factor of a valid correlation matrix (we test that indirectly via the new finite logp check, but I want to test the transform independently, because the logp could be wrong / or changed in a way that doesn't check that, and we would lose coverage)

And follow-up, we should open issues for:

@jessegrabowski
Copy link
Member

right now there's some confusion in the code about what LKJCorr should return:

  • the transform works with cholesky decomposed matrices, so values will be a flat vector and the forward method return L such that C = L @ L.T
  • But the forward draws returnC, the correlation matrix itself, directly.

Tests are failing as a result. Working with L is easier internally, but users will expect C, so I guess there needs to be a deterministic helper and a flag for which to give back?

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 1, 2026

Aren't you mixing LKJCorr and _LKJCorr? LKJCorr is the user facing one that returns the correlation matrix (optionally?), under the hood it works with _LKJCorr, which works on whatever space it wants. That's the one that needs an unconstraining transform.

@jessegrabowski
Copy link
Member

Aren't you mixing LKJCorr and _LKJCorr? LKJCorr is the user facing one that returns the correlation matrix (optionally?), under the hood it works with _LKJCorr, which works on whatever space it wants. That's the one that needs an unconstraining transform.

I'm not sure we need _LKJCorr at all anymore

@jessegrabowski jessegrabowski force-pushed the fix_lkjcorr_positive_definiteness branch from 48b46f5 to aac5607 Compare February 1, 2026 22:53
@ricardoV94
Copy link
Member

Aren't you mixing LKJCorr and _LKJCorr? LKJCorr is the user facing one that returns the correlation matrix (optionally?), under the hood it works with _LKJCorr, which works on whatever space it wants. That's the one that needs an unconstraining transform.

I'm not sure we need _LKJCorr at all anymore

What changed? I didn't see us changing anything in the RV side yet

@jessegrabowski
Copy link
Member

I added the other requested test.

I am hitting a problem now where the RNG returned from my scan in _random_corr_matrix is None in the TestLKJCorr:check_draws_match_expected test. I think it is something to do with how the scan Op is being rebuilt?

@jessegrabowski
Copy link
Member

What changed? I didn't see us changing anything in the RV side yet

_LKJCorr previously returned the packed lower-triangle of the cholesky decomposed matrix. The logp function also acted on vectors, and we sampled vectors.

Now everything works on matrices natively, so there's no need for helpers to do all the packing/unpacking. The transform returns matrix, the logp evaluates matrices, and the rv_op returns random matrices.

@ricardoV94
Copy link
Member

So is the underlying distribution no longer working on the non-zero-triangular entries now?

@ricardoV94
Copy link
Member

I don't know if it was worth it, but I understand that was done for memory efficiency. To avoid storing all those useless entries.

Also, computationally could the logp more expensive now?

And what are we working on, the correlation matrix or still the cholesky factor of it?

@jessegrabowski
Copy link
Member

We're working on the cholesky decomposition of the correlation matrix. This should be computationally cheaper across the board, but yes it will require more memory for random draws. But users can't do anything with the low memory draws anyway. If you actually want to compute with a correlation matrix, you are going to have to materialize it at some point.

The logp in this PR should be cheaper than what's on main. Here's the current code:

        shape = n * (n - 1) // 2
        tri_index = np.zeros((n, n), dtype="int32")
        tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
        tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)

        value = pt.take(value, tri_index)
        value = pt.fill_diagonal(value, 1)

        # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
        try:
            eta = float(get_underlying_scalar_constant_value(eta))
        except NotScalarConstantError:
            raise NotImplementedError("logp only implemented for constant `eta`")
        result = _lkj_normalizing_constant(eta, n)
        result += (eta - 1.0) * pt.log(det(value))
        return check_parameters(
            result,
            value >= -1,
            value <= 1,
            matrix_pos_def(value),
            eta > 0,
        )

And here's the new logp:

        result = _lkj_normalizing_constant(eta, n) + (eta - 1.0) * 2 * pt.diagonal(
            value, axis1=-2, axis2=-1
        ).log().sum(axis=-1)

        row_norms = pt.sum(value**2, axis=-1)

        return check_parameters(
            result,
            eta > 0,
            pt.isclose(row_norms, 1.0),
            msg="Invalid values passed to LKJCorr logp: value is not a valid correlationm matrix or eta <= 0.",
        )

You can see previously they did all this reshaping business to get the packed vector into a full correlation matrix then computed the logp.

Now we just exploit the cholesky structure to get cheap logdet. Our rewrites definitely wouldn't have been able to infer any structure from what were previously doing.

@jessegrabowski
Copy link
Member

But users can't do anything with the low memory draws anyway

I guess the point is that one can avoid storing them in the trace? That's an advantage I didn't think of. Personally I like storing the whole thing so I can have square dimensions on it and easily look at results. Working with the packed lower triangle requires multi-indices, which arviz supports only barely.

@jessegrabowski
Copy link
Member

Fix the double scan in numba mode

Just to keep things as mysterious as possible, I don't see the double scan anymore in the current state of the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug enhancements hackathon Suitable for hackathon

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: LKJCorr breaks when used as covariance with MvNormal

5 participants