Skip to content

Commit 7246fba

Browse files
Merge pull request SciML#3356 from SciML/fix-wrapfun-iip-inference
Fix wrapfun_iip type inference with FunctionWrappersWrappers v1.0+
2 parents 446cc1a + fb67fdd commit 7246fba

File tree

6 files changed

+37
-48
lines changed

6 files changed

+37
-48
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ FastClosures = "0.3.2"
124124
FillArrays = "1.13"
125125
FiniteDiff = "2.27"
126126
ForwardDiff = "0.10.38, 1"
127-
FunctionWrappersWrappers = "0.1.3, 1"
127+
FunctionWrappersWrappers = "1"
128128
InteractiveUtils = "1.9"
129129
JLArrays = "0.2, 0.3"
130130
LineSearches = "7.4"

lib/DiffEqBase/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ FastPower = "1.1"
8787
FlexUnits = "0.4"
8888
ForwardDiff = "0.10, 1"
8989
FunctionWrappers = "1.0"
90-
FunctionWrappersWrappers = "0.1, 1"
90+
FunctionWrappersWrappers = "1"
9191
GTPSA = "1.4"
9292
GeneralizedGenerated = "0.3"
9393
InteractiveUtils = "1.9"

lib/DiffEqBase/ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ function wrapfun_oop(ff, inputs::Tuple = ())
6161
)
6262
end
6363

64+
# Construct FunctionWrappersWrapper bypassing the convenience constructor.
65+
# The convenience constructor's `map` doesn't infer when the callable has many
66+
# type parameters (e.g. ODEFunction with 20+), because the FunctionWrapper
67+
# constructor in the separately-precompiled FunctionWrappers package can't be
68+
# traced through for complex types. Using Type{A} dispatch binds the arglist
69+
# types as type parameters, making FW{Nothing, A}(vff) fully inferrable.
70+
function _make_fww(
71+
@nospecialize(vff),
72+
::Type{A1}, ::Type{A2}, ::Type{A3}, ::Type{A4},
73+
) where {A1, A2, A3, A4}
74+
FW = FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper
75+
fwt = (
76+
FW{Nothing, A1}(vff), FW{Nothing, A2}(vff),
77+
FW{Nothing, A3}(vff), FW{Nothing, A4}(vff),
78+
)
79+
cs = FunctionWrappersWrappers.SingleCacheStorage()
80+
return FunctionWrappersWrappers.FunctionWrappersWrapper{
81+
typeof(fwt), FunctionWrappersWrappers.AllowNonIsBits, typeof(cs),
82+
}(fwt, cs)
83+
end
84+
6485
function wrapfun_iip(
6586
ff,
6687
inputs::Tuple{T1, T2, T3, T4}
@@ -71,16 +92,13 @@ function wrapfun_iip(
7192
dualT2 = ArrayInterface.promote_eltype(T2, dualT)
7293
dualT4 = dualgen(promote_type(T, T4))
7394

74-
iip_arglists = (
95+
return _make_fww(
96+
Void(ff),
7597
Tuple{T1, T2, T3, T4},
7698
Tuple{dualT1, dualT2, T3, T4},
7799
Tuple{dualT1, T2, T3, dualT4},
78-
Tuple{dualT1, dualT2, T3, dualT4},
100+
Tuple{dualT1, dualT2, T3, dualT4}
79101
)
80-
81-
iip_returnlists = ntuple(x -> Nothing, 4)
82-
83-
return FunctionWrappersWrappers.FunctionWrappersWrapper(Void(ff), iip_arglists, iip_returnlists)
84102
end
85103

86104
# 3-arg version: compile FunctionWrapper variants with the specified chunk size.
@@ -103,45 +121,15 @@ function wrapfun_iip(
103121
dualT1_time = ArrayInterface.promote_eltype(T1, dualT_time)
104122
dualT4_time = dualgen(promote_type(T, T4))
105123

106-
iip_arglists = (
107-
Tuple{T1, T2, T3, T4}, # plain
108-
Tuple{dualT1_jac, dualT2_jac, T3, T4}, # Jacobian (u dual, chunk=CS)
109-
Tuple{dualT1_time, T2, T3, dualT4_time}, # time derivative (chunk=1)
110-
Tuple{dualT1_jac, dualT2_jac, T3, dualT4_time}, # both
124+
return _make_fww(
125+
Void(ff),
126+
Tuple{T1, T2, T3, T4},
127+
Tuple{dualT1_jac, dualT2_jac, T3, T4},
128+
Tuple{dualT1_time, T2, T3, dualT4_time},
129+
Tuple{dualT1_jac, dualT2_jac, T3, dualT4_time}
111130
)
112-
113-
iip_returnlists = ntuple(x -> Nothing, 4)
114-
115-
return FunctionWrappersWrappers.FunctionWrappersWrapper(Void(ff), iip_arglists, iip_returnlists)
116131
end
117132

118-
const iip_arglists_default = (
119-
Tuple{
120-
Vector{Float64}, Vector{Float64}, Vector{Float64},
121-
Float64,
122-
},
123-
Tuple{
124-
Vector{Float64}, Vector{Float64},
125-
SciMLBase.NullParameters,
126-
Float64,
127-
},
128-
Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT},
129-
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, dualT},
130-
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64},
131-
Tuple{
132-
Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters,
133-
Float64,
134-
},
135-
Tuple{
136-
Vector{dualT}, Vector{Float64},
137-
SciMLBase.NullParameters, dualT,
138-
},
139-
)
140-
const iip_returnlists_default = ntuple(x -> Nothing, length(iip_arglists_default))
141-
142-
function wrapfun_iip(@nospecialize(ff))
143-
return FunctionWrappersWrappers.FunctionWrappersWrapper(Void(ff), iip_arglists_default, iip_returnlists_default)
144-
end
145133

146134
function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kwargs)
147135
if (haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])) ||

lib/DiffEqBase/test/downstream/inference.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ sol = solve(prob, Tsit5(), save_idxs = 1)
1717

1818
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz, u0, tspan)
1919

20-
@inferred SciMLBase.wrapfun_iip(prob.f)
20+
@inferred SciMLBase.wrapfun_iip(prob.f, (u0, u0, Float64[], tspan[1]))
2121
@inferred remake(prob, u0 = [1.0; 0.0; 0.0])
2222
@inferred remake(prob, u0 = Float32[1.0; 0.0; 0.0])
2323
@test_broken @inferred(solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0])) ==
@@ -29,7 +29,8 @@ prob = ODEProblem(lorenz, Float32[1.0; 0.0; 0.0], tspan)
2929
solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0])
3030
remake(prob, u0 = [1.0; 0.0; 0.0])
3131

32-
@inferred SciMLBase.wrapfun_iip(prob.f)
32+
u0_32 = Float32[1.0; 0.0; 0.0]
33+
@inferred SciMLBase.wrapfun_iip(prob.f, (u0_32, u0_32, Float32[], tspan[1]))
3334
@test_broken @inferred(
3435
ODEFunction{
3536
isinplace(prob), SciMLBase.FunctionWrapperSpecialize,

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Accessors = "0.1.36"
5757
ConcreteStructs = "0.2"
5858
StaticArraysCore = "1.4.3"
5959
SciMLStructures = "1.7"
60-
FunctionWrappersWrappers = "0.1, 1"
60+
FunctionWrappersWrappers = "1"
6161
FastBroadcast = "1.3"
6262
Random = "<0.0.1, 1"
6363
DiffEqDevTools = "2.44.4"

lib/OrdinaryDiffEqDifferentiation/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ DiffEqBase = {path = "../DiffEqBase"}
1212
[compat]
1313
Pkg = "1"
1414
ForwardDiff = "0.10.38, 1"
15-
FunctionWrappersWrappers = "0.1, 1"
15+
FunctionWrappersWrappers = "1"
1616
FastBroadcast = "1.3"
1717
Random = "<0.0.1, 1"
1818
DiffEqDevTools = "2.44.4"

0 commit comments

Comments
 (0)