Skip to content

Commit cdc4f86

Browse files
committed
fix
1 parent 294904d commit cdc4f86

File tree

4 files changed

+74
-40
lines changed

4 files changed

+74
-40
lines changed

src/compiler.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ function prepare_llvm(interp, mod::LLVM.Module, job, meta)
513513

514514
RT = return_type(interp, mi)
515515

516-
_, _, returnRoots = get_return_info(RT)
517-
returnRoots = returnRoots !== nothing
516+
_, _, returnRoots0 = get_return_info(RT)
517+
returnRoots = returnRoots0 !== nothing
518518

519519
attributes = function_attributes(llvmfn)
520520
push!(
@@ -529,7 +529,7 @@ function prepare_llvm(interp, mod::LLVM.Module, job, meta)
529529
push!(attributes, LLVM.StringAttribute("enzyme_LocalReadOnlyOrThrow"))
530530
end
531531
if returnRoots
532-
attr = StringAttribute("enzymejl_returnRoots", "")
532+
attr = StringAttribute("enzymejl_returnRoots", string(length(eltype(returnRoots0).parameters[1])))
533533
push!(parameter_attributes(llvmfn, 2), attr)
534534
for u in LLVM.uses(llvmfn)
535535
u = LLVM.user(u)
@@ -3907,7 +3907,7 @@ function lower_convention(
39073907
if !in(0, parmsRemoved)
39083908
sretPtr = alloca!(
39093909
builder,
3910-
eltype(value_type(parameters(entry_f)[1])),
3910+
sret_ty(entry_f, 1),
39113911
"innersret",
39123912
)
39133913
ctx = LLVM.context(entry_f)
@@ -3924,7 +3924,7 @@ function lower_convention(
39243924
if returnRoots && !in(1, parmsRemoved)
39253925
retRootPtr = alloca!(
39263926
builder,
3927-
eltype(value_type(parameters(entry_f)[1+sret])),
3927+
sret_ty(entry_f, 1+sret),
39283928
"innerreturnroots",
39293929
)
39303930
# retRootPtr = alloca!(builder, parameters(wrapper_f)[1])
@@ -3968,7 +3968,7 @@ function lower_convention(
39683968
if LLVM.addrspace(ty) != 0
39693969
ptr = addrspacecast!(builder, ptr, ty)
39703970
end
3971-
@assert eltype(ty) == value_type(wrapparm)
3971+
@assert elty == value_type(wrapparm)
39723972
store!(builder, wrapparm, ptr)
39733973
push!(wrapper_args, ptr)
39743974
push!(
@@ -4707,10 +4707,10 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
47074707
end
47084708
end
47094709

4710-
_, _, returnRoots = get_return_info(rt)
4711-
returnRoots = returnRoots !== nothing
4710+
_, _, returnRoots0 = get_return_info(rt)
4711+
returnRoots = returnRoots0 !== nothing
47124712
if returnRoots
4713-
attr = StringAttribute("enzymejl_returnRoots", "")
4713+
attr = StringAttribute("enzymejl_returnRoots", string(length(eltype(returnRoots0).parameters[1])))
47144714
push!(parameter_attributes(wrapper_f, 2), attr)
47154715
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(2), attr)
47164716
end
@@ -5880,7 +5880,10 @@ end
58805880
EnumAttribute("sret")
58815881
end
58825882
LLVM.API.LLVMAddCallSiteAttribute(r, LLVM.API.LLVMAttributeIndex(1), attr)
5883-
r = load!(builder, eltype(value_type(callparams[1])), callparams[1])
5883+
if !LLVM.is_opaque(value_type(callparams[1]))
5884+
@assert eltype(value_type(callparams[1])) == jltype
5885+
end
5886+
r = load!(builder, jltype, callparams[1])
58845887
end
58855888

58865889
if T_ret != T_void

src/llvm/transforms.jl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module)
455455
LLVM.VoidType(),
456456
[LLVM.IntType(64), value_type(dst0)],
457457
)
458-
lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT)
458+
lifetimestart, _ = get_function!(mod, LLVM.name(LLVM.Intrinsic("llvm.lifetime.start"), [value_type(dst0)]), FT)
459459
call!(
460460
B,
461461
FT,
@@ -835,7 +835,7 @@ function nodecayed_phis!(mod::LLVM.Module)
835835
end
836836
nv, noffset, nhasload =
837837
getparent(b, operands(v)[1], offset, hasload, phicache)
838-
if eltype(value_type(nv)) != eltype(value_type(v))
838+
if !is_opaque(value_type(nv)) && eltype(value_type(nv)) != eltype(value_type(v))
839839
nv = bitcast!(
840840
b,
841841
nv,
@@ -1774,23 +1774,7 @@ function propagate_returned!(mod::LLVM.Module)
17741774
B = IRBuilder()
17751775
position!(B, first(instructions(first(blocks(fn)))))
17761776

1777-
# TODO try to get sret element type if possible
1778-
# note currently opaque pointers has this break [and we need to doa check if opaque
1779-
# and if so get inner piece]
1780-
1781-
if LLVM.is_opaque(value_type(arg))
1782-
msg = sprint() do io
1783-
println(io, "Needed element type of pointer to replace with intervening alloca\n")
1784-
println(io, "arg = ", string(arg))
1785-
println(io, "i = ", string(i))
1786-
println(io, "argn = ", string(argn))
1787-
println(io, "fn = ", string(fn))
1788-
end
1789-
throw(OpaquePointerError(msg))
1790-
end
1791-
1792-
argeltype = eltype(value_type(arg))
1793-
1777+
argeltype = sret_ty(fn, i)
17941778
al = alloca!(B, argeltype)
17951779
if value_type(al) != value_type(arg)
17961780
al = addrspacecast!(B, al, value_type(arg))

src/rules/customrules.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,12 @@ function enzyme_custom_setup_args(
336336
LLVM.ConstantInt(LLVM.IntType(32), 0),
337337
],
338338
)
339-
if value_type(val) != eltype(value_type(ptr))
339+
340+
if !is_opaque(value_type(ptr))
341+
@assert eltype(value_type(ptr)) == arty
342+
end
343+
344+
if value_type(val) != arty
340345
if overwritten[end]
341346
bt = GPUCompiler.backtrace(orig)
342347
msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
@@ -347,6 +352,19 @@ function enzyme_custom_setup_args(
347352
"As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.\n"*msg2,
348353
)
349354
end
355+
356+
msg = sprint() do io
357+
print(io, "custom rule lower failure fwd\n")
358+
print(io, "Ty = $Ty\n")
359+
print(io, "llty = $llty\n")
360+
print(io, "arty = $arty\n")
361+
print(io, "al0 = $al0\n")
362+
print(io, "ptr = $ptr\n")
363+
print(io, "val = $val\n")
364+
print(io, "arg = $arg\n")
365+
end
366+
throw(OpaquePointerError(msg))
367+
350368
if arty == eltype(value_type(val))
351369
val = load!(B, arty, val)
352370
else
@@ -441,7 +459,12 @@ function enzyme_custom_setup_args(
441459
],
442460
)
443461
needsload = false
444-
if value_type(val) != eltype(value_type(ptr))
462+
463+
if !is_opaque(value_type(ptr))
464+
@assert eltype(value_type(ptr)) == arty
465+
end
466+
467+
if value_type(val) != arty
445468
val = load!(B, arty, val)
446469
if !mixed
447470
ptr_val = ival
@@ -719,13 +742,14 @@ end
719742
end
720743

721744
if sret !== nothing
745+
sty = sret_ty(llvmf, 1)
722746
if LLVM.version().major >= 12
723-
attr = TypeAttribute("sret", eltype(value_type(parameters(llvmf)[1])))
747+
attr = TypeAttribute("sret", sty)
724748
else
725749
attr = EnumAttribute("sret")
726750
end
727751
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1), attr)
728-
res = load!(B, eltype(value_type(parameters(llvmf)[1])), sret)
752+
res = load!(B, sty, sret)
729753
end
730754
if swiftself
731755
attr = EnumAttribute("swiftself")
@@ -1436,8 +1460,9 @@ function enzyme_custom_common_rev(
14361460
end
14371461

14381462
if sret !== nothing
1463+
sty = sret_ty(llvmf, 1+swiftself)
14391464
if LLVM.version().major >= 12
1440-
attr = TypeAttribute("sret", eltype(value_type(parameters(llvmf)[1+swiftself])))
1465+
attr = TypeAttribute("sret", sty)
14411466
else
14421467
attr = EnumAttribute("sret")
14431468
end
@@ -1446,7 +1471,7 @@ function enzyme_custom_common_rev(
14461471
LLVM.API.LLVMAttributeIndex(1 + swiftself),
14471472
attr,
14481473
)
1449-
res = load!(B, eltype(value_type(parameters(llvmf)[1+swiftself])), sret)
1474+
res = load!(B, sty, sret)
14501475
API.SetMustCache!(res)
14511476
end
14521477
if swiftself

src/utils.jl

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,34 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
503503

504504
vt = LLVM.value_type(LLVM.parameters(fn)[idx])
505505

506+
sretkind = LLVM.kind(if LLVM.version().major >= 12
507+
LLVM.TypeAttribute("sret", LLVM.Int32Type())
508+
else
509+
LLVM.EnumAttribute("sret")
510+
end)
506511

507-
for attr in collect(parameter_attributes(fn, idx))
512+
for attr in collect(LLVM.parameter_attributes(fn, idx))
508513
ekind = LLVM.kind(attr)
509-
510-
if ekind == "sret"
511-
return value(attr)
514+
515+
if ekind == sretkind
516+
res = LLVM.value(attr)
517+
if !LLVM.is_opaque(vt)
518+
@assert eltype(vt) == res
519+
end
520+
return res
521+
end
522+
523+
if ekind == "enzymejl_returnRoots"
524+
nroots = parse(Int, LLVM.value(attr))
525+
526+
T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
527+
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
528+
529+
res = LLVM.ArrayType(T_prjlvalue, nroots)
530+
if !LLVM.is_opaque(vt)
531+
@assert eltype(vt) == res
532+
end
533+
return res
512534
end
513535

514536
if ekind == "enzyme_sret" || ekind == "enzyme_sret_v"
@@ -526,7 +548,7 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
526548
end
527549
end
528550

529-
throw(AssertionError("Function requesting sret type was not an sret"))
551+
throw(AssertionError("Function requesting sret type was not an sret\nidx=$idx\nfn=$(string(fn))"))
530552
end
531553

532554
export sret_ty

0 commit comments

Comments
 (0)