@@ -618,10 +618,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number,
618618 # derivatives
619619 _a, _b, _x = map (float, promote (a, b, x))
620620 _, dIa, dIb, dIx = _beta_inc_grad (_a, _b, _x)
621- _Δa = Δa isa Real ? Δa : zero (Δa)
622- _Δb = Δb isa Real ? Δb : zero (Δb)
623- _Δx = Δx isa Real ? Δx : zero (Δx)
624- Δp = dIa * _Δa + dIb * _Δb + dIx * _Δx
621+ Δp = muladd (dIx, Δx, muladd (dIb, Δb, dIa * Δa))
625622 Δq = - Δp
626623 Tout = typeof ((p, q))
627624 return (p, q), ChainRulesCore. Tangent {Tout} (Δp, Δq)
@@ -648,11 +645,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Nu
648645 p, q = beta_inc (a, b, x, y)
649646 _a, _b, _x, _y = map (float, promote (a, b, x, y))
650647 _, dIa, dIb, dIx = _beta_inc_grad (_a, _b, _x)
651- _Δa = Δa isa Real ? Δa : zero (Δa)
652- _Δb = Δb isa Real ? Δb : zero (Δb)
653- _Δx = Δx isa Real ? Δx : zero (Δx)
654- _Δy = Δy isa Real ? Δy : zero (Δy)
655- Δp = dIa * _Δa + dIb * _Δb + dIx * (_Δx - _Δy)
648+ Δp = muladd (dIx, Δx, muladd (- dIx, Δy, muladd (dIb, Δb, dIa * Δa)))
656649 Δq = - Δp
657650 Tout = typeof ((p, q))
658651 return (p, q), ChainRulesCore. Tangent {Tout} (Δp, Δq)
@@ -690,10 +683,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Num
690683 dx_da = - dIa * inv_dIx
691684 dx_db = - dIb * inv_dIx
692685 dx_dp = inv_dIx
693- _Δa = Δa isa Real ? Δa : zero (Δa)
694- _Δb = Δb isa Real ? Δb : zero (Δb)
695- _Δp = Δp isa Real ? Δp : zero (Δp)
696- Δx = dx_da * _Δa + dx_db * _Δb + dx_dp * _Δp
686+ Δx = muladd (dx_dp, Δp, muladd (dx_db, Δb, dx_da * Δa))
697687 Δy = - Δx
698688 Tout = typeof ((x, y))
699689 return (x, y), ChainRulesCore. Tangent {Tout} (Δx, Δy)
0 commit comments