Skip to content

Commit 8ed70d8

Browse files
authored
Sparse vector polynomial (#536)
* WIP, add SparseVectorPolynomial * add SparseVector container * SparseVectorPolynomial needs zero(T) * NaN poisons, Inf propogates. No good rational, but ...
1 parent a368302 commit 8ed70d8

9 files changed

+364
-28
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@ name = "Polynomials"
22
uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
33
license = "MIT"
44
author = "JuliaMath"
5-
version = "4.0.2"
5+
version = "4.0.3"
66

77
[deps]
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
1010
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1111
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
12+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1213

1314
[weakdeps]
1415
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/Polynomials.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ module Polynomials
44
using LinearAlgebra
55
import Base: evalpoly
66
using Setfield
7+
using SparseArrays
78

89
include("abstract.jl")
910
include("show.jl")
@@ -20,8 +21,9 @@ include("polynomial-container-types/mutable-dense-view-polynomial.jl")
2021
include("polynomial-container-types/mutable-dense-laurent-polynomial.jl")
2122
include("polynomial-container-types/immutable-dense-polynomial.jl")
2223
include("polynomial-container-types/mutable-sparse-polynomial.jl")
24+
include("polynomial-container-types/mutable-sparse-vector-polynomial.jl")
2325
const PolynomialContainerTypes = (:MutableDensePolynomial, :MutableDenseViewPolynomial, :ImmutableDensePolynomial,
24-
:MutableDenseLaurentPolynomial, :MutableSparsePolynomial) # useful for some purposes
26+
:MutableDenseLaurentPolynomial, :MutableSparsePolynomial, :MutableSparseVectorPolynomial) # useful for some purposes
2527
const ZeroBasedDensePolynomialContainerTypes = (:MutableDensePolynomial, :MutableDenseViewPolynomial, :ImmutableDensePolynomial)
2628

2729
include("polynomials/standard-basis/standard-basis.jl")
@@ -30,6 +32,7 @@ include("polynomials/standard-basis/pn-polynomial.jl")
3032
include("polynomials/standard-basis/laurent-polynomial.jl")
3133
include("polynomials/standard-basis/immutable-polynomial.jl")
3234
include("polynomials/standard-basis/sparse-polynomial.jl")
35+
include("polynomials/standard-basis/sparse-vector-polynomial.jl")
3336

3437
include("polynomials/ngcd.jl")
3538
include("polynomials/multroot.jl")

src/common.jl

+11-9
Original file line numberDiff line numberDiff line change
@@ -311,15 +311,17 @@ end
311311
312312
In-place version of [`truncate`](@ref)
313313
"""
314-
function truncate!(p::AbstractPolynomial{T};
314+
truncate!(p::AbstractPolynomial; kwargs...) = _truncate!(p; kwargs...)
315+
316+
function _truncate!(p::AbstractPolynomial{T};
315317
rtol::Real = Base.rtoldefault(real(T)),
316318
atol::Real = 0,) where {T}
317-
truncate!(p.coeffs, rtol=rtol, atol=atol)
319+
_truncate!(p.coeffs, rtol=rtol, atol=atol)
318320
chop!(p, rtol = rtol, atol = atol)
319321
end
320322

321-
## truncate! underlying storage type
322-
function truncate!(ps::Vector{T};
323+
## _truncate! underlying storage type
324+
function _truncate!(ps::Vector{T};
323325
rtol::Real = Base.rtoldefault(real(T)),
324326
atol::Real = 0,) where {T}
325327
max_coeff = norm(ps, Inf)
@@ -332,7 +334,7 @@ function truncate!(ps::Vector{T};
332334
nothing
333335
end
334336

335-
function truncate!(ps::Dict{S,T};
337+
function _truncate!(ps::Dict{S,T};
336338
rtol::Real = Base.rtoldefault(real(T)),
337339
atol::Real = 0,) where {S,T}
338340

@@ -348,7 +350,7 @@ function truncate!(ps::Dict{S,T};
348350
nothing
349351
end
350352

351-
truncate!(ps::NTuple; kwargs...) = throw(ArgumentError("`truncate!` not defined."))
353+
_truncate!(ps::NTuple; kwargs...) = throw(ArgumentError("`truncate!` not defined for tuples."))
352354

353355
# _truncate(ps::NTuple{0}; kwargs...) = ps
354356
# function _truncate(ps::NTuple{N,T};
@@ -370,7 +372,7 @@ Rounds off coefficients close to zero, as determined by `rtol` and `atol`, and t
370372
function Base.truncate(p::AbstractPolynomial{T};
371373
rtol::Real = Base.rtoldefault(real(T)),
372374
atol::Real = 0,) where {T}
373-
truncate!(deepcopy(p), rtol = rtol, atol = atol)
375+
_truncate!(deepcopy(p), rtol = rtol, atol = atol)
374376
end
375377

376378
"""
@@ -455,8 +457,8 @@ end
455457

456458

457459
# for generic usage, as immutable types are not mutable
458-
chop!!(p::AbstractPolynomial; kwargs...) = (p = chop!(p); p)
459-
truncate!!(p::AbstractPolynomial; kwargs...) = truncate!(p)
460+
chop!!(p::AbstractPolynomial; kwargs...) = (p = chop!(p; kwargs...); p)
461+
truncate!!(p::AbstractPolynomial; kwargs...) = _truncate!(p; kwargs...)
460462

461463
## --------------------------------------------------
462464

src/polynomial-container-types/immutable-dense-polynomial.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ end
128128
chop!(p::ImmutableDensePolynomial; kwargs...) = chop(p; kwargs...) ## misnamed, should be chop!!
129129
chop!!(p::ImmutableDensePolynomial; kwargs...) = chop(p; kwargs...)
130130

131-
function truncate!(p::ImmutableDensePolynomial{B,T,X,N};
131+
function _truncate!(p::ImmutableDensePolynomial{B,T,X,N};
132132
rtol::Real = Base.rtoldefault(real(T)),
133133
atol::Real = 0) where {B,T,X,N}
134134

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""
2+
MutableSparseVectorPolynomial{B,T,X}
3+
4+
This polynomial type uses an `SparseVector{T,Int}` to store the coefficients of a polynomial relative to the basis `B` with indeterminate `X`.
5+
The type `T` should have `zero(T)` defined.
6+
7+
8+
"""
9+
struct MutableSparseVectorPolynomial{B,T,X} <: AbstractUnivariatePolynomial{B, T,X}
10+
coeffs::SparseVector{T, Int}
11+
function MutableSparseVectorPolynomial{B,T,X}(cs::SparseVector{S,Int}, order::Int=0) where {B,T,S,X}
12+
new{B,T,Symbol(X)}(cs)
13+
end
14+
end
15+
16+
MutableSparseVectorPolynomial{B,T,X}(check::Val{:false}, coeffs::SparseVector{Int,S}) where {B,T,S,X} =
17+
MutableSparseVectorPolynomial{B,T,X}(coeffs)
18+
MutableSparseVectorPolynomial{B,T,X}(checked::Val{:true}, coeffs::SparseVector{Int,T}) where {B,T,X<:Symbol} =
19+
MutableSparseVectorPolynomial{B,T,X}(coeffs)
20+
21+
# ---
22+
function MutableSparseVectorPolynomial{B,T}(coeffs::SparseVector{S,Int}, var::SymbolLike=Var(:x)) where {B,T,S}
23+
MutableSparseVectorPolynomial{B,T,Symbol(var)}(coeffs)
24+
end
25+
26+
function MutableSparseVectorPolynomial{B}(cs::SparseVector{T,Int}, var::SymbolLike=Var(:x)) where {B,T}
27+
MutableSparseVectorPolynomial{B,T,Symbol(var)}(cs)
28+
end
29+
30+
# From a Dictionary
31+
function MutableSparseVectorPolynomial{B,X}(cs::AbstractDict{Int, T}) where {B,T,X}
32+
N = maximum(keys(cs)) + 1
33+
v = SparseVector(N, 1 .+ keys(cs), collect(values(cs)))
34+
MutableSparseVectorPolynomial{B,T,X}(v)
35+
end
36+
37+
function MutableSparseVectorPolynomial{B}(cs::AbstractDict{Int, T}, var::SymbolLike=Var(:x)) where {B,T}
38+
MutableSparseVectorPolynomial{B,Symbol(var)}(cs)
39+
end
40+
41+
42+
# abstract vector has order/symbol
43+
function MutableSparseVectorPolynomial{B,T,X}(coeffs::AbstractVector{S}, order::Int=0) where {B,T,S,X}
44+
if Base.has_offset_axes(coeffs)
45+
@warn "ignoring the axis offset of the coefficient vector"
46+
coeffs = parent(coeffs)
47+
end
48+
49+
MutableSparseVectorPolynomial{B,T,X}(convert(SparseVector, coeffs))
50+
end
51+
52+
53+
# # cs iterable of pairs; ensuring tight value of T
54+
# function MutableSparseVectorPolynomial{B}(cs::Tuple, var::SymbolLike=:x) where {B}
55+
# isempty(cs) && throw(ArgumentError("No type attached"))
56+
# X = Var(var)
57+
# if length(cs) == 1
58+
# c = only(cs)
59+
# d = Dict(first(c) => last(c))
60+
# T = eltype(last(c))
61+
# return MutableSparseVectorPolynomial{B,T,X}(d)
62+
# else
63+
# c₁, c... = cs
64+
# T = typeof(last(c₁))
65+
# for b ∈ c
66+
# T = promote_type(T, typeof(b))
67+
# end
68+
# ks = 0:length(cs)-1
69+
# vs = cs
70+
# d = Dict{Int,T}(Base.Generator(=>, ks, vs))
71+
# return MutableSparseVectorPolynomial{B,T,X}(d)
72+
# end
73+
# end
74+
75+
constructorof(::Type{<:MutableSparseVectorPolynomial{B}}) where {B <: AbstractBasis} = MutableSparseVectorPolynomial{B}
76+
@poly_register MutableSparseVectorPolynomial
77+
78+
function Base.map(fn, p::P, args...) where {B,T,X, P<:MutableSparseVectorPolynomial{B,T,X}}
79+
xs = map(fn, p.coeffs)
80+
R = eltype(xs)
81+
return MutableSparseVectorPolynomial{B, R, X}(xs)
82+
end
83+
84+
function Base.map!(fn, q::Q, p::P, args...) where {B,T,X, P<:MutableSparseVectorPolynomial{B,T,X},S,Q<:MutableSparseVectorPolynomial{B,S,X}}
85+
map!(fn, p.coeffs, p.coeffs)
86+
nothing
87+
end
88+
89+
## ---
90+
Base.collect(p::MutableSparseVectorPolynomial) = collect(p.coeffs)
91+
Base.collect(::Type{T}, p::MutableSparseVectorPolynomial) where {T} = collect(T, p.coeffs)
92+
minimumexponent(::Type{<:MutableSparseVectorPolynomial}) = 0
93+
94+
Base.length(p::MutableSparseVectorPolynomial) = length(p.coeffs)
95+
96+
function degree(p::MutableSparseVectorPolynomial)
97+
idx = findall(!iszero, p.coeffs)
98+
isempty(idx) && return -1
99+
n = maximum(idx)
100+
n - 1
101+
end
102+
103+
Base.copy(p::MutableSparseVectorPolynomial{B,T,X}) where {B,T,X} = MutableSparseVectorPolynomial{B,T,X}(copy(p.coeffs))
104+
105+
function Base.convert(::Type{MutableSparseVectorPolynomial{B,T,X}}, p::MutableSparseVectorPolynomial{B,S,X}) where {B,T,S,X}
106+
cs = convert(SparseVector{T,Int}, p.coeffs)
107+
MutableSparseVectorPolynomial{B,T,X}(cs)
108+
end
109+
110+
function Base.:(==)(p1::P, p2::P) where {P <: MutableSparseVectorPolynomial}
111+
iszero(p1) && iszero(p2) && return true
112+
113+
ks1 = findall(!iszero, p1.coeffs)
114+
ks2 = findall(!iszero, p2.coeffs)
115+
length(ks1) == length(ks2) || return false
116+
idx = sortperm(ks1)
117+
for i idx
118+
ks1[i] == ks2[i] || return false
119+
p1.coeffs[ks1[i]] == p2.coeffs[ks2[i]] || return false
120+
end
121+
122+
return true
123+
# # eachindex(p1) == eachindex(p2) || return false
124+
# # coeffs(p1) == coeffs(p2), but non-allocating
125+
# p1val = (p1[i] for i in eachindex(p1))
126+
# p2val = (p2[i] for i in eachindex(p2))
127+
# all(((a,b),) -> a == b, zip(p1val, p2val))
128+
end
129+
130+
# ---
131+
132+
Base.firstindex(p::MutableSparseVectorPolynomial) = 0
133+
function Base.lastindex(p::MutableSparseVectorPolynomial)
134+
isempty(p.coeffs) && return 0
135+
maximum(keys(p.coeffs))
136+
end
137+
138+
function Base.getindex(p::MutableSparseVectorPolynomial{B,T,X}, i::Int) where {B,T,X}
139+
get(p.coeffs, i + 1, zero(T))
140+
end
141+
142+
# errors if extending
143+
function Base.setindex!(p::MutableSparseVectorPolynomial{B,T,X}, value, i::Int) where {B,T,X}
144+
p.coeffs[i+1] = value
145+
end
146+
147+
148+
function Base.pairs(p::MutableSparseVectorPolynomial)
149+
ks, vs = findnz(p.coeffs)
150+
idx = sortperm(ks) # guarantee order here
151+
Base.Generator(=>, ks[idx] .- 1, vs)
152+
end
153+
Base.keys(p::MutableSparseVectorPolynomial) = Base.Generator(first, pairs(p))
154+
Base.values(p::MutableSparseVectorPolynomial) = Base.Generator(last, pairs(p))
155+
156+
basis(P::Type{<:MutableSparseVectorPolynomial{B, T, X}}, i::Int) where {B,T,X} = P(SparseVector(1+i, [i+1], [1]))
157+
158+
# return coeffs as a vector
159+
function coeffs(p::MutableSparseVectorPolynomial{B,T}) where {B,T}
160+
d = degree(p)
161+
ps = p.coeffs
162+
[ps[i] for i 1:(d+1)]
163+
end
164+
165+
166+
hasnan(p::MutableSparseVectorPolynomial) = any(hasnan, values(p.coeffs))::Bool
167+
168+
169+
offset(p::MutableSparseVectorPolynomial) = 1
170+
171+
function keys_union(p::MutableSparseVectorPolynomial, q::MutableSparseVectorPolynomial)
172+
# IterTools.distinct(Base.Iterators.flatten((keys(p), keys(q)))) may allocate less
173+
unique(Base.Iterators.flatten((keys(p), keys(q))))
174+
end
175+
176+
177+
178+
## ---
179+
180+
chop_exact_zeros!(d::SparseVector{T, Int}) where {T} = d
181+
182+
183+
function _truncate!(v::SparseVector{T,X};
184+
rtol::Real = Base.rtoldefault(real(T)),
185+
atol::Real = 0) where {T,X}
186+
isempty(v) && return v
187+
δ = something(rtol,0)
188+
ϵ = something(atol,0)
189+
τ = max(ϵ, norm(values(v),2) * δ)
190+
for (i,pᵢ) pairs(v)
191+
abs(pᵢ) τ && (v[i] = zero(T))
192+
end
193+
v
194+
end
195+
196+
197+
chop!(p::MutableSparseVectorPolynomial; kwargs...) = (chop!(p.coeffs; kwargs...); p)
198+
function chop!(d::SparseVector{T, Int}; atol=nothing, rtol=nothing) where {T}
199+
isempty(d) && return d
200+
δ = something(rtol,0)
201+
ϵ = something(atol,0)
202+
τ = max(ϵ, norm(values(d),2) * δ)
203+
for (i, pᵢ) Base.Iterators.reverse(pairs(d))
204+
abs(pᵢ) τ && break
205+
d[i] = zero(T)
206+
end
207+
d
208+
end
209+
210+
## ---
211+
212+
_zeros(::Type{MutableSparseVectorPolynomial{B,T,X}}, z::S, N) where {B,T,X,S} = zeros(T, N)
213+
214+
Base.zero(::Type{MutableSparseVectorPolynomial{B,T,X}}) where {B,T,X} = MutableSparseVectorPolynomial{B,T,X}(spzeros(T,0))
215+
216+
## ---
217+
218+
function isconstant(p::MutableSparseVectorPolynomial)
219+
degree(p) <= 0
220+
end
221+
222+
Base.:+(p::MutableSparseVectorPolynomial{B,T,X}, q::MutableSparseVectorPolynomial{B,S,X}) where{B,X,T,S} =
223+
_sparse_vector_combine(+, p, q)
224+
Base.:-(p::MutableSparseVectorPolynomial{B,T,X}, q::MutableSparseVectorPolynomial{B,S,X}) where{B,X,T,S} =
225+
_sparse_vector_combine(-, p, q)
226+
227+
# embed into bigger vector
228+
function _embed(v::SparseVector{T, Int}, l) where {T}
229+
l == length(v) && return v
230+
ks,vs = findnz(v)
231+
SparseVector(l, ks, vs)
232+
end
233+
234+
235+
function _sparse_vector_combine(op, p::MutableSparseVectorPolynomial{B,T,X}, q::MutableSparseVectorPolynomial{B,S,X}) where{B,X,T,S}
236+
R = promote_type(T,S)
237+
ps, qs = p.coeffs, q.coeffs
238+
m = max(length(ps), length(qs))
239+
ps′, qs′ = _embed(ps, m), _embed(qs, m)
240+
cs = op(ps′, qs′)
241+
MutableSparseVectorPolynomial{B,R,X}(cs)
242+
end

src/polynomials/standard-basis/sparse-polynomial.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ export SparsePolynomial
4545

4646
_typealias(::Type{P}) where {P<:SparsePolynomial} = "SparsePolynomial"
4747

48-
function evalpoly(x, p::MutableSparsePolynomial)
48+
function evalpoly(x, p::SparsePolynomial)
4949
tot = zero(p[0]*x)
50-
for (i, cᵢ) p.coeffs
50+
for (i, cᵢ) pairs(p)
5151
tot = muladd(cᵢ, x^i, tot)
5252
end
5353
return tot

0 commit comments

Comments
 (0)