@@ -119,6 +119,40 @@ function rrule(::typeof(hypot), z::Complex)
119119 return (Ω, hypot_pullback)
120120end
121121
122+ # Note that `hypot` with two arguments has rules in `fastmath_able.jl`
123+
124+ function frule (
125+ (_, Δx, Δy, Δz, Δxs... ),
126+ :: typeof (hypot),
127+ x:: T ,
128+ y:: T ,
129+ z:: T ,
130+ xs:: Vararg{T, N}
131+ ) where {T<: Union{Real,Complex} , N}
132+ Ω = hypot (x, y, z, xs... )
133+ n = ifelse (iszero (Ω), one (Ω), Ω)
134+ ∂Ωxyz = realdot (x, Δx) + realdot (y, Δy) + realdot (z, Δz)
135+ ∂Ωxs = sum (realdot (xi, Δxi) for (xi, Δxi) in zip (xs, Δxs); init= zero (∂Ωxyz))
136+ ∂Ω = (∂Ωxyz + ∂Ωxs) / n
137+ return Ω, ∂Ω
138+ end
139+
140+ function rrule (
141+ :: typeof (hypot),
142+ x:: T ,
143+ y:: T ,
144+ z:: T ,
145+ xs:: Vararg{T, N}
146+ ) where {T<: Union{Real,Complex} , N}
147+ Ω = hypot (x, y, z, xs... )
148+ function hypot_pullback (ΔΩ)
149+ c = real (ΔΩ) / ifelse (iszero (Ω), one (Ω), Ω)
150+ return (NoTangent (), c * x, c * y, c * z, (c * xi for xi in xs). .. )
151+ end
152+ return (Ω, hypot_pullback)
153+ end
154+
155+
122156@scalar_rule fma (x, y, z) (y, x, true )
123157@scalar_rule muladd (x, y, z) (y, x, true )
124158@scalar_rule muladd (x:: Union{Number, ZeroTangent} , y:: Union{Number, ZeroTangent} , z:: Union{Number, ZeroTangent} ) (y, x, true )
0 commit comments