Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.72.5"
version = "1.72.6"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
34 changes: 34 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,40 @@ function rrule(::typeof(hypot), z::Complex)
return (Ω, hypot_pullback)
end

# Note that `hypot` with two arguments has rules in `fastmath_able.jl`

function frule(
(_, Δx, Δy, Δz, Δxs...),
::typeof(hypot),
x::T,
y::T,
z::T,
xs::Vararg{T, N}
) where {T<:Union{Real,Complex}, N}
Ω = hypot(x, y, z, xs...)
n = ifelse(iszero(Ω), one(Ω), Ω)
∂Ωxyz = realdot(x, Δx) + realdot(y, Δy) + realdot(z, Δz)
∂Ωxs = sum(realdot(xi, Δxi) for (xi, Δxi) in zip(xs, Δxs); init=zero(∂Ωxyz))
∂Ω = (∂Ωxyz + ∂Ωxs) / n
return Ω, ∂Ω
end

function rrule(
::typeof(hypot),
x::T,
y::T,
z::T,
xs::Vararg{T, N}
) where {T<:Union{Real,Complex}, N}
Ω = hypot(x, y, z, xs...)
function hypot_pullback(ΔΩ)
c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)
return (NoTangent(), c * x, c * y, c * z, (c * xi for xi in xs)...)
end
return (Ω, hypot_pullback)
end


@scalar_rule fma(x, y, z) (y, x, true)
@scalar_rule muladd(x, y, z) (y, x, true)
@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true)
Expand Down
11 changes: 11 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,15 @@ end
test_rrule(merge, (; a=1.0), (; b=2.0))
test_rrule(merge, (; a=1.0), (; a=2.0))
end

@testset "hypot(x, y, z, xs...)" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add tests with mixed Real/Complex inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to test mixed inputs. The new tests are passing; CI failures are the same as before.

test_frule(hypot, 1.0, 2.0, 3.0)
test_rrule(hypot, 1.0, 2.0, 3.0)
test_frule(hypot, 1.0, 2.0, 3.0, 4.0)
test_rrule(hypot, 1.0, 2.0, 3.0, 4.0)
test_frule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im)
test_rrule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im)
test_frule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im, 4.0+8.0im)
test_rrule(hypot, 1.0+5.0im, 2.0+6.0im, 3.0+7.0im, 4.0+8.0im)
end
end
Loading