Skip to content

Commit fd22a71

Browse files
committed
Handle Base.hypot with more than 2 arguments
1 parent 0923b1e commit fd22a71

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/rulesets/Base/base.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,40 @@ function rrule(::typeof(hypot), z::Complex)
119119
return (Ω, hypot_pullback)
120120
end
121121

122+
# Note that `hypot` with two arguments has rules in `fastmath_able.jl`
123+
124+
function frule(
125+
(_, Δx, Δy, Δz, Δxs...),
126+
::typeof(hypot),
127+
x::T,
128+
y::T,
129+
z::T,
130+
xs::Vararg{T, N}
131+
) where {T<:Union{Real,Complex}, N}
132+
Ω = hypot(x, y, z, xs...)
133+
n = ifelse(iszero(Ω), one(Ω), Ω)
134+
∂Ωxyz = realdot(x, Δx) + realdot(y, Δy) + realdot(z, Δz)
135+
∂Ωxs = sum(realdot(xi, Δxi) for (xi, Δxi) in zip(xs, Δxs); init=zero(∂Ωxyz))
136+
∂Ω = (∂Ωxyz + ∂Ωxs) / n
137+
return Ω, ∂Ω
138+
end
139+
140+
function rrule(
141+
::typeof(hypot),
142+
x::T,
143+
y::T,
144+
z::T,
145+
xs::Vararg{T, N}
146+
) where {T<:Union{Real,Complex}, N}
147+
Ω = hypot(x, y, z, xs...)
148+
function hypot_pullback(ΔΩ)
149+
c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)
150+
return (NoTangent(), c * x, c * y, c * z, (c * xi for xi in xs)...)
151+
end
152+
return (Ω, hypot_pullback)
153+
end
154+
155+
122156
@scalar_rule fma(x, y, z) (y, x, true)
123157
@scalar_rule muladd(x, y, z) (y, x, true)
124158
@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true)

test/rulesets/Base/base.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,15 @@ end
263263
test_rrule(merge, (; a=1.0), (; b=2.0))
264264
test_rrule(merge, (; a=1.0), (; a=2.0))
265265
end
266+
267+
@testset "hypot(x, y, z, xs...)" begin
268+
test_frule(hypot, 1.0, 2.0, 3.0)
269+
test_rrule(hypot, 1.0, 2.0, 3.0)
270+
test_frule(hypot, 1.0, 2.0, 3.0, 4.0)
271+
test_rrule(hypot, 1.0, 2.0, 3.0, 4.0)
272+
test_frule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im)
273+
test_rrule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im)
274+
test_frule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im, 4.0+8.0im)
275+
test_rrule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im, 4.0+8.0im)
276+
end
266277
end

0 commit comments

Comments
 (0)