Skip to content

Commit 88ad24c

Browse files
simsuracedevmotion
andauthored
Add iszero(x) branches to xlogy and xlog1py (#85)
* Add `iszero(x)` branch to `xlogy` and `xlog1py` * Treat `y` singularity * Update src/rules.jl Co-authored-by: David Widmann <[email protected]> * Update src/rules.jl Co-authored-by: David Widmann <[email protected]> * Remove whitespace * Bump version * Add tests Co-authored-by: David Widmann <[email protected]>
1 parent 2c29d7d commit 88ad24c

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DiffRules"
22
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
3-
version = "1.11.0"
3+
version = "1.11.1"
44

55
[deps]
66
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"

src/rules.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,14 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
248248
@define_diffrule LogExpFunctions.logmxp1(x) = :((1 - $x) / $x)
249249

250250
# binary
251-
@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y)
251+
@define_diffrule LogExpFunctions.xlogy(x, y) =
252+
:(log($y)),
253+
:(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z)
252254
@define_diffrule LogExpFunctions.logaddexp(x, y) =
253255
:(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y)))
254256
@define_diffrule LogExpFunctions.logsubexp(x, y) =
255257
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)),
256258
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z))
257-
258-
@define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :($x / (1 + $y))
259+
@define_diffrule LogExpFunctions.xlog1py(x, y) =
260+
:(log1p($y)),
261+
:(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z)

test/runtests.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,40 @@ for xtype in [:Float64, :BigFloat]
141141
end
142142
end
143143

144+
# Test `iszero(x)` branch of `xlogy`
145+
derivs = DiffRules.diffrule(:LogExpFunctions, :xlogy, :x, :y)
146+
for xytype in [:Float32, :Float64, :BigFloat]
147+
@eval begin
148+
let
149+
x = zero($xytype)
150+
y = rand($xytype)
151+
dx, dy = $(derivs[1]), $(derivs[2])
152+
@test iszero(dy)
153+
154+
y = one($xytype)
155+
dx, dy = $(derivs[1]), $(derivs[2])
156+
@test iszero(dy)
157+
end
158+
end
159+
end
160+
161+
# Test `iszero(x)` branch of `xlog1py`
162+
derivs = DiffRules.diffrule(:LogExpFunctions, :xlog1py, :x, :y)
163+
for xytype in [:Float32, :Float64, :BigFloat]
164+
@eval begin
165+
let
166+
x = zero($xytype)
167+
y = rand($xytype)
168+
dx, dy = $(derivs[1]), $(derivs[2])
169+
@test iszero(dy)
170+
171+
y = -one($xytype)
172+
dx, dy = $(derivs[1]), $(derivs[2])
173+
@test iszero(dy)
174+
end
175+
end
176+
end
177+
144178
end
145179

146180
@testset "diffrules" begin

0 commit comments

Comments
 (0)