Skip to content

Commit 712aabd

Browse files
committed
Release constraint on real
1 parent 9c9c317 commit 712aabd

File tree

6 files changed

+17
-8
lines changed

6 files changed

+17
-8
lines changed

examples/integration.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using TaylorDiff: TaylorDiff, TaylorScalar, make_seed, flatten, get_coefficient,
22
set_coefficient, append_coefficient
33
using TaylorSeries, TaylorIntegration
44
using ODEProblemLibrary, OrdinaryDiffEq, BenchmarkTools, Symbolics
5+
SymbolicUtils.ENABLE_HASHCONSING[] = true
56

67
# There are two ways to compute the Taylor coefficients of a ODE solution
78
# 1. Using naive repeated differentiation
@@ -109,3 +110,11 @@ function simplify_array_test()
109110
fast_oop, fast_iip = build_jetcoeffs(prob.f, prob.p, Val(P), length(prob.u0))
110111
@btime $fast_oop($prob.u0, $t0)
111112
end
113+
114+
@generated function evaluate_polynomial(t::TaylorScalar{T, P}, z) where {T, P}
115+
ex = :(v[$(P + 1)])
116+
for i in P:-1:1
117+
ex = :(v[$i] + z * $ex)
118+
end
119+
return :($(Expr(:meta, :inline)); v = flatten(t); $ex)
120+
end

src/TaylorDiff.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ module TaylorDiff
44
TaylorDiff.can_taylorize(V::Type)
55
66
Determines whether the type V is allowed as the scalar type in a
7-
Dual. By default, only `<:Real` types are allowed.
7+
Dual. By default, only `<:Number` types are allowed.
88
"""
9-
can_taylorize(::Type{<:Real}) = true
9+
can_taylorize(::Type{<:Number}) = true
1010
can_taylorize(::Type) = false
1111

1212
@noinline function throw_cannot_taylorize(V::Type)

src/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Base.@propagate_inbounds function Base.setindex!(
6565
end
6666

6767
Base.@propagate_inbounds function Base.setindex!(
68-
a::TaylorArray, s::Real, i::Int...)
68+
a::TaylorArray, s, i::Int...)
6969
value(a)[i...] = s
7070
return a
7171
end

src/chainrules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...)
5757
end
5858

5959
function rrule(::typeof(*), A::AbstractMatrix{S},
60-
t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
60+
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
6161
project_A = ProjectTo(A)
6262
function gemv_pullback(x̄)
6363
= reinterpret(reshape, T, x̄)
@@ -68,7 +68,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
6868
end
6969

7070
function rrule(::typeof(*), A::AbstractMatrix{S},
71-
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
71+
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
7272
project_A = ProjectTo(A)
7373
project_B = ProjectTo(B)
7474
function gemm_pullback(x̄)

src/derivative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export derivative, derivative!, derivatives
22

33
# Added to help Zygote infer types
4-
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
4+
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Number, P} = TaylorScalar{P}(x, l)
55
@inline make_seed(x::A, l::A, p) where {A <: AbstractArray} = broadcast(make_seed, x, l, p)
66

77
"""

src/primitive.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ sincos(t::TaylorScalar) = (sin(t), cos(t))
119119
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
120120
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)
121121

122-
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)
122+
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, Complex, RoundingMode)
123123

124124
for op in [:>, :<, :(==), :(>=), :(<=)]
125125
for R in AMBIGUOUS_TYPES
@@ -164,7 +164,7 @@ end
164164
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-1}) = inv(x)
165165
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-2}) = (i = inv(x); i * i)
166166

167-
for R in (Integer, Real)
167+
for R in (Integer, Real, Complex)
168168
@eval @immutable function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
169169
f = flatten(t)
170170
v[0] = f[0]^n

0 commit comments

Comments
 (0)