From e6d076a128f7ba1daab214e650cbd81801cb05a0 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 10 Aug 2022 21:07:10 +0800 Subject: [PATCH] Fix possible stack-overflow within broadcast. (#510) * Fix possible stack-overflow within broadcast. * Bump * Fix test on 1.0 --- Project.toml | 2 +- src/gpu_support.jl | 1 + test/issues/runtests.jl | 10 ++++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dc812646..ecf98a5e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Interpolations" uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -version = "0.14.3" +version = "0.14.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/gpu_support.jl b/src/gpu_support.jl index 4990f5bd..d9f7f7ea 100644 --- a/src/gpu_support.jl +++ b/src/gpu_support.jl @@ -56,6 +56,7 @@ end This function returns the type of the root cofficients array of an `AbstractInterpolation`. Some array wrappers, like `OffsetArray`, should be skipped. """ +root_storage_type(::Type{T}) where {T<:AbstractInterpolation} = Array{eltype(T),ndims(T)} # fallback to `Array` by default. root_storage_type(::Type{T}) where {T<:Extrapolation} = root_storage_type(fieldtype(T, 1)) root_storage_type(::Type{T}) where {T<:ScaledInterpolation} = root_storage_type(fieldtype(T, 1)) root_storage_type(::Type{T}) where {T<:BSplineInterpolation} = root_storage_type(fieldtype(T, 1)) diff --git a/test/issues/runtests.jl b/test/issues/runtests.jl index 209adb72..10223a6b 100644 --- a/test/issues/runtests.jl +++ b/test/issues/runtests.jl @@ -167,4 +167,14 @@ using Interpolations, Test, ForwardDiff @test_throws ErrorException Interpolations.symsize(Val(2)) @test_throws ErrorException Interpolations.symsize(Val(33)) end + @testset "issue 509" begin + @eval struct CPUITP{T,N} <: AbstractInterpolation{T,N,NTuple{N,NoInterp}} end + @test Broadcast.BroadcastStyle(CPUITP{Int,2}) == Broadcast.BroadcastStyle(Matrix{Int}) + # example in #509 + percentile_values = [0.0, 0.01, 0.1, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 99.9, 99.99, 100.0] + y = sort(randn(length(percentile_values))) + itp_cdf = extrapolate(interpolate(y, percentile_values, SteffenMonotonicInterpolation()), Flat()) + t = -3.0:0.01:3.0 + @test itp_cdf.(t) isa Vector{Float64} + end end