@@ -57,6 +57,29 @@ def test_inverse_wishart_sample(df=7.0, dim=3, scale_factor=3.0, n_samples=10000
5757 mc_std = jnp .sqrt (iw .variance () / n_samples )
5858 assert jnp .allclose (samples .mean (axis = 0 ), iw .mean (), atol = num_std * mc_std )
5959
60+ def test_inverse_wishart_sample_non_diagonal_scale (n_samples : int = 10_000 , num_std = 3 ):
61+ """Test sample mean of an inverse-Wishart distr. w/ non-diagonal scale matrix."""
62+ k = 2
63+ 𝜈 = 5.5 # 𝜈 > k
64+ Ψ = jnp .array ([[20.712932 , 25.124634 ],
65+ [25.124634 , 32.814785 ]], dtype = jnp .float32 ) # k x k
66+ Ψ_diag = jnp .diagonal (Ψ )
67+ assert all (jnp .linalg .eigvals (Ψ ) > 0 ) # Is positive definite.
68+
69+ iw = InverseWishart (df = 𝜈 , scale = Ψ )
70+ Σs = iw .sample (sample_shape = n_samples , seed = jr .key (42 ))
71+ actual_Σ_avg = jnp .mean (Σs , axis = 0 )
72+
73+ # Closed form expression of mean.
74+ true_Σ_avg = Ψ / (𝜈 - k - 1 )
75+ # Closed form expression of variance.
76+ numerator = (𝜈 - k + 1 ) * Ψ ** 2 + (𝜈 - k - 1 ) * jnp .outer (Ψ_diag , Ψ_diag )
77+ denominator = (𝜈 - k ) * (𝜈 - k - 1 )** 2 * (𝜈 - k - 3 )
78+ true_Σ_var = numerator / denominator
79+
80+ mc_std = jnp .sqrt (true_Σ_var / n_samples )
81+ assert jnp .allclose (actual_Σ_avg , true_Σ_avg , atol = num_std * mc_std )
82+
6083
6184def test_normal_inverse_wishart_mode (loc = 0. , mean_conc = 1.0 , df = 7.0 , dim = 3 , scale_factor = 3.0 ):
6285 """
0 commit comments