Skip to content

Commit 359de07

Browse files
committed
Simplify: remove IsInactive, always use *_nongen
1 parent c387d29 commit 359de07

File tree

5 files changed

+120
-303
lines changed

5 files changed

+120
-303
lines changed

lib/EnzymeCore/src/EnzymeCore.jl

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -506,60 +506,35 @@ function autodiff_thunk end
506506
function autodiff_deferred_thunk end
507507

508508
"""
509-
make_zero(
510-
prev::T, ::Val{copy_if_inactive}=Val(false), ::Val{runtime_inactive}=Val(false)
511-
)::T
512-
make_zero(
513-
::Type{T},
514-
seen::IdDict,
515-
prev::T,
516-
::Val{copy_if_inactive}=Val(false),
517-
::Val{runtime_inactive}=Val(false),
518-
)::T
509+
make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T
510+
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T
519511
520512
Recursively make a copy of the value `prev::T` in which all differentiable values are
521513
zeroed. The argument `copy_if_inactive` specifies what to do if the type `T` or any
522514
of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s
523515
instance (the default) or make a copy.
524516
525-
The argument `runtime_inactive` specifies whether each constituent type is checked for being
526-
guaranteed inactive at runtime for every call to `make_zero`, or if this can be checked once
527-
at compile-time and reused across multiple calls to `make_zero` and related functions (the
528-
default). Runtime checks are necessary to pick up recently added methods to
529-
`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually
530-
not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have
531-
previously been passed to `make_zero` or related functions.
532-
533517
Extending this method for custom types is rarely needed. If you implement a new type, such
534-
as a GPU array type, for which `make_zero` should directly invoke `zero` rather than
535-
iterate/broadcast when the eltype is scalar, it is sufficient to implement `Base.zero` and
536-
make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is not appropriate,
537-
extend [`EnzymeCore.isvectortype`](@ref) directly instead.)
518+
as a GPU array type, on which `make_zero` should directly invoke `zero` when the eltype is
519+
scalar, it is sufficient to implement `Base.zero` and make sure your type subtypes
520+
`DenseArray`. (If subtyping `DenseArray` is not appropriate, extend
521+
[`EnzymeCore.isvectortype`](@ref) instead.)
538522
"""
539523
function make_zero end
540524

541525
"""
542-
make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive}=Val(false))::Nothing
526+
make_zero!(val::T, [seen::IdDict])::Nothing
543527
544528
Recursively set a variable's differentiable values to zero. Only applicable for types `T`
545529
that are mutable or hold all differentiable values in mutable storage (e.g.,
546530
`Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over
547531
parts of `val` that are guaranteed to be inactive.
548532
549-
The argument `runtime_inactive` specifies whether each constituent type is checked for being
550-
guaranteed inactive at runtime for every call to `make_zero!`, or if this can be checked once
551-
at compile-time and reused across multiple calls to `make_zero!` and related functions (the
552-
default). Runtime checks are necessary to pick up recently added methods to
553-
`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually
554-
not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have
555-
previously been passed to `make_zero!` or related functions.
556-
557533
Extending this method for custom types is rarely needed. If you implement a new mutable
558-
type, such as a GPU array type, for which `make_zero!` should directly invoke
559-
`fill!(x, false)` rather than iterate/broadcast when the eltype is scalar, it is sufficient
560-
to implement `Base.zero`, `Base.fill!`, and make sure your type subtypes `DenseArray`. (If
561-
subtyping `DenseArray` is not appropriate, extend [`EnzymeCore.isvectortype`](@ref) directly
562-
instead.)
534+
type, such as a GPU array type, on which `make_zero!` should directly invoke
535+
`fill!(x, false)` when the eltype is scalar, it is sufficient to implement `Base.zero`,
536+
`Base.fill!`, and make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is
537+
not appropriate, extend [`EnzymeCore.isvectortype`](@ref) instead.)
563538
"""
564539
function make_zero! end
565540

@@ -569,7 +544,7 @@ function make_zero! end
569544
Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
570545
and [`make_zero!`](@ref) recurse through an object.
571546
572-
By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or
547+
By default, `isvectortype(T) == true` when `isscalartype(T) == true` or when
573548
`T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`.
574549
575550
A new vector type, such as a GPU array type, should normally subtype `DenseArray` and
@@ -607,7 +582,7 @@ in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensure
607582
`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]
608583
609584
By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete
610-
`T <: AbstractFloat`.
585+
types `T <: AbstractFloat`.
611586
612587
A hypothetical new real number type with Enzyme support should usually subtype
613588
`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,

src/analyses/activity.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_c
414414
return res
415415
end
416416

417-
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world)::Bool where {T}
417+
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world=nothing)::Bool where {T}
418418
rt = active_reg_inner(T, (), world)
419419
res = rt == AnyState
420420
return res
@@ -427,7 +427,7 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n
427427
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
428428
end
429429

430-
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T}
430+
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world=nothing)::Bool where {T}
431431
rt = Enzyme.Compiler.active_reg_inner(T, (), world)
432432
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
433433
end

src/typeutils/recursive_add.jl

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using .RecursiveMaps: RecursiveMaps, recursive_map, recursive_map!
22

33
"""
4-
recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const)
4+
recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const_nongen)
55
66
!!! warning
77
Internal function, documented for developer convenience but not covered by semver API
@@ -18,7 +18,9 @@ The optional argument `forcelhs` takes a callable such that if `forcelhs(S) == t
1818
such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies to
1919
non-differentiable values.
2020
"""
21-
function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) where {T,F,L}
21+
function recursive_add(
22+
x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const_nongen
23+
) where {T,F,L}
2224
function addf(xi::S, yi::S) where {S}
2325
@assert EnzymeCore.isvectortype(S)
2426
return ((xi + f(yi))::S,)
@@ -27,8 +29,7 @@ function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const)
2729
end
2830

2931
"""
30-
accumulate_seen!(f, seen::IdDict, ::Val{runtime_inactive}=Val(false))
31-
accumulate_seen!(f, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive)
32+
accumulate_seen!(f, seen::IdDict)
3233
3334
!!! warning
3435
Internal function, documented for developer convenience but not covered by semver API
@@ -42,35 +43,17 @@ vector type instance mappping to another object of the same type and structure.
4243
returned value.
4344
4445
The recursion stops at instances of types that are themselves cached by `make_zero`
45-
(`recursive_map`), as these objects should have their own entries in `seen`.
46-
47-
Inactive objects that would be shared/copied rather than zeroed by `make_zero` are skipped.
48-
If the optional `::Val{runtime_inactive}` argument was passed to `make_zero`, the same value
49-
should be passed to `accumulate_seen` for consistency. If needed, a custom
50-
`RecursiveMaps.IsInactive` instance can be provided instead.
46+
(`recursive_map`), as these objects should have their own entries in `seen`. The recursion
47+
also stops at inactive objects that not be zeroed by `make_zero`.
5148
"""
52-
function accumulate_seen!(
53-
f::F, seen::IdDict, ::Val{runtime_inactive}=Val(false)
54-
) where {F,runtime_inactive}
55-
accumulate_seen!(f, seen, RecursiveMaps.IsInactive{runtime_inactive}())
56-
return nothing
57-
end
58-
59-
function accumulate_seen!(
60-
f::F, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive
61-
) where {F}
62-
isinactivetype_or_seen = RecursiveMaps.IsInactive(
63-
isinactivetype, RecursiveMaps.iscachedtype
64-
)
49+
function accumulate_seen!(f::F, seen::IdDict) where {F}
6550
for (k, v) in seen
66-
_accumulate_seen_item!(f, k, v, isinactivetype, isinactivetype_or_seen)
51+
_accumulate_seen_item!(f, k, v)
6752
end
6853
return nothing
6954
end
7055

71-
function _accumulate_seen_item!(
72-
f::F, k::T, v::T, isinactivetype, isinactivetype_or_seen
73-
) where {F,T}
56+
function _accumulate_seen_item!(f::F, k::T, v::T) where {F,T}
7457
function addf!!(ki::S, vi::S) where {S}
7558
@assert EnzymeCore.isvectortype(S)
7659
return ((ki .+ f.(vi))::S,)
@@ -82,10 +65,13 @@ function _accumulate_seen_item!(
8265
ki .+= f.(vi)
8366
return (ki::S,)
8467
end
85-
RecursiveMaps.check_nonactive(T, isinactivetype)
86-
if !isinactivetype(T)
68+
@inline function isinactive_or_cachedtype(::Type{T}) where {T}
69+
return guaranteed_const_nongen(T) || RecursiveMaps.iscachedtype(T)
70+
end
71+
RecursiveMaps.check_nonactive(T)
72+
if !guaranteed_const_nongen(T)
8773
newks = RecursiveMaps.recursive_map_inner(
88-
nothing, addf!!, (k,), (k, v), Val(false), isinactivetype_or_seen
74+
nothing, addf!!, (k,), (k, v), Val(false), isinactive_or_cachedtype
8975
)
9076
@assert only(newks) === k
9177
end

0 commit comments

Comments
 (0)