Skip to content

Commit 7ebb75b

Browse files
committed
rename, tidy, improve
1 parent 72eb681 commit 7ebb75b

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

src/lib/broadcast.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -167,25 +167,31 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
167167
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
168168
collapse_nothings(xs) = xs
169169

170-
_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
171-
_purefun(::Type) = false
172-
_purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal hessian"
170+
_dual_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
171+
_dual_purefun(::Type) = false
172+
_dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers
173173

174-
_dualsafe(x::Numeric{<:Real}) = true
175-
_dualsafe(x::Ref{<:Numeric{<:Real}}) = true
176-
_dualsafe(x::Val) = true
177-
_dualsafe(x::Type) = true
178-
_dualsafe(x::Symbol) = true
179-
_dualsafe(x) = false
174+
_dual_safearg(x::Numeric{<:Real}) = true
175+
_dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
176+
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
177+
_dual_safearg(x) = false
178+
179+
# This is Broadcast.combine_eltypes but with dual eltypes:
180+
_combine_dual_eltypes(f, args::Tuple) =
181+
Broadcast.promote_typejoin_union(Base._return_type(f, map(_dual_eltype, args)))
182+
_dual_eltype(x::Numeric{T}) where {T<:Real} = Dual{Nothing, T, 1} # typeof(Dual(one(T),true))
183+
_dual_eltype(x) = eltype(x)
180184

181185
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
182-
T = Broadcast.combine_eltypes(f, args)
186+
TD = _combine_dual_eltypes(f, args)
183187
# Avoid generic broadcasting in two easy cases:
184-
if T == Bool
188+
if TD <: Dual && isconcretetype(TD)
189+
if _dual_purefun(F) && all(_dual_safearg, args)
190+
y, back = broadcast_forward(f, args...)
191+
return y, ȳ -> (nothing, nothing, back(ȳ)...)
192+
end
193+
elseif TD <: Real && isconcretetype(TD)
185194
return f.(args...), _->nothing
186-
elseif T <: Real && isconcretetype(T) && _purefun(F) && all(_dualsafe, args)
187-
y, back = broadcast_forward(f, args...)
188-
return y, ȳ -> (nothing, nothing, back(ȳ)...)
189195
end
190196
len = inclen(args)
191197
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)

test/features.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,4 +535,10 @@ end
535535

536536
@test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],)
537537
@test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)
538+
539+
# negative powers
540+
@test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], [1,-1,2])[1] [1.0, -0.25, 8.0]
541+
@test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
542+
@test gradient((x,p) -> sum(z -> z^p, x), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
543+
@test gradient((x,p) -> mapreduce(z -> z^p, +, x), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
538544
end

0 commit comments

Comments
 (0)