Skip to content

Distributions integration tests#2617

Merged
wsmoses merged 4 commits intomainfrom
dist
Sep 23, 2025
Merged

Distributions integration tests#2617
wsmoses merged 4 commits intomainfrom
dist

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Sep 23, 2025

No description provided.

@github-actions
Copy link
Contributor

github-actions bot commented Sep 23, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/test/integration/Distributions/runtests.jl b/test/integration/Distributions/runtests.jl
index b8eb219..bf2c67a 100644
--- a/test/integration/Distributions/runtests.jl
+++ b/test/integration/Distributions/runtests.jl
@@ -42,29 +42,29 @@ end
 
 # Turn a distribution into a call to logpdf.
 function TestCase(d::Distribution, value, name, runtime_activity, broken, splat)
-    TestCase(x -> logpdf(d, x), value, name, runtime_activity, broken, splat)
+    return TestCase(x -> logpdf(d, x), value, name, runtime_activity, broken, splat)
 end
 
 # Defaults for name, runtime_activity and broken.
 function TestCase(
-    f, value;
-    name=nothing, runtime_activity=Neither, broken=Neither, splat=false
-)
+        f, value;
+        name = nothing, runtime_activity = Neither, broken = Neither, splat = false
+    )
     return TestCase(f, value, name, runtime_activity, broken, splat)
 end
 
 # Default name for a Distribution.
 function TestCase(
-    d::Distribution, value;
-    name=string(nameof(typeof(d))), runtime_activity=Neither, broken=Neither, splat=false
-)
+        d::Distribution, value;
+        name = string(nameof(typeof(d))), runtime_activity = Neither, broken = Neither, splat = false
+    )
     return TestCase(d, value, name, runtime_activity, broken, splat)
 end
 
 """
 Test Enzyme.gradient, both Forward and Reverse mode, against FiniteDifferences.grad.
 """
-function test_grad(case::TestCase; rtol=1e-6, atol=1e-6)
+function test_grad(case::TestCase; rtol = 1.0e-6, atol = 1.0e-6)
     @nospecialize
     f = case.func
     # We'll call the function as f(x...), so wrap in a singleton tuple if need be.
@@ -209,10 +209,10 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
         TestCase(LogUniform(0.15, 7.8), 7.1),
         TestCase(LogUniform(2.0, 3.0), 2.1),
         # TODO Broken tests, see https://github.com/EnzymeAD/Enzyme.jl/issues/1620
-        TestCase(NoncentralBeta(1.1, 1.1, 1.2), 0.8; broken=Both), # foreigncall (Rmath.dnbeta).
-        TestCase(NoncentralChisq(2, 3.0), 10.0; broken=Both), # foreigncall (Rmath.dnchisq).
-        TestCase(NoncentralF(2, 3, 1.1), 4.1; broken=Both), # foreigncall (Rmath.dnf).
-        TestCase(NoncentralT(1.3, 1.1), 0.1; broken=Both), # foreigncall (Rmath.dnt).
+        TestCase(NoncentralBeta(1.1, 1.1, 1.2), 0.8; broken = Both), # foreigncall (Rmath.dnbeta).
+        TestCase(NoncentralChisq(2, 3.0), 10.0; broken = Both), # foreigncall (Rmath.dnchisq).
+        TestCase(NoncentralF(2, 3, 1.1), 4.1; broken = Both), # foreigncall (Rmath.dnf).
+        TestCase(NoncentralT(1.3, 1.1), 0.1; broken = Both), # foreigncall (Rmath.dnt).
         TestCase(Normal(), 0.1),
         TestCase(Normal(0.0, 1.0), 1.0),
         TestCase(Normal(0.5, 1.0), 0.05),
@@ -220,7 +220,7 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
         TestCase(Normal(-0.1, 0.9), -0.3),
         # TODO Broken test, see https://github.com/EnzymeAD/Enzyme.jl/issues/1603
         # foreigncall -- https://github.com/JuliaMath/SpecialFunctions.jl/blob/be1fa06fee58ec019a28fb0cd2b847ca83a5af9a/src/bessel.jl#L265
-        TestCase(NormalInverseGaussian(0.0, 1.0, 0.2, 0.1), 0.1; broken=Both),
+        TestCase(NormalInverseGaussian(0.0, 1.0, 0.2, 0.1), 0.1; broken = Both),
         TestCase(Pareto(1.0, 1.0), 3.5),
         TestCase(Pareto(1.1, 0.9), 3.1),
         TestCase(Pareto(1.0, 1.0), 1.4),
@@ -231,7 +231,7 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
         TestCase(Rayleigh(0.9), 1.1),
         TestCase(Rayleigh(0.55), 0.63),
         # TODO Broken test, see https://github.com/EnzymeAD/Enzyme.jl/issues/1620
-        TestCase(Rician(0.5, 1.0), 2.1; broken=Both),  # foreigncall (Rmath.dnchisq).
+        TestCase(Rician(0.5, 1.0), 2.1; broken = Both),  # foreigncall (Rmath.dnchisq).
         TestCase(Semicircle(1.0), 0.9),
         TestCase(Semicircle(5.1), 5.05),
         TestCase(Semicircle(0.5), -0.1),
@@ -281,24 +281,26 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
         TestCase(MvNormal([0.2, 0.3], Symmetric(Diagonal([0.5, 0.4]))), [-0.1, 0.05]),
         TestCase(MvNormal([0.2, 0.3], Diagonal([0.5, 0.4])), [-0.1, 0.05]),
         # TODO Broken tests, see https://github.com/EnzymeAD/Enzyme.jl/issues/1991
-        TestCase(MvNormal([-0.15], _pdmat([1.1]')), [-0.05]; broken=Forward),
+        TestCase(MvNormal([-0.15], _pdmat([1.1]')), [-0.05]; broken = Forward),
         TestCase(
             MvNormal([0.2, -0.15], _pdmat([1.0 0.9; 0.7 1.1])), [0.05, -0.05];
-            broken=Forward
+            broken = Forward
         ),
         TestCase(MvNormal([0.2, -0.3], [0.5, 0.6]), [0.4, -0.3]),
         # TODO https://github.com/EnzymeAD/Enzyme.jl/issues/2618, trmv error
-        TestCase(MvNormalCanon([0.1, -0.1], _pdmat([0.5 0.4; 0.45 1.0])), [0.2, -0.25]; 
-            broken=(VERSION >= v"1.11" ? Forward : Neither)),
+        TestCase(
+            MvNormalCanon([0.1, -0.1], _pdmat([0.5 0.4; 0.45 1.0])), [0.2, -0.25];
+            broken = (VERSION >= v"1.11" ? Forward : Neither)
+        ),
         # TODO Broken tests, see https://github.com/EnzymeAD/Enzyme.jl/issues/1991
         TestCase(
             MvLogNormal(MvNormal([0.2, -0.1], _pdmat([1.0 0.9; 0.7 1.1]))), [0.5, 0.1];
-            broken=Forward
+            broken = Forward
         ),
         TestCase(product_distribution([Normal()]), [0.3]),
         TestCase(
             product_distribution([Normal(), Uniform()]), [-0.4, 0.3];
-            runtime_activity=Both
+            runtime_activity = Both
         ),
 
         #
@@ -315,12 +317,12 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
         TestCase(
             Wishart(5, _pdmat(randn(rng, 3, 3))),
             Symmetric(collect(_pdmat(randn(rng, 3, 3))));
-            broken=Forward
+            broken = Forward
         ),
         TestCase(
             InverseWishart(5, _pdmat(randn(rng, 3, 3))),
             Symmetric(collect(_pdmat(randn(rng, 3, 3))));
-            broken=Forward
+            broken = Forward
         ),
         # TODO Broken tests, see https://github.com/EnzymeAD/Enzyme.jl/issues/1820
         # getrf derivative needed
@@ -332,15 +334,15 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
                 _pdmat(randn(rng, 3, 3)),
             ),
             randn(rng, 2, 3);
-            broken=Both
+            broken = Both
         ),
-        TestCase(MatrixBeta(5, 6.0, 7.0), rand(rng, MatrixBeta(5, 6.0, 6.0)); broken=Both),
+        TestCase(MatrixBeta(5, 6.0, 7.0), rand(rng, MatrixBeta(5, 6.0, 6.0)); broken = Both),
         TestCase(
             MatrixFDist(6.0, 7.0, _pdmat(randn(rng, 5, 5))),
             rand(rng, MatrixFDist(6.0, 7.0, _pdmat(randn(rng, 5, 5))));
-            broken=Both
+            broken = Both
         ),
-        TestCase(LKJ(5, 1.1), rand(rng, LKJ(5, 1.1)); broken=Both),
+        TestCase(LKJ(5, 1.1), rand(rng, LKJ(5, 1.1)); broken = Both),
 
         #
         # Miscellaneous others
@@ -348,50 +350,50 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
 
         TestCase(
             (a, b, x) -> logpdf(InverseGamma(a, b), x), (1.5, 1.4, 0.4);
-            name="InverseGamma", splat=true
+            name = "InverseGamma", splat = true
         ),
         TestCase(
             (m, s, x) -> logpdf(NormalCanon(m, s), x), (0.1, 1.0, -0.5);
-            name="NormalCanon", splat=true
+            name = "NormalCanon", splat = true
         ),
-        TestCase(x -> logpdf(Categorical(x, 1 - x), 1), 0.3; name="Categorical"),
+        TestCase(x -> logpdf(Categorical(x, 1 - x), 1), 0.3; name = "Categorical"),
 
         # TODO Broken test, see https://github.com/EnzymeAD/Enzyme.jl/issues/1995
         # Forward mode runtime needed
         TestCase(
             (m, S, x) -> logpdf(MvLogitNormal(m, S), vcat(x, 1 - sum(x))),
             ([0.4, 0.6], _pdmat([0.9 0.4; 0.5 1.1]), [0.27, 0.24]);
-            name="MvLogitNormal", runtime_activity=Forward, broken=Forward, splat=true,
+            name = "MvLogitNormal", runtime_activity = Forward, broken = Forward, splat = true,
         ),
         TestCase(
             (a, b, α, β, x) -> logpdf(truncated(Beta(α, β), a, b), x),
             (0.1, 0.9, 1.1, 1.3, 0.4);
-            name="truncated Beta", splat=true
+            name = "truncated Beta", splat = true
         ),
         TestCase(
             (a, b, x) -> logpdf(truncated(Normal(), a, b), x),
             (-0.3, 0.3, 0.1);
-            name="allocs Normal", splat=true
+            name = "allocs Normal", splat = true
         ),
         TestCase(
             (a, b, α, β, x) -> logpdf(truncated(Uniform(α, β), a, b), x),
             (0.1, 0.9, -0.1, 1.1, 0.4);
-            name="allocs Uniform", splat=true
+            name = "allocs Uniform", splat = true
         ),
         TestCase(
             (a, x) -> logpdf(Dirichlet(a), [x, 1 - x]), ([1.5, 1.1], 0.6);
-            name="Dirichlet", splat=true, runtime_activity=Forward
+            name = "Dirichlet", splat = true, runtime_activity = Forward
         ),
         TestCase(
             x -> logpdf(reshape(product_distribution([Normal(), Uniform()]), 1, 2), x),
             [2.1 0.7];
-            name="reshape"
+            name = "reshape"
         ),
         # TODO Broken test, see https://github.com/EnzymeAD/Enzyme.jl/issues/1820
         # needs getrf derivative
         TestCase(
             x -> logpdf(vec(LKJ(2, 1.1)), x), [1.0, 0.489, 0.489, 1.0];
-            name="vec", broken=Both
+            name = "vec", broken = Both
         ),
         TestCase(
             function (X, v)
@@ -403,7 +405,7 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
                 return logpdf(LKJCholesky(2, v), C)
             end,
             (randn(rng, 2, 2), 1.1);
-            name="LKJCholesky", splat=true
+            name = "LKJCholesky", splat = true
         ),
     ]
 
@@ -412,4 +414,4 @@ _pdmat(A) = PDMat(_sym(A) + 5I)
     end
 end
 
-end
\ No newline at end of file
+end

@wsmoses wsmoses merged commit 0f1c3df into main Sep 23, 2025
15 checks passed
@wsmoses wsmoses deleted the dist branch September 23, 2025 04:23
@giordano giordano mentioned this pull request Nov 4, 2025
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants