@@ -176,22 +176,14 @@ _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
176176_dual_safearg (x:: Union{Type,Val,Symbol} ) = true # non-differentiable types
177177_dual_safearg (x) = false
178178
179- # This is Broadcast.combine_eltypes but with dual eltypes:
180- _combine_dual_eltypes (f, args:: Tuple ) =
181- Broadcast. promote_typejoin_union (Base. _return_type (f, map (_dual_eltype, args)))
182- _dual_eltype (x:: Numeric{T} ) where {T<: Real } = Dual{Nothing, T, 1 } # typeof(Dual(one(T),true))
183- _dual_eltype (x) = eltype (x)
184-
185179@adjoint function broadcasted (:: AbstractArrayStyle , f:: F , args... ) where {F}
186- TD = _combine_dual_eltypes (f, args)
180+ T = Broadcast . combine_eltypes (f, args)
187181 # Avoid generic broadcasting in two easy cases:
188- if TD <: Dual && isconcretetype (TD)
189- if _dual_purefun (F) && all (_dual_safearg, args)
190- y, back = broadcast_forward (f, args... )
191- return y, ȳ -> (nothing , nothing , back (ȳ)... )
192- end
193- elseif TD <: Real && isconcretetype (TD)
194- return f .(args... ), _-> nothing
182+ if T == Bool
183+ return f .(args... ), _-> nothing
184+ elseif T <: Real && isconcretetype (T) && _dual_purefun (F) && all (_dual_safearg, args)
185+ y, back = broadcast_forward (f, args... )
186+ return y, ȳ -> (nothing , nothing , back (ȳ)... )
195187 end
196188 len = inclen (args)
197189 y∂b = _broadcast ((x... ) -> _pullback (__context__, f, x... ), args... )
0 commit comments