Skip to content

Commit be3b472

Browse files
Merge pull request #92 from ChrisRackauckas-Claude/interface-check-20251229-143358
Fix BigFloat compatibility in MVector size parameter
2 parents a620841 + 165ac0a commit be3b472

File tree

8 files changed

+96
-9
lines changed

8 files changed

+96
-9
lines changed

src/euler/gpueuler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export GPUSimpleEuler
1717
t = tspan[1]
1818
tf = prob.tspan[2]
1919
ts = tspan[1]:dt:tspan[2]
20-
us = MVector{length(ts), typeof(u0)}(undef)
20+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
2121
us[1] = u0
2222
u = u0
2323

src/rk4/gpurk4.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export GPUSimpleRK4
1717
t = tspan[1]
1818
tf = prob.tspan[2]
1919
ts = tspan[1]:dt:tspan[2]
20-
us = MVector{length(ts), typeof(u0)}(undef)
20+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
2121
us[1] = u0
2222
u = u0
2323
half = convert(eltype(u0), 1 // 2)

src/tsit5/atsit5.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleATsit5;
7171
else
7272
ts = saveat
7373
cur_t = 1
74-
us = MVector{length(ts), typeof(u0)}(undef)
74+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
7575
if prob.tspan[1] == ts[1]
7676
cur_t += 1
7777
us[1] = u0

src/tsit5/gpuatsit5.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export GPUSimpleTsit5
2626
else
2727
ts = saveat
2828
cur_t = 1
29-
us = MVector{length(ts), typeof(u0)}(undef)
29+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
3030
if prob.tspan[1] == ts[1]
3131
cur_t += 1
3232
us[1] = u0
@@ -125,7 +125,7 @@ SciMLBase.isadaptive(alg::GPUSimpleATsit5) = true
125125
else
126126
ts = saveat
127127
cur_t = 1
128-
us = MVector{length(ts), typeof(u0)}(undef)
128+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
129129
if prob.tspan[1] == ts[1]
130130
cur_t += 1
131131
us[1] = u0

src/verner/gpuvern7.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export GPUSimpleVern7
2626
else
2727
ts = saveat
2828
cur_t = 1
29-
us = MVector{length(ts), typeof(u0)}(undef)
29+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
3030
if prob.tspan[1] == ts[1]
3131
cur_t += 1
3232
us[1] = u0
@@ -195,7 +195,7 @@ SciMLBase.isadaptive(alg::GPUSimpleAVern7) = true
195195
else
196196
ts = saveat
197197
cur_t = 1
198-
us = MVector{length(ts), typeof(u0)}(undef)
198+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
199199
if prob.tspan[1] == ts[1]
200200
cur_t += 1
201201
us[1] = u0

src/verner/gpuvern9.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export GPUSimpleVern9
2626
else
2727
ts = saveat
2828
cur_t = 1
29-
us = MVector{length(ts), typeof(u0)}(undef)
29+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
3030
if prob.tspan[1] == ts[1]
3131
cur_t += 1
3232
us[1] = u0
@@ -280,7 +280,7 @@ SciMLBase.isadaptive(alg::GPUSimpleAVern9) = true
280280
else
281281
ts = saveat
282282
cur_t = 1
283-
us = MVector{length(ts), typeof(u0)}(undef)
283+
us = MVector{Int(length(ts)), typeof(u0)}(undef)
284284
if prob.tspan[1] == ts[1]
285285
cur_t += 1
286286
us[1] = u0

test/interface_tests.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using SimpleDiffEq, StaticArrays, Test
2+
3+
# Test interface compatibility with different number types
4+
5+
function decay(u, p, t)
6+
return -u
7+
end
8+
9+
function decay!(du, u, p, t)
10+
du .= -u
11+
return nothing
12+
end
13+
14+
@testset "BigFloat scalar support" begin
15+
u0 = BigFloat(1.0)
16+
tspan = (BigFloat(0.0), BigFloat(1.0))
17+
dt = BigFloat(0.01)
18+
prob = ODEProblem{false}(decay, u0, tspan, nothing)
19+
20+
sol = solve(prob, SimpleEuler(), dt = dt)
21+
@test eltype(sol.u) == BigFloat
22+
23+
sol = solve(prob, SimpleRK4(), dt = dt)
24+
@test eltype(sol.u) == BigFloat
25+
26+
sol = solve(prob, SimpleTsit5(), dt = dt)
27+
@test eltype(sol.u) == BigFloat
28+
29+
sol = solve(prob, SimpleATsit5(), dt = dt)
30+
@test eltype(sol.u) == BigFloat
31+
end
32+
33+
@testset "BigFloat Vector support (OOP)" begin
34+
u0 = BigFloat[1.0, 2.0, 3.0]
35+
tspan = (BigFloat(0.0), BigFloat(1.0))
36+
dt = BigFloat(0.01)
37+
prob = ODEProblem{false}(decay, u0, tspan, nothing)
38+
39+
sol = solve(prob, SimpleEuler(), dt = dt)
40+
@test eltype(sol.u[end]) == BigFloat
41+
42+
sol = solve(prob, SimpleRK4(), dt = dt)
43+
@test eltype(sol.u[end]) == BigFloat
44+
45+
sol = solve(prob, SimpleTsit5(), dt = dt)
46+
@test eltype(sol.u[end]) == BigFloat
47+
end
48+
49+
@testset "BigFloat Vector support (IIP)" begin
50+
u0 = BigFloat[1.0, 2.0, 3.0]
51+
tspan = (BigFloat(0.0), BigFloat(1.0))
52+
dt = BigFloat(0.01)
53+
prob = ODEProblem{true}(decay!, u0, tspan, nothing)
54+
55+
sol = solve(prob, SimpleEuler(), dt = dt)
56+
@test eltype(sol.u[end]) == BigFloat
57+
58+
sol = solve(prob, SimpleRK4(), dt = dt)
59+
@test eltype(sol.u[end]) == BigFloat
60+
61+
sol = solve(prob, SimpleTsit5(), dt = dt)
62+
@test eltype(sol.u[end]) == BigFloat
63+
end
64+
65+
@testset "SVector{BigFloat} support" begin
66+
u0 = SVector{3, BigFloat}(BigFloat(1.0), BigFloat(2.0), BigFloat(3.0))
67+
tspan = (BigFloat(0.0), BigFloat(1.0))
68+
dt = BigFloat(0.01)
69+
prob = ODEProblem{false}(decay, u0, tspan, nothing)
70+
71+
sol = solve(prob, SimpleEuler(), dt = dt)
72+
@test eltype(sol.u[end]) == BigFloat
73+
74+
sol = solve(prob, SimpleRK4(), dt = dt)
75+
@test eltype(sol.u[end]) == BigFloat
76+
77+
sol = solve(prob, SimpleTsit5(), dt = dt)
78+
@test eltype(sol.u[end]) == BigFloat
79+
80+
# GPU-style solvers with SVector{BigFloat}
81+
sol = solve(prob, GPUSimpleTsit5(), dt = dt)
82+
@test eltype(sol.u[end]) == BigFloat
83+
84+
sol = solve(prob, GPUSimpleATsit5(), dt = dt)
85+
@test eltype(sol.u[end]) == BigFloat
86+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ using SimpleDiffEq, SafeTestsets, Test
99
@time @safetestset "SimpleRK4 Tests" include("simplerk4_tests.jl")
1010
@time @safetestset "SimpleEuler Tests" include("simpleeuler_tests.jl")
1111
@time @safetestset "GPU Compatible ODE Tests" include("gpu_ode_regression.jl")
12+
@time @safetestset "Interface Tests" include("interface_tests.jl")
1213
end

0 commit comments

Comments
 (0)