Skip to content

Commit 336ac96

Browse files
Merge pull request #94 from ChrisRackauckas-Claude/interface-check-20260101-030232
Add JLArray tests for GPU-like array interface compliance
2 parents e9d7815 + bc9e9ea commit 336ac96

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,20 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414

1515
[compat]
1616
DiffEqBase = "6.122"
17+
JLArrays = "0.1, 0.2, 0.3"
1718
MuladdMacro = "0.2"
19+
OrdinaryDiffEq = "6"
1820
Parameters = "0.12"
1921
RecursiveArrayTools = "2, 3"
2022
Reexport = "0.2, 1.0"
2123
StaticArrays = "0.10, 0.11, 0.12, 1.0"
2224
julia = "1.6"
2325

2426
[extras]
27+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2528
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2629
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2730
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2831

2932
[targets]
30-
test = ["OrdinaryDiffEq", "SafeTestsets", "Test"]
33+
test = ["JLArrays", "OrdinaryDiffEq", "SafeTestsets", "Test"]

test/interface_tests.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SimpleDiffEq, StaticArrays, Test
1+
using SimpleDiffEq, StaticArrays, JLArrays, Test
22

33
# Test interface compatibility with different number types
44

@@ -84,3 +84,50 @@ end
8484
sol = solve(prob, GPUSimpleATsit5(), dt = dt)
8585
@test eltype(sol.u[end]) == BigFloat
8686
end
87+
88+
# Test JLArray support for GPU-like array interface compliance
89+
# JLArrays provide a GPU-like array that catches interface violations
90+
# such as improper scalar indexing or type hardcoding
91+
92+
@testset "JLArray support (OOP)" begin
93+
u0 = JLArray([1.0, 2.0, 3.0])
94+
tspan = (0.0, 1.0)
95+
dt = 0.01
96+
prob = ODEProblem{false}(decay, u0, tspan, nothing)
97+
98+
# Fixed step solvers
99+
sol = solve(prob, SimpleEuler(), dt = dt)
100+
@test sol.u[end] isa JLArray
101+
102+
sol = solve(prob, SimpleRK4(), dt = dt)
103+
@test sol.u[end] isa JLArray
104+
105+
sol = solve(prob, SimpleTsit5(), dt = dt)
106+
@test sol.u[end] isa JLArray
107+
108+
# Adaptive solver
109+
sol = solve(prob, SimpleATsit5(), dt = dt)
110+
@test sol.u[end] isa JLArray
111+
112+
# GPU-optimized solvers
113+
sol = solve(prob, GPUSimpleTsit5(), dt = dt)
114+
@test sol.u[end] isa JLArray
115+
116+
sol = solve(prob, GPUSimpleATsit5(), dt = dt)
117+
@test sol.u[end] isa JLArray
118+
end
119+
120+
@testset "JLArray scalar support (OOP)" begin
121+
# Test with scalar wrapped in JLArray-compatible manner
122+
# GPU solvers should handle scalar problems correctly
123+
u0 = 1.0
124+
tspan = (0.0, 1.0)
125+
dt = 0.01
126+
prob = ODEProblem{false}(decay, u0, tspan, nothing)
127+
128+
sol = solve(prob, GPUSimpleTsit5(), dt = dt)
129+
@test eltype(sol.u) == Float64
130+
131+
sol = solve(prob, GPUSimpleATsit5(), dt = dt)
132+
@test eltype(sol.u) == Float64
133+
end

0 commit comments

Comments
 (0)