@@ -167,25 +167,31 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
167167collapse_nothings (xs:: AbstractArray{Nothing} ) = nothing
168168collapse_nothings (xs) = xs
169169
170- _purefun (:: Type{F} ) where {F<: Function } = Base. issingletontype (F)
171- _purefun (:: Type ) = false
172- _purefun (:: Type{typeof(^)} ) = false # fix @testset "power" & @testset "diagonal hessian"
170+ _dual_purefun (:: Type{F} ) where {F<: Function } = Base. issingletontype (F)
171+ _dual_purefun (:: Type ) = false
172+ _dual_purefun (:: Type{typeof(^)} ) = false # avoid DomainError from negative powers
173173
174- _dualsafe (x:: Numeric{<:Real} ) = true
175- _dualsafe (x:: Ref{<:Numeric{<:Real}} ) = true
176- _dualsafe (x:: Val ) = true
177- _dualsafe (x:: Type ) = true
178- _dualsafe (x:: Symbol ) = true
179- _dualsafe (x) = false
174+ _dual_safearg (x:: Numeric{<:Real} ) = true
175+ _dual_safearg (x:: Ref{<:Numeric{<:Real}} ) = true
176+ _dual_safearg (x:: Union{Type,Val,Symbol} ) = true # non-differentiable types
177+ _dual_safearg (x) = false
178+
179+ # This is Broadcast.combine_eltypes but with dual eltypes:
180+ _combine_dual_eltypes (f, args:: Tuple ) =
181+ Broadcast. promote_typejoin_union (Base. _return_type (f, map (_dual_eltype, args)))
182+ _dual_eltype (x:: Numeric{T} ) where {T<: Real } = Dual{Nothing, T, 1 } # typeof(Dual(one(T),true))
183+ _dual_eltype (x) = eltype (x)
180184
181185@adjoint function broadcasted (:: AbstractArrayStyle , f:: F , args... ) where {F}
182- T = Broadcast . combine_eltypes (f, args)
186+ TD = _combine_dual_eltypes (f, args)
183187 # Avoid generic broadcasting in two easy cases:
184- if T == Bool
188+ if TD <: Dual && isconcretetype (TD)
189+ if _dual_purefun (F) && all (_dual_safearg, args)
190+ y, back = broadcast_forward (f, args... )
191+ return y, ȳ -> (nothing , nothing , back (ȳ)... )
192+ end
193+ elseif TD <: Real && isconcretetype (TD)
185194 return f .(args... ), _-> nothing
186- elseif T <: Real && isconcretetype (T) && _purefun (F) && all (_dualsafe, args)
187- y, back = broadcast_forward (f, args... )
188- return y, ȳ -> (nothing , nothing , back (ȳ)... )
189195 end
190196 len = inclen (args)
191197 y∂b = _broadcast ((x... ) -> _pullback (__context__, f, x... ), args... )
0 commit comments