@@ -38,7 +38,7 @@ function scale_x(ub, lb, x)
38
38
end
39
39
40
40
function v_inf (t)
41
- return t . / (1 . - t .^ 2 )
41
+ return map (t -> t / (1 - t^ 2 ), t )
42
42
end
43
43
44
44
function v_semiinf (t, a, upto_inf)
@@ -76,10 +76,13 @@ function transform_inf(t, p, f, lb, ub)
76
76
semilw = lbb .& .! ubb
77
77
78
78
function v (t)
79
- return t .* _none + v_inf (t) .* _inf + v_semiinf (t, lb, 1 ) .* semiup +
80
- v_semiinf (t, ub, 0 ) .* semilw
79
+ t .* _none + v_inf (t) .* _inf + v_semiinf (t, lb, 1 ) .* semiup +
80
+ v_semiinf (t, ub, 0 ) .* semilw
81
81
end
82
- jac = Zygote. @ignore ForwardDiff. jacobian (x -> v (x), t)
82
+ jac = ChainRulesCore. @ignore_derivatives ForwardDiff. jacobian (x -> v (x),
83
+ t |> Vector):: Matrix {
84
+ eltype (t)
85
+ }
83
86
j = det (jac)
84
87
f (v (t), p) * (j)
85
88
end
@@ -180,9 +183,9 @@ function __solvebp_call(prob::IntegralProblem, ::HCubatureJL, sensealg, lb, ub,
180
183
181
184
if isinplace (prob)
182
185
dx = zeros (prob. nout)
183
- f = (x) -> (prob. f (dx, x, p); dx)
186
+ f = x -> (prob. f (dx, x, prob . p); dx)
184
187
else
185
- f = (x) -> prob. f (x, p)
188
+ f = x -> prob. f (x, prob . p)
186
189
end
187
190
@assert prob. batch == 0
188
191
@@ -207,16 +210,16 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p,
207
210
if prob. batch == 0
208
211
if isinplace (prob)
209
212
dx = zeros (prob. nout)
210
- f = (x) -> (prob . f (dx, x, p); dx)
213
+ f = x -> (f (dx, x, p); dx)
211
214
else
212
- f = (x) -> prob. f (x, p)
215
+ f = x -> prob. f (x, prob . p)
213
216
end
214
217
else
215
218
if isinplace (prob)
216
219
dx = zeros (prob. batch)
217
- f = (x) -> (prob . f (dx, x' , p); dx)
220
+ f = x -> (f (dx, x' , p); dx)
218
221
else
219
- f = (x) -> prob . f (x' , p)
222
+ f = x -> f (x' , p)
220
223
end
221
224
end
222
225
val, err, chi = vegas (f, lb, ub, rtol = reltol, atol = abstol,
@@ -232,7 +235,7 @@ function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub,
232
235
y = typeof (Δ) <: Array{<:Number, 0} ? Δ[1 ] : Δ
233
236
if isinplace (prob)
234
237
dx = zeros (prob. nout)
235
- _f = (x) -> prob. f (dx, x, p)
238
+ _f = x -> prob. f (dx, x, p)
236
239
if sensealg. vjp isa ZygoteVJP
237
240
dfdp = function (dx, x, p)
238
241
_, back = Zygote. pullback (p) do p
@@ -252,7 +255,7 @@ function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub,
252
255
error (" TODO" )
253
256
end
254
257
else
255
- _f = (x) -> prob. f (x, p)
258
+ _f = x -> prob. f (x, p)
256
259
if sensealg. vjp isa ZygoteVJP
257
260
if prob. batch > 0
258
261
dfdp = function (x, p)
0 commit comments