Skip to content

Commit 1ad2238

Browse files
committed
fix
1 parent 5f76f86 commit 1ad2238

File tree

4 files changed

+68
-50
lines changed

4 files changed

+68
-50
lines changed

src/analyses/activity.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,22 +358,36 @@ Base.@nospecializeinfer @inline function active_reg_inner(
358358

359359
ty = AnyState
360360

361-
for f in fieldcount(nT)
361+
for f in 1:typed_fieldcount(nT)
362362
subT = typed_fieldtype(nT, f)
363363

364364
if justActive && ismutabletype(subT)
365365
# AnyState
366366
continue
367367
end
368368

369-
ty |= active_reg_inner(
369+
sub = active_reg_inner(
370370
subT,
371371
seen2,
372372
world,
373373
justActive,
374374
UnionSret,
375375
AbstractIsMixed,
376376
)
377+
378+
if sub == AnyState
379+
continue
380+
end
381+
382+
if sub == DupState && justActive
383+
continue
384+
end
385+
386+
if reftype
387+
sub = DupState
388+
end
389+
390+
ty |= sub
377391
end
378392

379393
return ty
@@ -448,7 +462,14 @@ end
448462
# check if a value is guaranteed to be not contain active[register] data
449463
# (aka not either mixed or active)
450464
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive(::Type{T})::Bool where {T}
451-
rt = Enzyme.Compiler.active_reg_nothrow(T)
465+
rt = active_reg_nothrow(T)
466+
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
467+
end
468+
469+
# check if a value is guaranteed to be not contain active[register] data
470+
# (aka not either mixed or active)
471+
@inline function guaranteed_nonactive(@nospecialize(T::Type), world::UInt)::Bool
472+
rt = active_reg(T; justActive=true)
452473
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
453474
end
454475

test/activity.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using Enzyme, Test
2+
3+
struct Ints{A, B}
4+
v::B
5+
q::Int
6+
end
7+
8+
mutable struct MInts{A, B}
9+
v::B
10+
q::Int
11+
end
12+
13+
@testset "Activity Tests" begin
14+
@static if VERSION < v"1.11-"
15+
else
16+
@test Enzyme.Compiler.active_reg(Memory{Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
17+
end
18+
@test Enzyme.Compiler.active_reg(Type{Array}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
19+
@test Enzyme.Compiler.active_reg(Ints{<:Any, Integer}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
20+
@test Enzyme.Compiler.active_reg(Ints{<:Any, Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
21+
@test Enzyme.Compiler.active_reg(Ints{Integer, <:Any}, Base.get_world_counter()) == Enzyme.Compiler.DupState
22+
@test Enzyme.Compiler.active_reg(Ints{Integer, <:Integer}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
23+
@test Enzyme.Compiler.active_reg(Ints{Integer, <:AbstractFloat}, Base.get_world_counter()) == Enzyme.Compiler.DupState
24+
@test Enzyme.Compiler.active_reg(Ints{Integer, Float64}, Base.get_world_counter()) == Enzyme.Compiler.ActiveState
25+
@test Enzyme.Compiler.active_reg(MInts{Integer, Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
26+
27+
@test Enzyme.Compiler.active_reg(Tuple{Float32,Float32,Int}, Base.get_world_counter()) == Enzyme.Compiler.ActiveState
28+
@test Enzyme.Compiler.active_reg(Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
29+
@test Enzyme.Compiler.active_reg(Base.RefValue{Float32}, Base.get_world_counter()) == Enzyme.Compiler.DupState
30+
@test Enzyme.Compiler.active_reg(Ptr, Base.get_world_counter()) == Enzyme.Compiler.DupState
31+
@test Enzyme.Compiler.active_reg(Base.RefValue{Float32}, Base.get_world_counter()) == Enzyme.Compiler.DupState
32+
@test Enzyme.Compiler.active_reg(Colon, Base.get_world_counter()) == Enzyme.Compiler.AnyState
33+
@test Enzyme.Compiler.active_reg(Symbol, Base.get_world_counter()) == Enzyme.Compiler.AnyState
34+
@test Enzyme.Compiler.active_reg(String, Base.get_world_counter()) == Enzyme.Compiler.AnyState
35+
@test Enzyme.Compiler.active_reg(Tuple{Any,Int64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
36+
@test Enzyme.Compiler.active_reg(Tuple{S,Int64} where S, Base.get_world_counter()) == Enzyme.Compiler.DupState
37+
@test Enzyme.Compiler.active_reg(Union{Float64,Nothing}, Base.get_world_counter()) == Enzyme.Compiler.DupState
38+
@test Enzyme.Compiler.active_reg(Union{Float64,Nothing}, Base.get_world_counter(), UnionSret=true) == Enzyme.Compiler.ActiveState
39+
@test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter()) == Enzyme.Compiler.DupState
40+
@test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter(); AbstractIsMixed=true) == Enzyme.Compiler.MixedState
41+
@test Enzyme.Compiler.active_reg(Tuple{A,A} where A, Base.get_world_counter(), AbstractIsMixed=true) == Enzyme.Compiler.MixedState
42+
end

test/mixedrrule.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,6 @@ function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typ
6969
return AugmentedReturn(primal, nothing, vec)
7070
end
7171

72-
# check if a value is guaranteed to be not contain active[register] data
73-
# (aka not either mixed or active)
74-
@inline function guaranteed_nonactive(::Type{T}) where T
75-
rt = Enzyme.Compiler.active_reg_nothrow(T)
76-
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
77-
end
78-
7972
function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(recmixfnc)},
8073
dret::Active, tape, tup)
8174
prev = tup.dval[]
@@ -86,7 +79,7 @@ function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(recmi
8679
pv = getfield(prev, i)
8780
if i == 1
8881
next = (7 * tape[1] * dret.val, 31 * tape[1] * dret.val)
89-
Enzyme.Compiler.recursive_add(pv, next, identity, guaranteed_nonactive)
82+
Enzyme.Compiler.recursive_add(pv, next, identity, Enzyme.Compiler.guaranteed_nonactive)
9083
else
9184
pv
9285
end

test/runtests.jl

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ end
7373

7474
include("abi.jl")
7575
include("typetree.jl")
76+
include("activity.jl")
7677
include("passes.jl")
7778
include("optimize.jl")
7879
include("make_zero.jl")
@@ -101,46 +102,7 @@ function vrec(start, x)
101102
end
102103
end
103104

104-
struct Ints{A, B}
105-
v::B
106-
q::Int
107-
end
108-
109-
mutable struct MInts{A, B}
110-
v::B
111-
q::Int
112-
end
113-
114105
@testset "Internal tests" begin
115-
@static if VERSION < v"1.11-"
116-
else
117-
@assert Enzyme.Compiler.active_reg(Memory{Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
118-
end
119-
@assert Enzyme.Compiler.active_reg(Type{Array}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
120-
@assert Enzyme.Compiler.active_reg(Ints{<:Any, Integer}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
121-
@assert Enzyme.Compiler.active_reg(Ints{<:Any, Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
122-
@assert Enzyme.Compiler.active_reg(Ints{Integer, <:Any}, Base.get_world_counter()) == Enzyme.Compiler.DupState
123-
@assert Enzyme.Compiler.active_reg(Ints{Integer, <:Integer}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
124-
@assert Enzyme.Compiler.active_reg(Ints{Integer, <:AbstractFloat}, Base.get_world_counter()) == Enzyme.Compiler.DupState
125-
@assert Enzyme.Compiler.active_reg(Ints{Integer, Float64}, Base.get_world_counter()) == Enzyme.Compiler.ActiveState
126-
@assert Enzyme.Compiler.active_reg(MInts{Integer, Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
127-
128-
@assert Enzyme.Compiler.active_reg(Tuple{Float32,Float32,Int}, Base.get_world_counter()) == Enzyme.Compiler.ActiveState
129-
@assert Enzyme.Compiler.active_reg(Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
130-
@assert Enzyme.Compiler.active_reg(Base.RefValue{Float32}, Base.get_world_counter()) == Enzyme.Compiler.DupState
131-
@assert Enzyme.Compiler.active_reg(Ptr, Base.get_world_counter(), Base.get_world_counter()) == Enzyme.Compiler.DupState
132-
@assert Enzyme.Compiler.active_reg(Base.RefValue{Float32}, Base.get_world_counter()) == Enzyme.Compiler.DupState
133-
@assert Enzyme.Compiler.active_reg(Colon, Base.get_world_counter()) == Enzyme.Compiler.AnyState
134-
@assert Enzyme.Compiler.active_reg(Symbol, Base.get_world_counter()) == Enzyme.Compiler.AnyState
135-
@assert Enzyme.Compiler.active_reg(String, Base.get_world_counter()) == Enzyme.Compiler.AnyState
136-
@assert Enzyme.Compiler.active_reg(Tuple{Any,Int64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
137-
@assert Enzyme.Compiler.active_reg(Tuple{S,Int64} where S, Base.get_world_counter()) == Enzyme.Compiler.DupState
138-
@assert Enzyme.Compiler.active_reg(Union{Float64,Nothing}, Base.get_world_counter()) == Enzyme.Compiler.DupState
139-
@assert Enzyme.Compiler.active_reg(Union{Float64,Nothing}, Base.get_world_counter(), #=justActive=#Val(false), #=unionSret=#Val(true)) == Enzyme.Compiler.ActiveState
140-
@test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter()) == Enzyme.Compiler.DupState
141-
@test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter(), #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState
142-
@test Enzyme.Compiler.active_reg(Tuple{A,A} where A, Base.get_world_counter(), #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState
143-
144106
# issue #1935
145107
struct Incomplete
146108
x::Float64

0 commit comments

Comments
 (0)