Skip to content

Commit 6aeb4d7

Browse files
committed
Use ForwardDiff 1.0.1, fix LKJCholesky Jacobian test
1 parent c6facf9 commit 6aeb4d7

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Distributions = "0.25.33"
5151
DistributionsAD = "0.6"
5252
DocStringExtensions = "0.9"
5353
EnzymeCore = "0.8.4"
54-
ForwardDiff = "0.10, 1"
54+
ForwardDiff = "0.10, 1.0.1"
5555
Functors = "0.1, 0.2, 0.3, 0.4, 0.5"
5656
InverseFunctions = "0.1"
5757
IrrationalConstants = "0.1, 0.2"

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Enzyme = "0.13.12"
3636
EnzymeTestUtils = "0.2.1"
3737
FillArrays = "1"
3838
FiniteDifferences = "0.11, 0.12"
39-
ForwardDiff = "1"
39+
ForwardDiff = "1.0.1"
4040
Functors = "0.1, 0.2, 0.3, 0.4, 0.5"
4141
InverseFunctions = "0.1"
4242
LazyArrays = "1, 2"

test/transform.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,16 +238,16 @@ end
238238
@testset "uplo: $uplo" for uplo in [:L, :U]
239239
dist = LKJCholesky(3, 1, uplo)
240240
single_sample_tests(dist)
241-
242241
x = rand(dist)
243-
244-
inds = [
245-
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
246-
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
247-
]
248242
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
243+
# Remove columns of Jacobian that are all zero (i.e. those
244+
# corresponding to entries above the diagonal for uplo = :U, or below
245+
# the diagonal for uplo = :L). This slightly unscientific approach
246+
# based on filter() is needed to handle both ForwardDiff 0.10 and 1 as
247+
# the exact indices will differ for the two versions.
248+
inds = filter(i -> !all(iszero, J[:, i]), 1:size(J, 2))
249249
J = J[:, inds]
250-
logpdf_turing = logpdf_with_trans(dist, x, true)
250+
logpdf_turing = logpdf_with_trans(dist, x, true)
251251
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
252252
end
253253
end

0 commit comments

Comments
 (0)