Skip to content

Commit 4603cc1

Browse files
authored
Using a strong zero in Dual(0.0, 1)^0 to avoid NaN (#84)
* test and fix for nan epsilon of dual^UInt64(0) * "0.6.3" -> "0.6.4" * removed NaNs, codified behaviour with Ints * improve testing, remove type instability * handle Dual(Integer(1), n)^Integer * version semver bump 0.6.4 -> 0.6.5
1 parent a1128cf commit 4603cc1

File tree

3 files changed

+69
-14
lines changed

3 files changed

+69
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DualNumbers"
22
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
3-
version = "0.6.4"
3+
version = "0.6.5"
44

55
[deps]
66
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"

src/dual.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -245,29 +245,40 @@ Base.:/(z::Number, w::Dual) = Dual(z/value(w), -z*epsilon(w)/value(w)^2)
245245
Base.:/(z::Dual, x::Number) = Dual(value(z)/x, epsilon(z)/x)
246246

247247
for f in [:(Base.:^), :(NaNMath.pow)]
248-
@eval function ($f)(z::Dual, w::Dual)
249-
if epsilon(w) == 0.0
250-
return $f(z, value(w))
251-
end
248+
@eval function ($f)(z::Dual{T1}, w::Dual{T2}) where {T1, T2}
249+
T = promote_type(T1, T2) # for type stability in ? : statements
252250
val = $f(value(z), value(w))
253251

254-
du = epsilon(z) * value(w) * $f(value(z), value(w) - 1) +
255-
epsilon(w) * $f(value(z), value(w)) * log(value(z))
252+
ezvw = epsilon(z) * value(w) # for using in ? : statement
253+
du1 = iszero(ezvw) ? zero(T) : ezvw * $f(value(z), value(w) - 1)
254+
ew = epsilon(w) # for using in ? : statement
255+
# the float is for type stability because log promotes to floats
256+
du2 = iszero(ew) ? zero(float(T)) : ew * val * log(value(z))
257+
du = du1 + du2
256258

257259
Dual(val, du)
258260
end
259261
end
260262

261263
Base.mod(z::Dual, n::Number) = Dual(mod(value(z), n), epsilon(z))
262264

263-
# these two definitions are needed to fix ambiguity warnings
264-
Base.:^(z::Dual, n::Unsigned) = z^Signed(n)
265-
Base.:^(z::Dual, n::Integer) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
266-
Base.:^(z::Dual, n::Rational) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
265+
# introduce a boolean !iszero(n) for hard zero behaviour to combat NaNs
266+
function pow(z::Dual, n::AbstractFloat)
267+
return Dual(value(z)^n, !iszero(n) * (epsilon(z) * n * value(z)^(n - 1)))
268+
end
269+
function pow(z::Dual{T}, n::Integer) where T
270+
iszero(n) && return Dual(one(T), zero(T)) # avoid DomainError Int^(negative Int)
271+
isone(z) && return Dual(one(T), epsilon(z) * n)
272+
return Dual(value(z)^n, epsilon(z) * n * value(z)^(n - 1))
273+
end
274+
# these first two definitions are needed to fix ambiguity warnings
275+
for T1 (:Integer, :Rational, :Number)
276+
@eval Base.:^(z::Dual{T}, n::$T1) where T = pow(z, n)
277+
end
278+
267279

268-
Base.:^(z::Dual, n::Number) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
269-
NaNMath.pow(z::Dual, n::Number) = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
270-
NaNMath.pow(z::Number, w::Dual) = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))
280+
NaNMath.pow(z::Dual{T}, n::Number) where T = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
281+
NaNMath.pow(z::Number, w::Dual{T}) where T = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))
271282

272283
Base.inv(z::Dual) = dual(inv(value(z)),-epsilon(z)/value(z)^2)
273284

test/automatic_differentiation_test.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,52 @@ y = x^3.0
1515
@test value(y) 2.0^3
1616
@test epsilon(y) 3.0*2^2
1717

18+
# taking care with divides by zero where there shouldn't be any on paper
19+
for (y, n) Iterators.product((float(x), Dual(0.0, 1)), (0, 0.0))
20+
z = y^n
21+
@test value(z) == 1
22+
@test !isnan(epsilon(z))
23+
@test epsilon(z) == 0
24+
end
25+
26+
# acting on floats works as expected
27+
for (y, n) ((float(x), Dual(0.0, 1)), -1:1)
28+
@test float(y)^n == float(y)^float(n)
29+
end
30+
31+
@test !isnan(epsilon(Dual(0, 1)^1))
32+
@test Dual(0, 1)^1 == Dual(0, 1)
33+
34+
# power_by_squaring error for integers
35+
# needs to be wrapped to make n a literal
36+
powwrap(z, n, epspart=0) = Dual(z, epspart)^n
37+
@test_throws DomainError powwrap(0, -1)
38+
@test_throws DomainError powwrap(2, -1)
39+
@test_throws DomainError powwrap(123, -1) # etc
40+
# these ones don't DomainError
41+
@test powwrap(0, 0, 0) == Dual(1, 0) # special case is handled
42+
@test powwrap(0, 0, 1) == Dual(1, 0) # special case is handled
43+
@test powwrap(1, -1) == powwrap(1.0, -1) # special case is handled
44+
@test powwrap(1, -2) == powwrap(1.0, -2) # special case is handled
45+
@test powwrap(1, -123) == powwrap(1.0, -123) # special case is handled
46+
@test powwrap(1, 0) == Dual(1, 1)
47+
@test powwrap(123, 0) == Dual(1, 1)
48+
for i -3:3
49+
@test powwrap(1, i) == Dual(1, i)
50+
end
51+
52+
# this no longer throws 1/0 DomainError
53+
@test powwrap(0, Dual(0, 1)) == Dual(1, 0)
54+
# this never did DomainError because it starts off with a float
55+
@test 0.0^Dual(0, 1) == Dual(1.0, NaN)
56+
# and Dual^Dual uses a log and is now type stable
57+
# because the log promotes ints to floats for all values
58+
@test typeof(value(powwrap(0, Dual(0, 1)))) == Float64
59+
@test Dual(0, 1)^Dual(0, 1) == Dual(1, 0)
60+
1861
y = Dual(2.0, 1)^UInt64(0)
1962
@test !isnan(epsilon(y))
63+
@test epsilon(y) == 0
2064

2165
y = sin(x)+exp(x)
2266
@test value(y) sin(2)+exp(2)

0 commit comments

Comments
 (0)