|
1 | | -using SimpleDiffEq, StaticArrays, Test |
| 1 | +using SimpleDiffEq, StaticArrays, JLArrays, Test |
2 | 2 |
|
3 | 3 | # Test interface compatibility with different number types |
4 | 4 |
|
|
84 | 84 | sol = solve(prob, GPUSimpleATsit5(), dt = dt) |
85 | 85 | @test eltype(sol.u[end]) == BigFloat |
86 | 86 | end |
| 87 | + |
| 88 | +# Test JLArray support for GPU-like array interface compliance |
| 89 | +# JLArrays provide a GPU-like array that catches interface violations |
| 90 | +# such as improper scalar indexing or type hardcoding |
| 91 | + |
| 92 | +@testset "JLArray support (OOP)" begin |
| 93 | + u0 = JLArray([1.0, 2.0, 3.0]) |
| 94 | + tspan = (0.0, 1.0) |
| 95 | + dt = 0.01 |
| 96 | + prob = ODEProblem{false}(decay, u0, tspan, nothing) |
| 97 | + |
| 98 | + # Fixed step solvers |
| 99 | + sol = solve(prob, SimpleEuler(), dt = dt) |
| 100 | + @test sol.u[end] isa JLArray |
| 101 | + |
| 102 | + sol = solve(prob, SimpleRK4(), dt = dt) |
| 103 | + @test sol.u[end] isa JLArray |
| 104 | + |
| 105 | + sol = solve(prob, SimpleTsit5(), dt = dt) |
| 106 | + @test sol.u[end] isa JLArray |
| 107 | + |
| 108 | + # Adaptive solver |
| 109 | + sol = solve(prob, SimpleATsit5(), dt = dt) |
| 110 | + @test sol.u[end] isa JLArray |
| 111 | + |
| 112 | + # GPU-optimized solvers |
| 113 | + sol = solve(prob, GPUSimpleTsit5(), dt = dt) |
| 114 | + @test sol.u[end] isa JLArray |
| 115 | + |
| 116 | + sol = solve(prob, GPUSimpleATsit5(), dt = dt) |
| 117 | + @test sol.u[end] isa JLArray |
| 118 | +end |
| 119 | + |
| 120 | +@testset "JLArray scalar support (OOP)" begin |
| 121 | + # Test with scalar wrapped in JLArray-compatible manner |
| 122 | + # GPU solvers should handle scalar problems correctly |
| 123 | + u0 = 1.0 |
| 124 | + tspan = (0.0, 1.0) |
| 125 | + dt = 0.01 |
| 126 | + prob = ODEProblem{false}(decay, u0, tspan, nothing) |
| 127 | + |
| 128 | + sol = solve(prob, GPUSimpleTsit5(), dt = dt) |
| 129 | + @test eltype(sol.u) == Float64 |
| 130 | + |
| 131 | + sol = solve(prob, GPUSimpleATsit5(), dt = dt) |
| 132 | + @test eltype(sol.u) == Float64 |
| 133 | +end |
0 commit comments