Skip to content

Commit f0e48e8

Browse files
authored
Merge branch 'main' into vc/blocking_ring
2 parents 8bd1ff5 + 115e3e0 commit f0e48e8

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Enzyme"
22
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
4-
version = "0.13.96"
4+
version = "0.13.97"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -43,9 +43,9 @@ ADTypes = "1"
4343
BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
4444
CEnum = "0.4, 0.5"
4545
ChainRulesCore = "1"
46-
DynamicPPL = "0.35, 0.36, 0.37, 0.38"
46+
DynamicPPL = "0.35 - 0.39"
4747
EnzymeCore = "0.8.15"
48-
Enzyme_jll = "0.0.207"
48+
Enzyme_jll = "0.0.208"
4949
GPUArraysCore = "0.1.6, 0.2"
5050
GPUCompiler = "1.6.2"
5151
LLVM = "9.1"

src/typetree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ else
305305
dl,
306306
seen::TypeTreeTable,
307307
) where {kind,T}
308-
tt = copy(typetree(Ptr{T}, ctx, dl, seen))
308+
tt = copy(typetree(typed_fieldtype(AT, 1), ctx, dl, seen))
309309
shift!(tt, dl, 0, sizeof(Int), 0)
310310

311311
for f = 2:fieldcount(AT)

src/utils.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,23 @@ end
462462

463463
else
464464

465+
function is_memory_ref_field2_an_offset(@nospecialize(T::Type{<:GenericMemoryRef}))
466+
ET = eltype(T)
467+
468+
# 0 = inlinealloc
469+
# 1 = isboxed
470+
# 2 = isbitsunion
471+
return (Base.datatype_arrayelem(T.types[2]) == 2) || Base.datatype_layoutsize(T.types[2]) == 0
472+
end
473+
465474
@inline function typed_fieldtype(@nospecialize(T::Type), i::Int)::Type
466475
if T <: GenericMemoryRef && i == 1 || T <: GenericMemory && i == 2
467-
eT = eltype(T)
468-
Ptr{eT}
476+
if T <: GenericMemoryRef && i == 1 && is_memory_ref_field2_an_offset(T)
477+
Int
478+
else
479+
eT = eltype(T)
480+
Ptr{eT}
481+
end
469482
else
470483
fieldtype(T, i)
471484
end

test/typetree.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ struct UnionStruct1{T}
4848
y::Any
4949
end
5050

51+
struct EmptyStruct
52+
end
53+
5154
@testset "TypeTree" begin
5255
@test tt(Float16) == "{[-1]:Float@half}"
5356
@test tt(Float32) == "{[-1]:Float@float}"
@@ -99,6 +102,14 @@ end
99102
"{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,0]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,0]:Pointer, [24,4]:Float@float, [24,8]:Float@double}"
100103
end
101104

105+
@static if VERSION >= v"1.11-"
106+
if Sys.WORD_SIZE == 64
107+
@test tt(MemoryRef{EmptyStruct}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Pointer, [8,0]:Integer, [8,1]:Integer, [8,2]:Integer, [8,3]:Integer, [8,4]:Integer, [8,5]:Integer, [8,6]:Integer, [8,7]:Integer, [8,8]:Pointer}"
108+
@test tt(MemoryRef{Union{Float64, Int64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Pointer, [8,0]:Integer, [8,1]:Integer, [8,2]:Integer, [8,3]:Integer, [8,4]:Integer, [8,5]:Integer, [8,6]:Integer, [8,7]:Integer, [8,8]:Pointer}"
109+
@test tt(MemoryRef{Float64}) == "{[-1]:Pointer, [0,-1]:Float@double, [8,0]:Integer, [8,1]:Integer, [8,2]:Integer, [8,3]:Integer, [8,4]:Integer, [8,5]:Integer, [8,6]:Integer, [8,7]:Integer, [8,8]:Pointer, [8,8,-1]:Float@double}"
110+
end
111+
end
112+
102113
@test tt(UnionStruct1{Float32}) == "{[0]:Float@float, [4]:Integer, [8]:Pointer}"
103114

104115
if Sys.WORD_SIZE == 64

0 commit comments

Comments
 (0)