Skip to content

Commit 0f36d68

Browse files
Merge pull request #94 from sharanry/sy/type_instability
Make function transformation type stable
2 parents 806fdf0 + a7432b5 commit 0f36d68

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Integrals"
22
uuid = "de52edbc-65ea-441a-8357-d3a637375a31"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "3.1.1"
4+
version = "3.1.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -38,7 +38,8 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
3838
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3939
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4040
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
41+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4142
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4243

4344
[targets]
44-
test = ["SciMLSensitivity", "FiniteDiff", "Pkg", "SafeTestsets", "Test"]
45+
test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test"]

src/Integrals.jl

+15-12
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function scale_x(ub, lb, x)
3838
end
3939

4040
function v_inf(t)
41-
return t ./ (1 .- t .^ 2)
41+
return map(t -> t / (1 - t^2), t)
4242
end
4343

4444
function v_semiinf(t, a, upto_inf)
@@ -76,10 +76,13 @@ function transform_inf(t, p, f, lb, ub)
7676
semilw = lbb .& .!ubb
7777

7878
function v(t)
79-
return t .* _none + v_inf(t) .* _inf + v_semiinf(t, lb, 1) .* semiup +
80-
v_semiinf(t, ub, 0) .* semilw
79+
t .* _none + v_inf(t) .* _inf + v_semiinf(t, lb, 1) .* semiup +
80+
v_semiinf(t, ub, 0) .* semilw
8181
end
82-
jac = Zygote.@ignore ForwardDiff.jacobian(x -> v(x), t)
82+
jac = ChainRulesCore.@ignore_derivatives ForwardDiff.jacobian(x -> v(x),
83+
t |> Vector)::Matrix{
84+
eltype(t)
85+
}
8386
j = det(jac)
8487
f(v(t), p) * (j)
8588
end
@@ -180,9 +183,9 @@ function __solvebp_call(prob::IntegralProblem, ::HCubatureJL, sensealg, lb, ub,
180183

181184
if isinplace(prob)
182185
dx = zeros(prob.nout)
183-
f = (x) -> (prob.f(dx, x, p); dx)
186+
f = x -> (prob.f(dx, x, prob.p); dx)
184187
else
185-
f = (x) -> prob.f(x, p)
188+
f = x -> prob.f(x, prob.p)
186189
end
187190
@assert prob.batch == 0
188191

@@ -207,16 +210,16 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p,
207210
if prob.batch == 0
208211
if isinplace(prob)
209212
dx = zeros(prob.nout)
210-
f = (x) -> (prob.f(dx, x, p); dx)
213+
f = x -> (f(dx, x, p); dx)
211214
else
212-
f = (x) -> prob.f(x, p)
215+
f = x -> prob.f(x, prob.p)
213216
end
214217
else
215218
if isinplace(prob)
216219
dx = zeros(prob.batch)
217-
f = (x) -> (prob.f(dx, x', p); dx)
220+
f = x -> (f(dx, x', p); dx)
218221
else
219-
f = (x) -> prob.f(x', p)
222+
f = x -> f(x', p)
220223
end
221224
end
222225
val, err, chi = vegas(f, lb, ub, rtol = reltol, atol = abstol,
@@ -232,7 +235,7 @@ function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub,
232235
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ
233236
if isinplace(prob)
234237
dx = zeros(prob.nout)
235-
_f = (x) -> prob.f(dx, x, p)
238+
_f = x -> prob.f(dx, x, p)
236239
if sensealg.vjp isa ZygoteVJP
237240
dfdp = function (dx, x, p)
238241
_, back = Zygote.pullback(p) do p
@@ -252,7 +255,7 @@ function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub,
252255
error("TODO")
253256
end
254257
else
255-
_f = (x) -> prob.f(x, p)
258+
_f = x -> prob.f(x, p)
256259
if sensealg.vjp isa ZygoteVJP
257260
if prob.batch > 0
258261
dfdp = function (x, p)

test/inf_integral_tests.jl

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Integrals, Distributions, Test
1+
using Integrals, Distributions, Test, StaticArrays
22

33
μ = [0.00, 0.00]
44
Σ = [0.4 0.0; 0.00 0.4]
@@ -35,3 +35,19 @@ prob = IntegralProblem(f, 0, Inf)
3535
sol = solve(prob, HCubatureJL(), reltol = 1e-3, abstol = 1e-3)
3636
@test (pi / 2 - sol.u)^2 < 1e-6
3737
@test_nowarn @inferred Integrals.transformation_if_inf(prob, Val(true))
38+
39+
# Type stability
40+
μ = [0.00, 0.00]
41+
Σ = [0.4 0.0; 0.00 0.4]
42+
d = MvNormal(μ, Σ)
43+
m2 = let d = d
44+
(x, p) -> pdf(d, x)
45+
end
46+
47+
prob = IntegralProblem(m2, SVector(-Inf, -Inf), SVector(Inf, Inf))
48+
@test_nowarn @inferred solve(prob, HCubatureJL(); do_inf_transformation = Val(true))
49+
50+
prob = @test_nowarn @inferred Integrals.transformation_if_inf(prob, Val(true))
51+
@test_nowarn @inferred Integrals.__solvebp_call(prob, HCubatureJL(),
52+
Integrals.ReCallVJP(Integrals.ZygoteVJP()),
53+
prob.lb, prob.ub, prob.p)

0 commit comments

Comments
 (0)