Skip to content

Commit d7d26d4

Browse files
committed
Move to muladd
1 parent 3cbdb3a commit d7d26d4

File tree

1 file changed

+3
-13
lines changed

1 file changed

+3
-13
lines changed

ext/SpecialFunctionsChainRulesCoreExt.jl

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -618,10 +618,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number,
618618
# derivatives
619619
_a, _b, _x = map(float, promote(a, b, x))
620620
_, dIa, dIb, dIx = _beta_inc_grad(_a, _b, _x)
621-
_Δa = Δa isa Real ? Δa : zero(Δa)
622-
_Δb = Δb isa Real ? Δb : zero(Δb)
623-
_Δx = Δx isa Real ? Δx : zero(Δx)
624-
Δp = dIa * _Δa + dIb * _Δb + dIx * _Δx
621+
Δp = muladd(dIx, Δx, muladd(dIb, Δb, dIa * Δa))
625622
Δq = -Δp
626623
Tout = typeof((p, q))
627624
return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq)
@@ -648,11 +645,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Nu
648645
p, q = beta_inc(a, b, x, y)
649646
_a, _b, _x, _y = map(float, promote(a, b, x, y))
650647
_, dIa, dIb, dIx = _beta_inc_grad(_a, _b, _x)
651-
_Δa = Δa isa Real ? Δa : zero(Δa)
652-
_Δb = Δb isa Real ? Δb : zero(Δb)
653-
_Δx = Δx isa Real ? Δx : zero(Δx)
654-
_Δy = Δy isa Real ? Δy : zero(Δy)
655-
Δp = dIa * _Δa + dIb * _Δb + dIx * (_Δx - _Δy)
648+
Δp = muladd(dIx, Δx, muladd(-dIx, Δy, muladd(dIb, Δb, dIa * Δa)))
656649
Δq = -Δp
657650
Tout = typeof((p, q))
658651
return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq)
@@ -690,10 +683,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Num
690683
dx_da = -dIa * inv_dIx
691684
dx_db = -dIb * inv_dIx
692685
dx_dp = inv_dIx
693-
_Δa = Δa isa Real ? Δa : zero(Δa)
694-
_Δb = Δb isa Real ? Δb : zero(Δb)
695-
_Δp = Δp isa Real ? Δp : zero(Δp)
696-
Δx = dx_da * _Δa + dx_db * _Δb + dx_dp * _Δp
686+
Δx = muladd(dx_dp, Δp, muladd(dx_db, Δb, dx_da * Δa))
697687
Δy = -Δx
698688
Tout = typeof((x, y))
699689
return (x, y), ChainRulesCore.Tangent{Tout}(Δx, Δy)

0 commit comments

Comments
 (0)