Skip to content

Commit b7622d7

Browse files
authored
Fix abs typeof when zero-sized struct element (#2487)
* Fix abs typeof when zero-sized struct element * fix * fix * fix
1 parent ae3d9f1 commit b7622d7

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

src/absint.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,10 @@ function abs_typeof(
601601
@assert Base.isconcretetype(typ)
602602
seen = false
603603
lasti = 1
604+
604605
for i in 1:typed_fieldcount(typ)
605606
fo = typed_fieldoffset(typ, i)
606-
if fo == offset
607+
if fo == offset && (i == typed_fieldcount(typ) || typed_fieldoffset(typ, i + 1) != offset)
607608
offset = 0
608609
typ = typed_fieldtype(typ, i)
609610
if !Base.allocatedinline(typ)

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4508,7 +4508,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
45084508
ctx = LLVM.context(mod)
45094509
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
45104510
fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing
4511-
4511+
45124512
if !API.HasFromStack(inst) && isa(inst, LLVM.AllocaInst)
45134513

45144514
calluse = nothing

src/llvm/transforms.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ end
310310
# amenable to caching analysis infrastructure
311311
function memcpy_alloca_to_loadstore(mod::LLVM.Module)
312312
dl = datalayout(mod)
313+
ctx = context(mod)
314+
seen = TypeTreeTable()
313315
for f in functions(mod)
314316
if length(blocks(f)) != 0
315317
bb = first(blocks(f))
@@ -413,6 +415,41 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module)
413415
bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src))))
414416

415417
src = load!(B, elty, src)
418+
419+
T_jlvalue = LLVM.StructType(LLVMType[])
420+
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
421+
422+
legal, source_typ, byref = abs_typeof(src)
423+
codegen_typ = value_type(src)
424+
if legal
425+
if codegen_typ isa LLVM.PointerType || codegen_typ isa LLVM.IntegerType
426+
else
427+
@assert byref == GPUCompiler.BITS_VALUE
428+
source_typ
429+
end
430+
431+
ec = typetree(source_typ, ctx, string(dl), seen)
432+
if byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF
433+
ec = copy(ec)
434+
merge!(ec, TypeTree(API.DT_Pointer, ctx))
435+
only!(ec, -1)
436+
end
437+
metadata(src)["enzyme_type"] = to_md(ec, ctx)
438+
metadata(src)["enzymejl_source_type_$(source_typ)"] = MDNode(LLVM.Metadata[])
439+
metadata(src)["enzymejl_byref_$(byref)"] = MDNode(LLVM.Metadata[])
440+
441+
@static if VERSION < v"1.11-"
442+
else
443+
legal2, obj = absint(src)
444+
if legal2 obj isa Memory && obj == typeof(obj).instance
445+
metadata(src)["nonnull"] = MDNode(LLVM.Metadata[])
446+
end
447+
end
448+
449+
elseif codegen_typ == T_prjlvalue
450+
metadata(src)["enzyme_type"] =
451+
to_md(typetree(Ptr{Cvoid}, ctx, dl, seen), ctx)
452+
end
416453
FT = LLVM.FunctionType(
417454
LLVM.VoidType(),
418455
[LLVM.IntType(64), value_type(dst0)],

0 commit comments

Comments
 (0)